1
1
use std:: fmt;
2
- use std:: ops:: { Bound , Range , RangeBounds } ;
2
+ use std:: ops:: { Bound , RangeBounds } ;
3
3
4
4
/// __注意⚠__ この実装は遅いので time limit の厳しい問題には代わりに ACL のセグメントツリーを使うこと。
5
5
///
6
6
/// セグメントツリーです。
7
7
#[ derive( Clone ) ]
8
8
pub struct SegmentTree < T , F > {
9
+ original_n : usize ,
9
10
n : usize ,
10
11
dat : Vec < T > ,
11
12
e : T ,
12
13
multiply : F ,
13
14
}
14
15
16
+ // https://hcpc-hokudai.github.io/archive/structure_segtree_001.pdf
15
17
impl < T , F > SegmentTree < T , F >
16
18
where
17
19
T : Clone ,
@@ -21,33 +23,43 @@ where
21
23
///
22
24
/// `multiply` は fold に使う二項演算です。
23
25
pub fn new ( n : usize , e : T , multiply : F ) -> Self {
26
+ let original_n = n;
24
27
let n = n. next_power_of_two ( ) ;
25
28
Self {
29
+ original_n,
26
30
n,
27
- dat : vec ! [ e. clone( ) ; n * 2 - 1 ] ,
31
+ dat : vec ! [ e. clone( ) ; n * 2 ] , // dat[0] is unused
28
32
e,
29
33
multiply,
30
34
}
31
35
}
32
36
33
37
/// 列の `i` 番目の要素を取得します。
34
38
pub fn get ( & self , i : usize ) -> & T {
35
- & self . dat [ i + self . n - 1 ]
39
+ assert ! ( i < self . original_n) ;
40
+ & self . dat [ i + self . n ]
36
41
}
37
42
38
43
/// 列の `i` 番目の要素を `x` で更新します。
39
- pub fn update ( & mut self , i : usize , x : T ) {
40
- let mut k = i + self . n - 1 ;
41
- self . dat [ k] = x;
42
- while k > 0 {
43
- k = ( k - 1 ) / 2 ;
44
- self . dat [ k] = ( self . multiply ) ( & self . dat [ k * 2 + 1 ] , & self . dat [ k * 2 + 2 ] ) ;
44
+ pub fn set ( & mut self , i : usize , x : T ) {
45
+ self . update ( i, |_| x) ;
46
+ }
47
+
48
+ /// 列の `i` 番目の要素を `f` で更新します。
49
+ pub fn update < U > ( & mut self , i : usize , f : U )
50
+ where
51
+ U : FnOnce ( & T ) -> T ,
52
+ {
53
+ assert ! ( i < self . original_n) ;
54
+ let mut k = i + self . n ;
55
+ self . dat [ k] = f ( & self . dat [ k] ) ;
56
+ while k > 1 {
57
+ k >>= 1 ;
58
+ self . dat [ k] = ( self . multiply ) ( & self . dat [ k << 1 ] , & self . dat [ k << 1 | 1 ] ) ;
45
59
}
46
60
}
47
61
48
62
/// `range` が `l..r` として、`multiply(l番目の要素, multiply(..., multiply(r-2番目の要素, r-1番目の要素)))` の値を返します。
49
- ///
50
- /// 実際のアルゴリズムは、結合法則を使って `1 + (2 + (3 + 4))` ではなく `(1 + 2) + (3 + 4)` のように計算しています。
51
63
pub fn fold ( & self , range : impl RangeBounds < usize > ) -> T {
52
64
let start = match range. start_bound ( ) {
53
65
Bound :: Included ( & start) => start,
@@ -59,21 +71,31 @@ where
59
71
Bound :: Excluded ( & end) => end,
60
72
Bound :: Unbounded => self . n ,
61
73
} ;
62
- assert ! ( end <= self . n ) ;
63
- self . _fold ( & ( start..end ) , 0 , 0 .. self . n )
74
+ assert ! ( start <= end && end <= self . original_n ) ;
75
+ self . _fold ( start, end )
64
76
}
65
- fn _fold ( & self , range : & Range < usize > , i : usize , i_range : Range < usize > ) -> T {
66
- if range. end <= i_range. start || i_range. end <= range. start {
67
- return self . e . clone ( ) ;
68
- }
69
- if range. start <= i_range. start && i_range. end <= range. end {
70
- return self . dat [ i] . clone ( ) ;
77
+
78
+ fn _fold ( & self , mut l : usize , mut r : usize ) -> T {
79
+ let mut acc_l = self . e . clone ( ) ;
80
+ let mut acc_r = self . e . clone ( ) ;
81
+ l += self . n ;
82
+ r += self . n ;
83
+ while l < r {
84
+ if l & 1 == 1 {
85
+ // 右の子だったらいま足しておかないといけない
86
+ // 左の子だったら祖先のどれかで足されるのでよい
87
+ acc_l = ( self . multiply ) ( & acc_l, & self . dat [ l] ) ;
88
+ l += 1 ;
89
+ }
90
+ if r & 1 == 1 {
91
+ // r が exclusive であることに注意する
92
+ r -= 1 ;
93
+ acc_r = ( self . multiply ) ( & self . dat [ r] , & acc_r) ;
94
+ }
95
+ l >>= 1 ;
96
+ r >>= 1 ;
71
97
}
72
- let m = ( i_range. start + i_range. end ) / 2 ;
73
- ( self . multiply ) (
74
- & self . _fold ( range, i * 2 + 1 , i_range. start ..m) ,
75
- & self . _fold ( range, i * 2 + 2 , m..i_range. end ) ,
76
- )
98
+ ( self . multiply ) ( & acc_l, & acc_r)
77
99
}
78
100
}
79
101
82
104
T : fmt:: Debug ,
83
105
{
84
106
fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
85
- write ! ( f, "{:?}" , & self . dat[ ( self . n - 1 ) ..] )
107
+ write ! ( f, "{:?}" , & self . dat[ self . n..] )
86
108
}
87
109
}
88
110
@@ -95,7 +117,7 @@ mod tests {
95
117
let s = "abcdefgh" ;
96
118
let mut seg = SegmentTree :: new ( s. len ( ) , String :: new ( ) , |a, b| format ! ( "{a}{b}" ) ) ;
97
119
for ( i, c) in s. chars ( ) . enumerate ( ) {
98
- seg. update ( i, c. to_string ( ) ) ;
120
+ seg. set ( i, c. to_string ( ) ) ;
99
121
}
100
122
101
123
for i in 0 ..s. len ( ) {
@@ -117,7 +139,7 @@ mod tests {
117
139
fn single_element ( ) {
118
140
let mut seg = SegmentTree :: new ( 1 , 0 , |a, b| a + b) ;
119
141
assert_eq ! ( seg. get( 0 ) , & 0 ) ;
120
- seg. update ( 0 , 42 ) ;
142
+ seg. set ( 0 , 42 ) ;
121
143
assert_eq ! ( seg. get( 0 ) , & 42 ) ;
122
144
}
123
145
}
0 commit comments