Skip to content

Commit 4442f30

Browse files
authored
Merge pull request #152 from ia7ck/segtree-iter
[segtree] 非再帰の実装にする
2 parents 0017b58 + bec2542 commit 4442f30

File tree

2 files changed

+51
-29
lines changed

2 files changed

+51
-29
lines changed

algo/segment_tree/examples/point_set_range_composite.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ fn main() {
1616
(c * a % mo, (c * b % mo + d) % mo)
1717
});
1818
for (i, &(a, b)) in ab.iter().enumerate() {
19-
seg.update(i, (a, b));
19+
seg.set(i, (a, b));
2020
}
2121
for _ in 0..q {
2222
input! {
@@ -28,7 +28,7 @@ fn main() {
2828
c: u64,
2929
d: u64,
3030
}
31-
seg.update(p, (c, d));
31+
seg.set(p, (c, d));
3232
} else {
3333
input! {
3434
l: usize,

algo/segment_tree/src/lib.rs

Lines changed: 49 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
use std::fmt;
2-
use std::ops::{Bound, Range, RangeBounds};
2+
use std::ops::{Bound, RangeBounds};
33

44
/// __注意⚠__ この実装は遅いので time limit の厳しい問題には代わりに ACL のセグメントツリーを使うこと。
55
///
66
/// セグメントツリーです。
77
#[derive(Clone)]
88
pub struct SegmentTree<T, F> {
9+
original_n: usize,
910
n: usize,
1011
dat: Vec<T>,
1112
e: T,
1213
multiply: F,
1314
}
1415

16+
// https://hcpc-hokudai.github.io/archive/structure_segtree_001.pdf
1517
impl<T, F> SegmentTree<T, F>
1618
where
1719
T: Clone,
@@ -21,33 +23,43 @@ where
2123
///
2224
/// `multiply` は fold に使う二項演算です。
2325
pub fn new(n: usize, e: T, multiply: F) -> Self {
26+
let original_n = n;
2427
let n = n.next_power_of_two();
2528
Self {
29+
original_n,
2630
n,
27-
dat: vec![e.clone(); n * 2 - 1],
31+
dat: vec![e.clone(); n * 2], // dat[0] is unused
2832
e,
2933
multiply,
3034
}
3135
}
3236

3337
/// 列の `i` 番目の要素を取得します。
3438
pub fn get(&self, i: usize) -> &T {
35-
&self.dat[i + self.n - 1]
39+
assert!(i < self.original_n);
40+
&self.dat[i + self.n]
3641
}
3742

3843
/// 列の `i` 番目の要素を `x` で更新します。
39-
pub fn update(&mut self, i: usize, x: T) {
40-
let mut k = i + self.n - 1;
41-
self.dat[k] = x;
42-
while k > 0 {
43-
k = (k - 1) / 2;
44-
self.dat[k] = (self.multiply)(&self.dat[k * 2 + 1], &self.dat[k * 2 + 2]);
44+
pub fn set(&mut self, i: usize, x: T) {
45+
self.update(i, |_| x);
46+
}
47+
48+
/// 列の `i` 番目の要素を `f` で更新します。
49+
pub fn update<U>(&mut self, i: usize, f: U)
50+
where
51+
U: FnOnce(&T) -> T,
52+
{
53+
assert!(i < self.original_n);
54+
let mut k = i + self.n;
55+
self.dat[k] = f(&self.dat[k]);
56+
while k > 1 {
57+
k >>= 1;
58+
self.dat[k] = (self.multiply)(&self.dat[k << 1], &self.dat[k << 1 | 1]);
4559
}
4660
}
4761

4862
/// `range` が `l..r` として、`multiply(l番目の要素, multiply(..., multiply(r-2番目の要素, r-1番目の要素)))` の値を返します。
49-
///
50-
/// 実際のアルゴリズムは、結合法則を使って `1 + (2 + (3 + 4))` ではなく `(1 + 2) + (3 + 4)` のように計算しています。
5163
pub fn fold(&self, range: impl RangeBounds<usize>) -> T {
5264
let start = match range.start_bound() {
5365
Bound::Included(&start) => start,
@@ -59,21 +71,31 @@ where
5971
Bound::Excluded(&end) => end,
6072
Bound::Unbounded => self.n,
6173
};
62-
assert!(end <= self.n);
63-
self._fold(&(start..end), 0, 0..self.n)
74+
assert!(start <= end && end <= self.original_n);
75+
self._fold(start, end)
6476
}
65-
fn _fold(&self, range: &Range<usize>, i: usize, i_range: Range<usize>) -> T {
66-
if range.end <= i_range.start || i_range.end <= range.start {
67-
return self.e.clone();
68-
}
69-
if range.start <= i_range.start && i_range.end <= range.end {
70-
return self.dat[i].clone();
77+
78+
fn _fold(&self, mut l: usize, mut r: usize) -> T {
79+
let mut acc_l = self.e.clone();
80+
let mut acc_r = self.e.clone();
81+
l += self.n;
82+
r += self.n;
83+
while l < r {
84+
if l & 1 == 1 {
85+
// 右の子だったらいま足しておかないといけない
86+
// 左の子だったら祖先のどれかで足されるのでよい
87+
acc_l = (self.multiply)(&acc_l, &self.dat[l]);
88+
l += 1;
89+
}
90+
if r & 1 == 1 {
91+
// r が exclusive であることに注意する
92+
r -= 1;
93+
acc_r = (self.multiply)(&self.dat[r], &acc_r);
94+
}
95+
l >>= 1;
96+
r >>= 1;
7197
}
72-
let m = (i_range.start + i_range.end) / 2;
73-
(self.multiply)(
74-
&self._fold(range, i * 2 + 1, i_range.start..m),
75-
&self._fold(range, i * 2 + 2, m..i_range.end),
76-
)
98+
(self.multiply)(&acc_l, &acc_r)
7799
}
78100
}
79101

@@ -82,7 +104,7 @@ where
82104
T: fmt::Debug,
83105
{
84106
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85-
write!(f, "{:?}", &self.dat[(self.n - 1)..])
107+
write!(f, "{:?}", &self.dat[self.n..])
86108
}
87109
}
88110

@@ -95,7 +117,7 @@ mod tests {
95117
let s = "abcdefgh";
96118
let mut seg = SegmentTree::new(s.len(), String::new(), |a, b| format!("{a}{b}"));
97119
for (i, c) in s.chars().enumerate() {
98-
seg.update(i, c.to_string());
120+
seg.set(i, c.to_string());
99121
}
100122

101123
for i in 0..s.len() {
@@ -117,7 +139,7 @@ mod tests {
117139
fn single_element() {
118140
let mut seg = SegmentTree::new(1, 0, |a, b| a + b);
119141
assert_eq!(seg.get(0), &0);
120-
seg.update(0, 42);
142+
seg.set(0, 42);
121143
assert_eq!(seg.get(0), &42);
122144
}
123145
}

0 commit comments

Comments
 (0)