From 83dafd1be519f7884fdeea25c2ab29f1b04ada39 Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Wed, 14 Jan 2026 16:05:37 +0000 Subject: [PATCH] feat: add halve/decay to countmin sketch --- datasketches/src/countmin/sketch.rs | 48 +++++++++++++++++++++++++++++ datasketches/tests/countmin_test.rs | 48 +++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+) diff --git a/datasketches/src/countmin/sketch.rs b/datasketches/src/countmin/sketch.rs index 4f8225b..fba8f47 100644 --- a/datasketches/src/countmin/sketch.rs +++ b/datasketches/src/countmin/sketch.rs @@ -180,6 +180,54 @@ impl CountMinSketch { } } + /// Divides every counter by two, truncating toward zero. + /// + /// Useful for exponential decay where counts represent recent activity. + /// + /// # Examples + /// + /// ```rust + /// # use datasketches::countmin::CountMinSketch; + /// let mut sketch = CountMinSketch::new(4, 128); + /// sketch.update_with_weight("apple", 3); + /// sketch.halve(); + /// assert!(sketch.estimate("apple") >= 1); + /// ``` + pub fn halve(&mut self) { + self.counts.iter_mut().for_each(|c| *c /= 2); + self.total_weight /= 2; + } + + /// Multiplies every counter by `decay` and truncates back into `i64`. + /// + /// Values are truncated toward zero after multiplication; choose `decay` in `(0, 1]`. + /// The total weight is scaled by the same factor to keep bounds consistent. + /// + /// # Examples + /// + /// ```rust + /// # use datasketches::countmin::CountMinSketch; + /// let mut sketch = CountMinSketch::new(4, 128); + /// sketch.update_with_weight("apple", 3); + /// sketch.decay(0.5); + /// assert!(sketch.estimate("apple") >= 1); + /// ``` + /// + /// # Panics + /// + /// Panics if `decay` is not finite or is outside `(0, 1]`. + pub fn decay(&mut self, decay: f64) { + assert!(decay.is_finite(), "decay must be finite"); + assert!( + decay > 0.0 && decay <= 1.0, + "decay must be within (0, 1]" + ); + self.counts + .iter_mut() + .for_each(|c| *c = (*c as f64 * decay) as i64); + self.total_weight = (self.total_weight as f64 * decay) as i64; + } + /// Returns the estimated frequency of the given item. /// /// # Examples diff --git a/datasketches/tests/countmin_test.rs b/datasketches/tests/countmin_test.rs index b4b3685..84590ba 100644 --- a/datasketches/tests/countmin_test.rs +++ b/datasketches/tests/countmin_test.rs @@ -68,6 +68,54 @@ fn test_negative_weights() { assert_eq!(sketch.total_weight(), 3); } +#[test] +fn test_halve() { + let buckets = CountMinSketch::suggest_num_buckets(0.01); + let hashes = CountMinSketch::suggest_num_hashes(0.9); + let mut sketch = CountMinSketch::new(hashes, buckets); + + for i in 0..1000usize { + for _ in 0..i { + sketch.update(i as u64); + } + } + + for i in 0..1000usize { + assert!(sketch.estimate(i as u64) >= i as i64); + } + + sketch.halve(); + + for i in 0..1000usize { + assert!(sketch.estimate(i as u64) >= (i as i64) / 2); + } +} + +#[test] +fn test_decay() { + let buckets = CountMinSketch::suggest_num_buckets(0.01); + let hashes = CountMinSketch::suggest_num_hashes(0.9); + let mut sketch = CountMinSketch::new(hashes, buckets); + + for i in 0..1000usize { + for _ in 0..i { + sketch.update(i as u64); + } + } + + for i in 0..1000usize { + assert!(sketch.estimate(i as u64) >= i as i64); + } + + const FACTOR: f64 = 0.5; + sketch.decay(FACTOR); + + for i in 0..1000usize { + let expected = ((i as f64) * FACTOR).floor() as i64; + assert!(sketch.estimate(i as u64) >= expected); + } +} + #[test] fn test_merge() { let mut left = CountMinSketch::new(3, 64);