From 154f8ea21684f48c90c009c18d635215e0bb3e07 Mon Sep 17 00:00:00 2001 From: himkt Date: Mon, 15 Dec 2025 09:25:26 +0900 Subject: [PATCH 1/2] refactor(tree): generic version of segment trees --- src/tree.rs | 350 +++++++++++++++++++++++++++++----------------------- 1 file changed, 193 insertions(+), 157 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index dfc4d66..bdcb087 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -1,15 +1,45 @@ -#[derive(Debug, Clone)] -pub struct SegmentTree { - data: Vec, - mode: Mode, +use std::ops::Add; + +pub trait Zero { + fn zero() -> Self; } -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum Mode { - RangeUpdate(Op), - RangeGet(Op), +pub trait Bounded { + fn min_value() -> Self; + fn max_value() -> Self; } +macro_rules! impl_monoid_traits { + ($t:ty) => { + impl Zero for $t { + fn zero() -> Self { + 0 + } + } + impl Bounded for $t { + fn min_value() -> Self { + <$t>::MIN + } + fn max_value() -> Self { + <$t>::MAX + } + } + }; +} + +impl_monoid_traits!(i8); +impl_monoid_traits!(i16); +impl_monoid_traits!(i32); +impl_monoid_traits!(i64); +impl_monoid_traits!(i128); +impl_monoid_traits!(u8); +impl_monoid_traits!(u16); +impl_monoid_traits!(u32); +impl_monoid_traits!(u64); +impl_monoid_traits!(u128); +impl_monoid_traits!(isize); +impl_monoid_traits!(usize); + #[derive(Debug, Clone, PartialEq, Eq)] pub enum Op { Max, @@ -17,194 +47,200 @@ pub enum Op { Add, } -// Segment tree implementation. All operations are 0-origin. -// Note that a half-open interval [l, r) is used as a range representation. -impl SegmentTree { - const SEQ_LEN: usize = 1 << 20; - const MAX: i64 = 1_000_000_000_000; - const MIN: i64 = -1_000_000_000_000; +#[derive(Debug, Clone)] +pub struct RangeGetTree { + data: Vec, + op: Op, +} - pub fn new(mode: Mode) -> Self { - let default = match &mode { - Mode::RangeGet(op) => SegmentTree::default(op), - Mode::RangeUpdate(op) => SegmentTree::default(op), - }; +impl RangeGetTree +where + T: Clone + Add + Ord + Zero + Bounded, +{ + const SEQ_LEN: usize = 1 << 20; + pub fn new(op: Op) -> Self { + let identity = Self::identity_for(&op); Self { - data: vec![default; 2 * SegmentTree::SEQ_LEN], - mode, + data: vec![identity; 2 * Self::SEQ_LEN], + op, } } - /// Return an appropriate default value for the given operation. - pub fn default(op: &Op) -> i64 { + fn identity_for(op: &Op) -> T { match op { - Op::Add => 0, - Op::Max => SegmentTree::MIN, - Op::Min => SegmentTree::MAX, + Op::Add => T::zero(), + Op::Max => T::min_value(), + Op::Min => T::max_value(), } } - /// Get an i-th element of from the tree. - pub fn get_one(&mut self, mut index: usize) -> i64 { - index += SegmentTree::SEQ_LEN; - let mut ret = 0; - - if let Mode::RangeUpdate(op) = &self.mode { - let operator = match op { - Op::Add => |ret: &mut i64, v: i64| *ret += v, - _ => panic!("Operator {:?} is not supported.", op), - }; - - operator(&mut ret, self.data[index]); - while index > 0 { - index /= 2; - operator(&mut ret, self.data[index]); - } - } else { - panic!("Mode {:?} is not supported.", &self.mode); - } - - ret + pub fn identity(&self) -> T { + Self::identity_for(&self.op) } - pub fn get_range(&self, l: usize, r: usize) -> i64 { - if let Mode::RangeGet(op) = &self.mode { - self.range_query_recursive(op, l, r, 0, SegmentTree::SEQ_LEN, 1) - } else { - panic!("Mode {:?} is not supported.", &self.mode); + fn operate(&self, a: T, b: T) -> T { + match &self.op { + Op::Add => a + b, + Op::Max => a.max(b), + Op::Min => a.min(b), } } - /// Update an i-th element to `value`. - pub fn update_one(&mut self, mut index: usize, value: i64) { - index += SegmentTree::SEQ_LEN; - - if let Mode::RangeGet(op) = &self.mode { - match op { - Op::Add => self.data[index] += value, - _ => self.data[index] = value, - } - while index > 0 { - index /= 2; - let lv = self.data[index * 2]; - let rv = self.data[index * 2 + 1]; - match op { - Op::Add => self.data[index] = lv + rv, - Op::Max => self.data[index] = lv.max(rv), - Op::Min => self.data[index] = lv.min(rv), - }; - } - } else { - panic!("Mode {:?} is not supported.", &self.mode); - } + pub fn get(&self, l: usize, r: usize) -> T { + self.range_query_recursive(l, r, 0, Self::SEQ_LEN, 1) } - /// Add `value` to the range `[l, r)`. - pub fn update_range(&mut self, mut l: usize, mut r: usize, value: i64) { - if let Mode::RangeUpdate(op) = &self.mode { - let operate_and_assign_one = match op { - Op::Add => |ret: &mut i64, v: i64| *ret += v, - _ => panic!(), - }; - - l += SegmentTree::SEQ_LEN; - r += SegmentTree::SEQ_LEN; - - while l < r { - if l % 2 == 1 { - operate_and_assign_one(&mut self.data[l], value); - l += 1; - } - l /= 2; - - if r % 2 == 1 { - operate_and_assign_one(&mut self.data[r - 1], value); - r -= 1; - } - r /= 2; - } - } else { - panic!("Mode {:?} is not supported.", &self.mode); + pub fn update(&mut self, mut index: usize, value: T) { + index += Self::SEQ_LEN; + self.data[index] = self.operate(self.data[index].clone(), value); + while index > 1 { + index /= 2; + let lv = self.data[index * 2].clone(); + let rv = self.data[index * 2 + 1].clone(); + self.data[index] = self.operate(lv, rv); } } - fn range_query_recursive(&self, op: &Op, ql: usize, qr: usize, sl: usize, sr: usize, pos: usize) -> i64 { + fn range_query_recursive(&self, ql: usize, qr: usize, sl: usize, sr: usize, pos: usize) -> T { if qr <= sl || sr <= ql { - return SegmentTree::default(op); + return self.identity(); } if ql <= sl && sr <= qr { - return self.data[pos]; + return self.data[pos].clone(); } - let sm = (sl + sr) / 2; - let lv = self.range_query_recursive(op, ql, qr, sl, sm, pos * 2); - let rv = self.range_query_recursive(op, ql, qr, sm, sr, pos * 2 + 1); - match op { - Op::Add => lv + rv, - Op::Max => lv.max(rv), - Op::Min => lv.min(rv), - } + let lv = self.range_query_recursive(ql, qr, sl, sm, pos * 2); + let rv = self.range_query_recursive(ql, qr, sm, sr, pos * 2 + 1); + self.operate(lv, rv) } } #[cfg(test)] -mod test_segment_tree { - use crate::tree::Mode; - use crate::tree::Op; - use crate::tree::SegmentTree; - - #[test] - fn it_works_raq() { - let mut raq = SegmentTree::new(Mode::RangeUpdate(Op::Add)); - raq.update_range(1, 2, 1); - raq.update_range(2, 4, 2); - raq.update_range(3, 4, 3); - assert_eq!(raq.get_one(0), 0); - assert_eq!(raq.get_one(2), 2); - assert_eq!(raq.get_one(3), 5); - } +mod test_range_get_tree { + use super::{Op, RangeGetTree}; #[test] fn it_works_rsq() { - let mut rsq = SegmentTree::new(Mode::RangeGet(Op::Add)); - rsq.update_one(0, 3); - rsq.update_one(2, 3); - rsq.update_one(3, 1); - rsq.update_one(4, 4); - assert_eq!(rsq.get_range(0, 3), 6); - assert_eq!(rsq.get_range(1, 3), 3); - assert_eq!(rsq.get_range(2, 4), 4); - assert_eq!(rsq.get_range(3, 4), 1); - assert_eq!(rsq.get_range(1, 6), 8); - assert_eq!(rsq.get_range(0, 0), 0); + let mut rsq: RangeGetTree = RangeGetTree::new(Op::Add); + rsq.update(0, 3); + rsq.update(2, 3); + rsq.update(3, 1); + rsq.update(4, 4); + assert_eq!(rsq.get(0, 3), 6); + assert_eq!(rsq.get(1, 3), 3); + assert_eq!(rsq.get(2, 4), 4); + assert_eq!(rsq.get(3, 4), 1); + assert_eq!(rsq.get(1, 6), 8); + assert_eq!(rsq.get(0, 0), rsq.identity()); } #[test] fn it_works_rmaxq() { - let mut rmaxq = SegmentTree::new(Mode::RangeGet(Op::Max)); - rmaxq.update_one(0, 10); - rmaxq.update_one(2, 101); - rmaxq.update_one(100, 1001); - assert_eq!(rmaxq.get_range(0, 1), 10); - assert_eq!(rmaxq.get_range(0, 2), 10); - assert_eq!(rmaxq.get_range(0, 3), 101); - assert_eq!(rmaxq.get_range(0, 100100), 1001); - assert_eq!(rmaxq.get_range(101, 1000), SegmentTree::MIN); - assert_eq!(rmaxq.get_range(0, 0), SegmentTree::MIN); + let mut rmaxq: RangeGetTree = RangeGetTree::new(Op::Max); + rmaxq.update(0, 10); + rmaxq.update(2, 101); + rmaxq.update(100, 1001); + assert_eq!(rmaxq.get(0, 1), 10); + assert_eq!(rmaxq.get(0, 2), 10); + assert_eq!(rmaxq.get(0, 3), 101); + assert_eq!(rmaxq.get(0, 100100), 1001); + assert_eq!(rmaxq.get(0, 0), rmaxq.identity()); } #[test] fn it_works_rminq() { - let mut rminq = SegmentTree::new(Mode::RangeGet(Op::Min)); - rminq.update_one(0, 101); - rminq.update_one(2, 10); - rminq.update_one(100, 1001); - assert_eq!(rminq.get_range(0, 1), 101); - assert_eq!(rminq.get_range(0, 2), 101); - assert_eq!(rminq.get_range(0, 3), 10); - assert_eq!(rminq.get_range(0, 100100), 10); - assert_eq!(rminq.get_range(101, 1000), SegmentTree::MAX); - assert_eq!(rminq.get_range(0, 0), SegmentTree::MAX); + let mut rminq: RangeGetTree = RangeGetTree::new(Op::Min); + rminq.update(0, 101); + rminq.update(2, 10); + rminq.update(100, 1001); + assert_eq!(rminq.get(0, 1), 101); + assert_eq!(rminq.get(0, 2), 101); + assert_eq!(rminq.get(0, 3), 10); + assert_eq!(rminq.get(0, 100100), 10); + assert_eq!(rminq.get(0, 0), rminq.identity()); + } +} + +#[derive(Debug, Clone)] +pub struct RangeUpdateTree { + data: Vec, + op: Op, +} + +impl RangeUpdateTree +where + T: Clone + Add + Ord + Zero + Bounded, +{ + const SEQ_LEN: usize = 1 << 20; + + pub fn new(op: Op) -> Self { + let identity = Self::identity_for(&op); + Self { + data: vec![identity; 2 * Self::SEQ_LEN], + op, + } + } + + fn identity_for(op: &Op) -> T { + match op { + Op::Add => T::zero(), + _ => panic!("Unsupported op for RangeUpdateTree: {:?}", op), + } + } + + pub fn identity(&self) -> T { + Self::identity_for(&self.op) + } + + fn operate(&self, a: T, b: T) -> T { + match &self.op { + Op::Add => a + b, + _ => panic!("Unsupported op for RangeUpdateTree: {:?}", &self.op), + } + } + + pub fn get(&self, mut index: usize) -> T { + index += Self::SEQ_LEN; + let mut ret = self.identity(); + ret = self.operate(ret, self.data[index].clone()); + while index > 1 { + index /= 2; + ret = self.operate(ret, self.data[index].clone()); + } + ret + } + + pub fn update(&mut self, mut l: usize, mut r: usize, value: T) { + l += Self::SEQ_LEN; + r += Self::SEQ_LEN; + while l < r { + if l % 2 == 1 { + self.data[l] = self.operate(self.data[l].clone(), value.clone()); + l += 1; + } + l /= 2; + if r % 2 == 1 { + self.data[r - 1] = self.operate(self.data[r - 1].clone(), value.clone()); + r -= 1; + } + r /= 2; + } + } +} + +#[cfg(test)] +mod test_range_update_tree { + use super::{Op, RangeUpdateTree}; + + #[test] + fn it_works_raq() { + let mut raq: RangeUpdateTree = RangeUpdateTree::new(Op::Add); + raq.update(1, 2, 1); + raq.update(2, 4, 2); + raq.update(3, 4, 3); + assert_eq!(raq.get(0), raq.identity()); + assert_eq!(raq.get(2), 2); + assert_eq!(raq.get(3), 5); } } From 87ed094af4d001c28357538f67ff533a8a3f8ee6 Mon Sep 17 00:00:00 2001 From: himkt Date: Mon, 15 Dec 2025 09:30:13 +0900 Subject: [PATCH 2/2] Update src/tree.rs Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/tree.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/tree.rs b/src/tree.rs index bdcb087..3ce68fd 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -202,8 +202,7 @@ where pub fn get(&self, mut index: usize) -> T { index += Self::SEQ_LEN; - let mut ret = self.identity(); - ret = self.operate(ret, self.data[index].clone()); + let mut ret = self.operate(self.identity(), self.data[index].clone()); while index > 1 { index /= 2; ret = self.operate(ret, self.data[index].clone());