Skip to content

Commit 7357ac6

Browse files
authored
feat: parse query parameters in PostgreSQL query (#1732)
* fix: PostgreSQL supports newline in quoted literals and identifiers PostgreSQL supports newline characters in string literals and quoted identifiers. Trying to execute a statement with a string literal or quoted identifier that contained a newline character would cause an 'Unclosed string literal' error. Fixes #1730 * feat: parse query parameters in PostgreSQL query Adds a helper method to get the parameters from a PostgreSQL query. This is needed for DESCRIBE statement messages in PGAdapter, as it must return the data types of all query parameters in a query string. Even though this parser is not able to determine the parameter types, it is able to determine the number of parameters. This again makes it possible to PGAdapter to return Oid.UNSPECIFIED for each parameter in the query string, which is enough for most clients.
1 parent f403d99 commit 7357ac6

File tree

2 files changed

+108
-27
lines changed

2 files changed

+108
-27
lines changed

google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/PostgreSQLStatementParser.java

Lines changed: 74 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
import com.google.cloud.spanner.ErrorCode;
2222
import com.google.cloud.spanner.SpannerExceptionFactory;
2323
import com.google.common.base.Preconditions;
24+
import java.util.HashSet;
25+
import java.util.Set;
26+
import javax.annotation.Nullable;
2427

2528
@InternalApi
2629
public class PostgreSQLStatementParser extends AbstractStatementParser {
@@ -149,30 +152,76 @@ ParametersInfo convertPositionalParametersToNamedParametersInternal(char paramCh
149152
return new ParametersInfo(paramIndex - 1, named.toString());
150153
}
151154

152-
private int skip(String sql, int currentIndex, StringBuilder result) {
155+
/**
156+
* Note: This is an internal API and breaking changes can be made without prior notice.
157+
*
158+
* <p>Returns the PostgreSQL-style query parameters ($1, $2, ...) in the given SQL string. The
159+
* SQL-string is assumed to not contain any comments. Use {@link #removeCommentsAndTrim(String)}
160+
* to remove all comments before calling this method. Occurrences of query-parameter like strings
161+
* inside quoted identifiers or string literals are ignored.
162+
*
163+
* <p>The following example will return a set containing ("$1", "$2"). <code>
164+
* select col1, col2, "col$4"
165+
* from some_table
166+
* where col1=$1 and col2=$2
167+
* and not col3=$1 and col4='$3'
168+
* </code>
169+
*
170+
* @param sql the SQL-string to check for parameters. Must not contain comments.
171+
* @return A set containing all the parameters in the SQL-string.
172+
*/
173+
@InternalApi
174+
public Set<String> getQueryParameters(String sql) {
175+
Preconditions.checkNotNull(sql);
176+
int maxCount = countOccurrencesOf('$', sql);
177+
Set<String> parameters = new HashSet<>(maxCount);
178+
int currentIndex = 0;
179+
while (currentIndex < sql.length() - 1) {
180+
char c = sql.charAt(currentIndex);
181+
if (c == '$' && Character.isDigit(sql.charAt(currentIndex + 1))) {
182+
// Look ahead for the first non-digit. That is the end of the query parameter.
183+
int endIndex = currentIndex + 2;
184+
while (endIndex < sql.length() && Character.isDigit(sql.charAt(endIndex))) {
185+
endIndex++;
186+
}
187+
parameters.add(sql.substring(currentIndex, endIndex));
188+
currentIndex = endIndex;
189+
} else {
190+
currentIndex = skip(sql, currentIndex, null);
191+
}
192+
}
193+
return parameters;
194+
}
195+
196+
private int skip(String sql, int currentIndex, @Nullable StringBuilder result) {
153197
char currentChar = sql.charAt(currentIndex);
154198
if (currentChar == SINGLE_QUOTE || currentChar == DOUBLE_QUOTE) {
155-
result.append(currentChar);
199+
appendIfNotNull(result, currentChar);
156200
return skipQuoted(sql, currentIndex, currentChar, result);
157201
} else if (currentChar == DOLLAR) {
158202
String dollarTag = parseDollarQuotedString(sql, currentIndex + 1);
159203
if (dollarTag != null) {
160-
result.append(currentChar).append(dollarTag).append(currentChar);
204+
appendIfNotNull(result, currentChar, dollarTag, currentChar);
161205
return skipQuoted(
162206
sql, currentIndex + dollarTag.length() + 1, currentChar, dollarTag, result);
163207
}
164208
}
165209

166-
result.append(currentChar);
210+
appendIfNotNull(result, currentChar);
167211
return currentIndex + 1;
168212
}
169213

170-
private int skipQuoted(String sql, int startIndex, char startQuote, StringBuilder result) {
214+
private int skipQuoted(
215+
String sql, int startIndex, char startQuote, @Nullable StringBuilder result) {
171216
return skipQuoted(sql, startIndex, startQuote, null, result);
172217
}
173218

174219
private int skipQuoted(
175-
String sql, int startIndex, char startQuote, String dollarTag, StringBuilder result) {
220+
String sql,
221+
int startIndex,
222+
char startQuote,
223+
String dollarTag,
224+
@Nullable StringBuilder result) {
176225
boolean lastCharWasEscapeChar = false;
177226
int currentIndex = startIndex + 1;
178227
while (currentIndex < sql.length()) {
@@ -182,29 +231,41 @@ private int skipQuoted(
182231
// Check if this is the end of the current dollar quoted string.
183232
String tag = parseDollarQuotedString(sql, currentIndex + 1);
184233
if (tag != null && tag.equals(dollarTag)) {
185-
result.append(currentChar).append(tag).append(currentChar);
234+
appendIfNotNull(result, currentChar, dollarTag, currentChar);
186235
return currentIndex + tag.length() + 2;
187236
}
188237
} else if (lastCharWasEscapeChar) {
189238
lastCharWasEscapeChar = false;
190239
} else if (sql.length() > currentIndex + 1 && sql.charAt(currentIndex + 1) == startQuote) {
191240
// This is an escaped quote (e.g. 'foo''bar')
192-
result.append(currentChar).append(currentChar);
241+
appendIfNotNull(result, currentChar);
242+
appendIfNotNull(result, currentChar);
193243
currentIndex += 2;
194244
continue;
195245
} else {
196-
result.append(currentChar);
246+
appendIfNotNull(result, currentChar);
197247
return currentIndex + 1;
198248
}
199-
} else if (currentChar == '\\') {
200-
lastCharWasEscapeChar = true;
201249
} else {
202-
lastCharWasEscapeChar = false;
250+
lastCharWasEscapeChar = currentChar == '\\';
203251
}
204252
currentIndex++;
205-
result.append(currentChar);
253+
appendIfNotNull(result, currentChar);
206254
}
207255
throw SpannerExceptionFactory.newSpannerException(
208256
ErrorCode.INVALID_ARGUMENT, "SQL statement contains an unclosed literal: " + sql);
209257
}
258+
259+
private void appendIfNotNull(@Nullable StringBuilder result, char currentChar) {
260+
if (result != null) {
261+
result.append(currentChar);
262+
}
263+
}
264+
265+
private void appendIfNotNull(
266+
@Nullable StringBuilder result, char prefix, String tag, char suffix) {
267+
if (result != null) {
268+
result.append(prefix).append(tag).append(suffix);
269+
}
270+
}
210271
}

google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/StatementParserTest.java

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import static org.junit.Assert.assertFalse;
2525
import static org.junit.Assert.assertTrue;
2626
import static org.junit.Assert.fail;
27+
import static org.junit.Assume.assumeTrue;
2728

2829
import com.google.cloud.spanner.Dialect;
2930
import com.google.cloud.spanner.ErrorCode;
@@ -33,6 +34,7 @@
3334
import com.google.cloud.spanner.connection.AbstractStatementParser.StatementType;
3435
import com.google.cloud.spanner.connection.ClientSideStatementImpl.CompileException;
3536
import com.google.common.collect.ImmutableMap;
37+
import com.google.common.collect.ImmutableSet;
3638
import com.google.common.truth.Truth;
3739
import java.io.File;
3840
import java.io.FileNotFoundException;
@@ -42,7 +44,6 @@
4244
import java.util.Set;
4345
import java.util.regex.Matcher;
4446
import java.util.regex.Pattern;
45-
import org.junit.Assume;
4647
import org.junit.Before;
4748
import org.junit.Test;
4849
import org.junit.runner.RunWith;
@@ -158,7 +159,7 @@ public void testRemoveComments() {
158159

159160
@Test
160161
public void testGoogleStandardSQLRemoveCommentsGsql() {
161-
Assume.assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);
162+
assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);
162163

163164
assertThat(parser.removeCommentsAndTrim("/*GSQL*/")).isEqualTo("");
164165
assertThat(parser.removeCommentsAndTrim("/*GSQL*/SELECT * FROM FOO"))
@@ -183,7 +184,7 @@ public void testGoogleStandardSQLRemoveCommentsGsql() {
183184

184185
@Test
185186
public void testPostgreSQLDialectRemoveCommentsGsql() {
186-
Assume.assumeTrue(dialect == Dialect.POSTGRESQL);
187+
assumeTrue(dialect == Dialect.POSTGRESQL);
187188

188189
assertThat(parser.removeCommentsAndTrim("/*GSQL*/")).isEqualTo("/*GSQL*/");
189190
assertThat(parser.removeCommentsAndTrim("/*GSQL*/SELECT * FROM FOO"))
@@ -273,7 +274,7 @@ public void testStatementWithCommentContainingSlashAndNoAsteriskOnNewLine() {
273274

274275
@Test
275276
public void testPostgresSQLDialectDollarQuoted() {
276-
Assume.assumeTrue(dialect == Dialect.POSTGRESQL);
277+
assumeTrue(dialect == Dialect.POSTGRESQL);
277278

278279
assertThat(parser.removeCommentsAndTrim("$$foo$$")).isEqualTo("$$foo$$");
279280
assertThat(parser.removeCommentsAndTrim("$$--foo$$")).isEqualTo("$$--foo$$");
@@ -296,7 +297,7 @@ public void testPostgresSQLDialectDollarQuoted() {
296297

297298
@Test
298299
public void testPostgreSQLDialectSupportsEmbeddedComments() {
299-
Assume.assumeTrue(dialect == Dialect.POSTGRESQL);
300+
assumeTrue(dialect == Dialect.POSTGRESQL);
300301

301302
final String sql =
302303
"/* This is a comment /* This is an embedded comment */ This is after the embedded comment */ SELECT 1";
@@ -305,7 +306,7 @@ public void testPostgreSQLDialectSupportsEmbeddedComments() {
305306

306307
@Test
307308
public void testGoogleStandardSQLDialectDoesNotSupportEmbeddedComments() {
308-
Assume.assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);
309+
assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);
309310

310311
final String sql =
311312
"/* This is a comment /* This is an embedded comment */ This is after the embedded comment */ SELECT 1";
@@ -315,7 +316,7 @@ public void testGoogleStandardSQLDialectDoesNotSupportEmbeddedComments() {
315316

316317
@Test
317318
public void testPostgreSQLDialectUnterminatedComment() {
318-
Assume.assumeTrue(dialect == Dialect.POSTGRESQL);
319+
assumeTrue(dialect == Dialect.POSTGRESQL);
319320

320321
final String sql =
321322
"/* This is a comment /* This is still a comment */ this is unterminated SELECT 1";
@@ -334,7 +335,7 @@ public void testPostgreSQLDialectUnterminatedComment() {
334335

335336
@Test
336337
public void testGoogleStandardSqlDialectDialectUnterminatedComment() {
337-
Assume.assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);
338+
assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);
338339

339340
final String sql =
340341
"/* This is a comment /* This is still a comment */ this is unterminated SELECT 1";
@@ -360,7 +361,7 @@ public void testShowStatements() {
360361

361362
@Test
362363
public void testGoogleStandardSQLDialectStatementWithHashTagSingleLineComment() {
363-
Assume.assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);
364+
assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);
364365

365366
// Supports # based comments
366367
assertThat(
@@ -382,7 +383,7 @@ public void testGoogleStandardSQLDialectStatementWithHashTagSingleLineComment()
382383

383384
@Test
384385
public void testPostgreSQLDialectStatementWithHashTagSingleLineComment() {
385-
Assume.assumeTrue(dialect == Dialect.POSTGRESQL);
386+
assumeTrue(dialect == Dialect.POSTGRESQL);
386387

387388
// Does not support # based comments
388389
assertThat(
@@ -615,7 +616,7 @@ public void testIsQuery() {
615616

616617
@Test
617618
public void testGoogleStandardSQLDialectIsQuery_QueryHints() {
618-
Assume.assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);
619+
assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);
619620

620621
// Supports query hints, PostgreSQL dialect does NOT
621622
// Valid query hints.
@@ -663,7 +664,7 @@ public void testGoogleStandardSQLDialectIsQuery_QueryHints() {
663664

664665
@Test
665666
public void testIsUpdate_QueryHints() {
666-
Assume.assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);
667+
assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);
667668

668669
// Supports query hints, PostgreSQL dialect does NOT
669670
// Valid query hints.
@@ -1093,7 +1094,7 @@ public void testConvertPositionalParametersToNamedParametersWithGsqlException()
10931094

10941095
@Test
10951096
public void testGoogleStandardSQLDialectConvertPositionalParametersToNamedParameters() {
1096-
Assume.assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);
1097+
assumeTrue(dialect == Dialect.GOOGLE_STANDARD_SQL);
10971098

10981099
assertThat(
10991100
parser.convertPositionalParametersToNamedParameters(
@@ -1203,7 +1204,7 @@ public void testGoogleStandardSQLDialectConvertPositionalParametersToNamedParame
12031204

12041205
@Test
12051206
public void testPostgreSQLDialectDialectConvertPositionalParametersToNamedParameters() {
1206-
Assume.assumeTrue(dialect == Dialect.POSTGRESQL);
1207+
assumeTrue(dialect == Dialect.POSTGRESQL);
12071208

12081209
assertThat(
12091210
parser.convertPositionalParametersToNamedParameters(
@@ -1318,6 +1319,25 @@ public void testPostgreSQLDialectDialectConvertPositionalParametersToNamedParame
13181319
+ "and col8 between $12 and $13")));
13191320
}
13201321

1322+
@Test
1323+
public void testPostgreSQLGetQueryParameters() {
1324+
assumeTrue(dialect == Dialect.POSTGRESQL);
1325+
1326+
PostgreSQLStatementParser parser = (PostgreSQLStatementParser) this.parser;
1327+
assertEquals(ImmutableSet.of(), parser.getQueryParameters("select * from foo"));
1328+
assertEquals(
1329+
ImmutableSet.of("$1"), parser.getQueryParameters("select * from foo where bar=$1"));
1330+
assertEquals(
1331+
ImmutableSet.of("$1", "$2", "$3"),
1332+
parser.getQueryParameters("select $2 from foo where bar=$1 and baz=$3"));
1333+
assertEquals(
1334+
ImmutableSet.of("$1", "$3"),
1335+
parser.getQueryParameters("select '$2' from foo where bar=$1 and baz in ($1, $3)"));
1336+
assertEquals(
1337+
ImmutableSet.of("$1"),
1338+
parser.getQueryParameters("select '$2' from foo where bar=$1 and baz=$foo"));
1339+
}
1340+
13211341
private void assertUnclosedLiteral(String sql) {
13221342
try {
13231343
parser.convertPositionalParametersToNamedParameters('?', sql);

0 commit comments

Comments
 (0)