Skip to content

Commit 224d7fd

Browse files
feat(completions): complete in WITH CHECK and USING clauses (#422)
1 parent 4cb12df commit 224d7fd

File tree

8 files changed

+504
-74
lines changed

8 files changed

+504
-74
lines changed

crates/pgt_completions/src/context/base_parser.rs

Lines changed: 121 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
use std::iter::Peekable;
2-
31
use pgt_text_size::{TextRange, TextSize};
2+
use std::iter::Peekable;
43

54
pub(crate) struct TokenNavigator {
65
tokens: Peekable<std::vec::IntoIter<WordWithIndex>>,
@@ -101,73 +100,139 @@ impl WordWithIndex {
101100
}
102101
}
103102

104-
/// Note: A policy name within quotation marks will be considered a single word.
105-
pub(crate) fn sql_to_words(sql: &str) -> Result<Vec<WordWithIndex>, String> {
106-
let mut words = vec![];
107-
108-
let mut start_of_word: Option<usize> = None;
109-
let mut current_word = String::new();
110-
let mut in_quotation_marks = false;
111-
112-
for (current_position, current_char) in sql.char_indices() {
113-
if (current_char.is_ascii_whitespace() || current_char == ';')
114-
&& !current_word.is_empty()
115-
&& start_of_word.is_some()
116-
&& !in_quotation_marks
117-
{
118-
words.push(WordWithIndex {
119-
word: current_word,
120-
start: start_of_word.unwrap(),
121-
end: current_position,
122-
});
123-
124-
current_word = String::new();
125-
start_of_word = None;
126-
} else if (current_char.is_ascii_whitespace() || current_char == ';')
127-
&& current_word.is_empty()
128-
{
129-
// do nothing
130-
} else if current_char == '"' && start_of_word.is_none() {
131-
in_quotation_marks = true;
132-
current_word.push(current_char);
133-
start_of_word = Some(current_position);
134-
} else if current_char == '"' && start_of_word.is_some() {
135-
current_word.push(current_char);
136-
in_quotation_marks = false;
137-
} else if start_of_word.is_some() {
138-
current_word.push(current_char)
103+
pub(crate) struct SubStatementParser {
104+
start_of_word: Option<usize>,
105+
current_word: String,
106+
in_quotation_marks: bool,
107+
is_fn_call: bool,
108+
words: Vec<WordWithIndex>,
109+
}
110+
111+
impl SubStatementParser {
112+
pub(crate) fn parse(sql: &str) -> Result<Vec<WordWithIndex>, String> {
113+
let mut parser = SubStatementParser {
114+
start_of_word: None,
115+
current_word: String::new(),
116+
in_quotation_marks: false,
117+
is_fn_call: false,
118+
words: vec![],
119+
};
120+
121+
parser.collect_words(sql);
122+
123+
if parser.in_quotation_marks {
124+
Err("String was not closed properly.".into())
139125
} else {
140-
start_of_word = Some(current_position);
141-
current_word.push(current_char);
126+
Ok(parser.words)
142127
}
143128
}
144129

145-
if let Some(start_of_word) = start_of_word {
146-
if !current_word.is_empty() {
147-
words.push(WordWithIndex {
148-
word: current_word,
149-
start: start_of_word,
150-
end: sql.len(),
151-
});
130+
pub fn collect_words(&mut self, sql: &str) {
131+
for (pos, c) in sql.char_indices() {
132+
match c {
133+
'"' => {
134+
if !self.has_started_word() {
135+
self.in_quotation_marks = true;
136+
self.add_char(c);
137+
self.start_word(pos);
138+
} else {
139+
self.in_quotation_marks = false;
140+
self.add_char(c);
141+
}
142+
}
143+
144+
'(' => {
145+
if !self.has_started_word() {
146+
self.push_char_as_word(c, pos);
147+
} else {
148+
self.add_char(c);
149+
self.is_fn_call = true;
150+
}
151+
}
152+
153+
')' => {
154+
if self.is_fn_call {
155+
self.add_char(c);
156+
self.is_fn_call = false;
157+
} else {
158+
if self.has_started_word() {
159+
self.push_word(pos);
160+
}
161+
self.push_char_as_word(c, pos);
162+
}
163+
}
164+
165+
_ => {
166+
if c.is_ascii_whitespace() || c == ';' {
167+
if self.in_quotation_marks {
168+
self.add_char(c);
169+
} else if !self.is_empty() && self.has_started_word() {
170+
self.push_word(pos);
171+
}
172+
} else if self.has_started_word() {
173+
self.add_char(c);
174+
} else {
175+
self.start_word(pos);
176+
self.add_char(c)
177+
}
178+
}
179+
}
180+
}
181+
182+
if self.has_started_word() && !self.is_empty() {
183+
self.push_word(sql.len())
152184
}
153185
}
154186

155-
if in_quotation_marks {
156-
Err("String was not closed properly.".into())
157-
} else {
158-
Ok(words)
187+
fn is_empty(&self) -> bool {
188+
self.current_word.is_empty()
189+
}
190+
191+
fn add_char(&mut self, c: char) {
192+
self.current_word.push(c)
193+
}
194+
195+
fn start_word(&mut self, pos: usize) {
196+
self.start_of_word = Some(pos);
197+
}
198+
199+
fn has_started_word(&self) -> bool {
200+
self.start_of_word.is_some()
201+
}
202+
203+
fn push_char_as_word(&mut self, c: char, pos: usize) {
204+
self.words.push(WordWithIndex {
205+
word: String::from(c),
206+
start: pos,
207+
end: pos + 1,
208+
});
209+
}
210+
211+
fn push_word(&mut self, current_position: usize) {
212+
self.words.push(WordWithIndex {
213+
word: self.current_word.clone(),
214+
start: self.start_of_word.unwrap(),
215+
end: current_position,
216+
});
217+
self.current_word = String::new();
218+
self.start_of_word = None;
159219
}
160220
}
161221

222+
/// Note: A policy name within quotation marks will be considered a single word.
223+
pub(crate) fn sql_to_words(sql: &str) -> Result<Vec<WordWithIndex>, String> {
224+
SubStatementParser::parse(sql)
225+
}
226+
162227
#[cfg(test)]
163228
mod tests {
164-
use crate::context::base_parser::{WordWithIndex, sql_to_words};
229+
use crate::context::base_parser::{SubStatementParser, WordWithIndex, sql_to_words};
165230

166231
#[test]
167232
fn determines_positions_correctly() {
168-
let query = "\ncreate policy \"my cool pol\"\n\ton auth.users\n\tas permissive\n\tfor select\n\t\tto public\n\t\tusing (true);".to_string();
233+
let query = "\ncreate policy \"my cool pol\"\n\ton auth.users\n\tas permissive\n\tfor select\n\t\tto public\n\t\tusing (auth.uid());".to_string();
169234

170-
let words = sql_to_words(query.as_str()).unwrap();
235+
let words = SubStatementParser::parse(query.as_str()).unwrap();
171236

172237
assert_eq!(words[0], to_word("create", 1, 7));
173238
assert_eq!(words[1], to_word("policy", 8, 14));
@@ -181,7 +246,9 @@ mod tests {
181246
assert_eq!(words[9], to_word("to", 73, 75));
182247
assert_eq!(words[10], to_word("public", 78, 84));
183248
assert_eq!(words[11], to_word("using", 87, 92));
184-
assert_eq!(words[12], to_word("(true)", 93, 99));
249+
assert_eq!(words[12], to_word("(", 93, 94));
250+
assert_eq!(words[13], to_word("auth.uid()", 94, 104));
251+
assert_eq!(words[14], to_word(")", 104, 105));
185252
}
186253

187254
#[test]

crates/pgt_completions/src/context/mod.rs

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,15 @@ pub enum WrappingClause<'a> {
4747
SetStatement,
4848
AlterRole,
4949
DropRole,
50+
51+
/// `PolicyCheck` refers to either the `WITH CHECK` or the `USING` clause
52+
/// in a policy statement.
53+
/// ```sql
54+
/// CREATE POLICY "my pol" ON PUBLIC.USERS
55+
/// FOR SELECT
56+
/// USING (...) -- this one!
57+
/// ```
58+
PolicyCheck,
5059
}
5160

5261
#[derive(PartialEq, Eq, Hash, Debug, Clone)]
@@ -78,6 +87,7 @@ pub(crate) enum NodeUnderCursor<'a> {
7887
text: NodeText,
7988
range: TextRange,
8089
kind: String,
90+
previous_node_kind: Option<String>,
8191
},
8292
}
8393

@@ -222,6 +232,7 @@ impl<'a> CompletionContext<'a> {
222232
text: revoke_context.node_text.into(),
223233
range: revoke_context.node_range,
224234
kind: revoke_context.node_kind.clone(),
235+
previous_node_kind: None,
225236
});
226237

227238
if revoke_context.node_kind == "revoke_table" {
@@ -249,6 +260,7 @@ impl<'a> CompletionContext<'a> {
249260
text: grant_context.node_text.into(),
250261
range: grant_context.node_range,
251262
kind: grant_context.node_kind.clone(),
263+
previous_node_kind: None,
252264
});
253265

254266
if grant_context.node_kind == "grant_table" {
@@ -276,6 +288,7 @@ impl<'a> CompletionContext<'a> {
276288
text: policy_context.node_text.into(),
277289
range: policy_context.node_range,
278290
kind: policy_context.node_kind.clone(),
291+
previous_node_kind: Some(policy_context.previous_node_kind),
279292
});
280293

281294
if policy_context.node_kind == "policy_table" {
@@ -295,7 +308,13 @@ impl<'a> CompletionContext<'a> {
295308
}
296309
"policy_role" => Some(WrappingClause::ToRoleAssignment),
297310
"policy_table" => Some(WrappingClause::From),
298-
_ => None,
311+
_ => {
312+
if policy_context.in_check_or_using_clause {
313+
Some(WrappingClause::PolicyCheck)
314+
} else {
315+
None
316+
}
317+
}
299318
};
300319
}
301320

@@ -785,7 +804,11 @@ impl<'a> CompletionContext<'a> {
785804
.is_some_and(|sib| kinds.contains(&sib.kind()))
786805
}
787806

788-
NodeUnderCursor::CustomNode { .. } => false,
807+
NodeUnderCursor::CustomNode {
808+
previous_node_kind, ..
809+
} => previous_node_kind
810+
.as_ref()
811+
.is_some_and(|k| kinds.contains(&k.as_str())),
789812
}
790813
})
791814
}

0 commit comments

Comments
 (0)