1
- use std:: iter:: Peekable ;
2
-
3
1
use pgt_text_size:: { TextRange , TextSize } ;
2
+ use std:: iter:: Peekable ;
4
3
5
4
pub ( crate ) struct TokenNavigator {
6
5
tokens : Peekable < std:: vec:: IntoIter < WordWithIndex > > ,
@@ -101,73 +100,139 @@ impl WordWithIndex {
101
100
}
102
101
}
103
102
104
- /// Note: A policy name within quotation marks will be considered a single word.
105
- pub ( crate ) fn sql_to_words ( sql : & str ) -> Result < Vec < WordWithIndex > , String > {
106
- let mut words = vec ! [ ] ;
107
-
108
- let mut start_of_word: Option < usize > = None ;
109
- let mut current_word = String :: new ( ) ;
110
- let mut in_quotation_marks = false ;
111
-
112
- for ( current_position, current_char) in sql. char_indices ( ) {
113
- if ( current_char. is_ascii_whitespace ( ) || current_char == ';' )
114
- && !current_word. is_empty ( )
115
- && start_of_word. is_some ( )
116
- && !in_quotation_marks
117
- {
118
- words. push ( WordWithIndex {
119
- word : current_word,
120
- start : start_of_word. unwrap ( ) ,
121
- end : current_position,
122
- } ) ;
123
-
124
- current_word = String :: new ( ) ;
125
- start_of_word = None ;
126
- } else if ( current_char. is_ascii_whitespace ( ) || current_char == ';' )
127
- && current_word. is_empty ( )
128
- {
129
- // do nothing
130
- } else if current_char == '"' && start_of_word. is_none ( ) {
131
- in_quotation_marks = true ;
132
- current_word. push ( current_char) ;
133
- start_of_word = Some ( current_position) ;
134
- } else if current_char == '"' && start_of_word. is_some ( ) {
135
- current_word. push ( current_char) ;
136
- in_quotation_marks = false ;
137
- } else if start_of_word. is_some ( ) {
138
- current_word. push ( current_char)
103
+ pub ( crate ) struct SubStatementParser {
104
+ start_of_word : Option < usize > ,
105
+ current_word : String ,
106
+ in_quotation_marks : bool ,
107
+ is_fn_call : bool ,
108
+ words : Vec < WordWithIndex > ,
109
+ }
110
+
111
+ impl SubStatementParser {
112
+ pub ( crate ) fn parse ( sql : & str ) -> Result < Vec < WordWithIndex > , String > {
113
+ let mut parser = SubStatementParser {
114
+ start_of_word : None ,
115
+ current_word : String :: new ( ) ,
116
+ in_quotation_marks : false ,
117
+ is_fn_call : false ,
118
+ words : vec ! [ ] ,
119
+ } ;
120
+
121
+ parser. collect_words ( sql) ;
122
+
123
+ if parser. in_quotation_marks {
124
+ Err ( "String was not closed properly." . into ( ) )
139
125
} else {
140
- start_of_word = Some ( current_position) ;
141
- current_word. push ( current_char) ;
126
+ Ok ( parser. words )
142
127
}
143
128
}
144
129
145
- if let Some ( start_of_word) = start_of_word {
146
- if !current_word. is_empty ( ) {
147
- words. push ( WordWithIndex {
148
- word : current_word,
149
- start : start_of_word,
150
- end : sql. len ( ) ,
151
- } ) ;
130
+ pub fn collect_words ( & mut self , sql : & str ) {
131
+ for ( pos, c) in sql. char_indices ( ) {
132
+ match c {
133
+ '"' => {
134
+ if !self . has_started_word ( ) {
135
+ self . in_quotation_marks = true ;
136
+ self . add_char ( c) ;
137
+ self . start_word ( pos) ;
138
+ } else {
139
+ self . in_quotation_marks = false ;
140
+ self . add_char ( c) ;
141
+ }
142
+ }
143
+
144
+ '(' => {
145
+ if !self . has_started_word ( ) {
146
+ self . push_char_as_word ( c, pos) ;
147
+ } else {
148
+ self . add_char ( c) ;
149
+ self . is_fn_call = true ;
150
+ }
151
+ }
152
+
153
+ ')' => {
154
+ if self . is_fn_call {
155
+ self . add_char ( c) ;
156
+ self . is_fn_call = false ;
157
+ } else {
158
+ if self . has_started_word ( ) {
159
+ self . push_word ( pos) ;
160
+ }
161
+ self . push_char_as_word ( c, pos) ;
162
+ }
163
+ }
164
+
165
+ _ => {
166
+ if c. is_ascii_whitespace ( ) || c == ';' {
167
+ if self . in_quotation_marks {
168
+ self . add_char ( c) ;
169
+ } else if !self . is_empty ( ) && self . has_started_word ( ) {
170
+ self . push_word ( pos) ;
171
+ }
172
+ } else if self . has_started_word ( ) {
173
+ self . add_char ( c) ;
174
+ } else {
175
+ self . start_word ( pos) ;
176
+ self . add_char ( c)
177
+ }
178
+ }
179
+ }
180
+ }
181
+
182
+ if self . has_started_word ( ) && !self . is_empty ( ) {
183
+ self . push_word ( sql. len ( ) )
152
184
}
153
185
}
154
186
155
- if in_quotation_marks {
156
- Err ( "String was not closed properly." . into ( ) )
157
- } else {
158
- Ok ( words)
187
+ fn is_empty ( & self ) -> bool {
188
+ self . current_word . is_empty ( )
189
+ }
190
+
191
+ fn add_char ( & mut self , c : char ) {
192
+ self . current_word . push ( c)
193
+ }
194
+
195
+ fn start_word ( & mut self , pos : usize ) {
196
+ self . start_of_word = Some ( pos) ;
197
+ }
198
+
199
+ fn has_started_word ( & self ) -> bool {
200
+ self . start_of_word . is_some ( )
201
+ }
202
+
203
+ fn push_char_as_word ( & mut self , c : char , pos : usize ) {
204
+ self . words . push ( WordWithIndex {
205
+ word : String :: from ( c) ,
206
+ start : pos,
207
+ end : pos + 1 ,
208
+ } ) ;
209
+ }
210
+
211
+ fn push_word ( & mut self , current_position : usize ) {
212
+ self . words . push ( WordWithIndex {
213
+ word : self . current_word . clone ( ) ,
214
+ start : self . start_of_word . unwrap ( ) ,
215
+ end : current_position,
216
+ } ) ;
217
+ self . current_word = String :: new ( ) ;
218
+ self . start_of_word = None ;
159
219
}
160
220
}
161
221
222
+ /// Note: A policy name within quotation marks will be considered a single word.
223
+ pub ( crate ) fn sql_to_words ( sql : & str ) -> Result < Vec < WordWithIndex > , String > {
224
+ SubStatementParser :: parse ( sql)
225
+ }
226
+
162
227
#[ cfg( test) ]
163
228
mod tests {
164
- use crate :: context:: base_parser:: { WordWithIndex , sql_to_words} ;
229
+ use crate :: context:: base_parser:: { SubStatementParser , WordWithIndex , sql_to_words} ;
165
230
166
231
#[ test]
167
232
fn determines_positions_correctly ( ) {
168
- let query = "\n create policy \" my cool pol\" \n \t on auth.users\n \t as permissive\n \t for select\n \t \t to public\n \t \t using (true );" . to_string ( ) ;
233
+ let query = "\n create policy \" my cool pol\" \n \t on auth.users\n \t as permissive\n \t for select\n \t \t to public\n \t \t using (auth.uid() );" . to_string ( ) ;
169
234
170
- let words = sql_to_words ( query. as_str ( ) ) . unwrap ( ) ;
235
+ let words = SubStatementParser :: parse ( query. as_str ( ) ) . unwrap ( ) ;
171
236
172
237
assert_eq ! ( words[ 0 ] , to_word( "create" , 1 , 7 ) ) ;
173
238
assert_eq ! ( words[ 1 ] , to_word( "policy" , 8 , 14 ) ) ;
@@ -181,7 +246,9 @@ mod tests {
181
246
assert_eq ! ( words[ 9 ] , to_word( "to" , 73 , 75 ) ) ;
182
247
assert_eq ! ( words[ 10 ] , to_word( "public" , 78 , 84 ) ) ;
183
248
assert_eq ! ( words[ 11 ] , to_word( "using" , 87 , 92 ) ) ;
184
- assert_eq ! ( words[ 12 ] , to_word( "(true)" , 93 , 99 ) ) ;
249
+ assert_eq ! ( words[ 12 ] , to_word( "(" , 93 , 94 ) ) ;
250
+ assert_eq ! ( words[ 13 ] , to_word( "auth.uid()" , 94 , 104 ) ) ;
251
+ assert_eq ! ( words[ 14 ] , to_word( ")" , 104 , 105 ) ) ;
185
252
}
186
253
187
254
#[ test]
0 commit comments