From f8abd57e8c626be7b515e1c808c80ec09f9a1cb8 Mon Sep 17 00:00:00 2001 From: Newbie012 Date: Sat, 7 Dec 2024 22:57:58 +0200 Subject: [PATCH] fix: resolve enum comparison with string literals This patch addresses an issue where the SafeQL ESLint plugin would fail when comparing enums with string literals. Additional test cases have been added to ensure proper functionality in various SQL contexts. --- .changeset/gentle-trainers-beam.md | 5 + .../eslint-plugin/src/rules/check-sql.test.ts | 92 +++++++- .../src/utils/query-context.test.ts | 205 ++++++++++++++++++ .../eslint-plugin/src/utils/query-context.ts | 132 +++++++++++ .../eslint-plugin/src/utils/ts-pg.utils.ts | 11 +- 5 files changed, 441 insertions(+), 4 deletions(-) create mode 100644 .changeset/gentle-trainers-beam.md create mode 100644 packages/eslint-plugin/src/utils/query-context.test.ts create mode 100644 packages/eslint-plugin/src/utils/query-context.ts diff --git a/.changeset/gentle-trainers-beam.md b/.changeset/gentle-trainers-beam.md new file mode 100644 index 00000000..2a74e566 --- /dev/null +++ b/.changeset/gentle-trainers-beam.md @@ -0,0 +1,5 @@ +--- +"@ts-safeql/eslint-plugin": patch +--- + +fixed an issue where safeql would fail when comparing enum with string literals diff --git a/packages/eslint-plugin/src/rules/check-sql.test.ts b/packages/eslint-plugin/src/rules/check-sql.test.ts index 105cb071..151945e0 100644 --- a/packages/eslint-plugin/src/rules/check-sql.test.ts +++ b/packages/eslint-plugin/src/rules/check-sql.test.ts @@ -5,12 +5,12 @@ import { } from "@ts-safeql/test-utils"; import { InvalidTestCase, RuleTester } from "@typescript-eslint/rule-tester"; -import { afterAll, beforeAll, describe, it } from "vitest"; +import { normalizeIndent } from "@ts-safeql/shared"; import path from "path"; import { Sql } from "postgres"; +import { afterAll, beforeAll, describe, it } from "vitest"; import rules from "."; import { RuleOptionConnection, RuleOptions } from "./RuleOptions"; -import { normalizeIndent } from "@ts-safeql/shared"; const tsconfigRootDir = path.resolve(__dirname, "../../"); const project = "tsconfig.json"; @@ -57,6 +57,11 @@ const runMigrations1 = >(sql: Sql agency_id INT NOT NULL REFERENCES agency(id) ); + CREATE TABLE certification_metadata ( + id INTEGER PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + certification certification NOT NULL + ); + CREATE TABLE test_date_column ( id INTEGER PRIMARY KEY GENERATED ALWAYS AS IDENTITY, date_col DATE NOT NULL, @@ -837,6 +842,79 @@ RuleTester.describe("check-sql", () => { ], }); + ruleTester.run("one-of transformation", rules["check-sql"], { + valid: [ + { + filename, + name: "control", + options: withConnection(connections.withTag), + code: normalizeIndent` + function union(cert: "HHA" | "RN", cert2: "LPN" | "CNA") { + return sql\`SELECT FROM caregiver WHERE certification = \${cert}\` + } + `, + }, + { + filename, + name: "control", + options: withConnection(connections.withTag), + code: normalizeIndent` + function union(cert: "HHA" | "RN", cert2: "LPN" | "CNA") { + return sql\`UPDATE caregiver SET certification = \${cert}::certification WHERE id = 1\` + } + `, + }, + { + filename, + name: "join context", + options: withConnection(connections.withTag), + code: normalizeIndent` + function joinTest(cert: "HHA" | "RN") { + return sql\`SELECT FROM caregiver c JOIN certification_metadata ct ON c.certification = \${cert}\` + } + `, + }, + { + filename, + name: "case context", + options: withConnection(connections.withTag), + code: normalizeIndent` + function caseTest(cert: "HHA" | "RN") { + return sql<{ is_certified: number }>\` + SELECT CASE WHEN certification = \${cert} THEN 1 ELSE 0 END AS is_certified + FROM caregiver + \` + } + `, + }, + { + filename, + name: "having context", + options: withConnection(connections.withTag), + code: normalizeIndent` + function havingTest(cert: "HHA" | "RN") { + return sql\`SELECT FROM caregiver GROUP BY certification HAVING certification = \${cert}\` + } + `, + }, + { + filename, + name: "returning context", + options: withConnection(connections.withTag), + code: normalizeIndent` + function returningTest(cert: "HHA" | "RN") { + return sql<{ one_of: boolean | null }>\` + UPDATE caregiver + SET id = DEFAULT + WHERE FALSE + RETURNING certification = \${cert} AS one_of\` + } + `, + }, + ], + invalid: [], + }); + ruleTester.run("position", rules["check-sql"], { valid: [ { @@ -907,6 +985,16 @@ RuleTester.describe("check-sql", () => { line: 2, columns: [61, 69], }), + invalidPositionTestCase({ + code: normalizeIndent` + function run(cert: "HHA" | "RN'") { + return sql\`select id from caregiver where certification = \${cert}\` + } + `, + error: `invalid input value for enum certification: "RN'"`, + line: 2, + columns: [61, 68], + }), invalidPositionTestCase({ code: "sql`select id, id from caregiver`", error: `Duplicate columns: caregiver.id, caregiver.id`, diff --git a/packages/eslint-plugin/src/utils/query-context.test.ts b/packages/eslint-plugin/src/utils/query-context.test.ts new file mode 100644 index 00000000..14b72242 --- /dev/null +++ b/packages/eslint-plugin/src/utils/query-context.test.ts @@ -0,0 +1,205 @@ +import { normalizeIndent } from "@ts-safeql/shared"; +import { describe, expect, it } from "vitest"; +import { getQueryContext } from "./query-context"; + +describe("getQueryContext", () => { + it("should handle queries with no keywords", () => { + const query = ""; + + expect(getQueryContext(query)).toMatchInlineSnapshot(` + [] + `); + }); + + it("should parse queries with unusual capitalization", () => { + const query = "SeLeCt * FrOm TBL WhErE ID = 10"; + + expect(getQueryContext(query)).toMatchInlineSnapshot(` + [ + "SELECT", + "FROM", + "WHERE", + ] + `); + }); + + it("should handle queries with comments", () => { + const query = ` + SELECT id, name -- Select columns + FROM people -- From table + WHERE age > 30 /* age filter */ + `; + + expect(getQueryContext(query)).toMatchInlineSnapshot(` + [ + "SELECT", + "FROM", + "WHERE", + ] + `); + }); + + it("should parse queries with UNION", () => { + const query = normalizeIndent` + SELECT name FROM tbl1 + UNION + SELECT name FROM tbl2 + ORDER BY name + `; + + expect(getQueryContext(query)).toMatchInlineSnapshot(` + [ + "SELECT", + "FROM", + "UNION", + "SELECT", + "FROM", + "ORDER BY", + ] + `); + }); + + it("should parse queries with JOINs", () => { + const query = normalizeIndent` + SELECT a.name, b.age + FROM tbl1 a + INNER JOIN tbl2 b ON a.id = b.id + WHERE b.age > 30 + `; + + expect(getQueryContext(query)).toMatchInlineSnapshot(` + [ + "SELECT", + "FROM", + "INNER JOIN", + "ON", + "WHERE", + ] + `); + }); + + it("should parse queries with nested functions", () => { + const query = normalizeIndent` + SELECT id, COUNT(*) AS total + FROM ( + SELECT id FROM tbl WHERE col = 5 + ) subquery + GROUP BY id + `; + + expect(getQueryContext(query)).toMatchInlineSnapshot(` + [ + "SELECT", + "FROM", + [ + "SELECT", + "FROM", + "WHERE", + ], + "GROUP BY", + ] + `); + }); + + it("should handle queries with placeholders", () => { + const query = "SELECT * FROM tbl WHERE id = $1 AND name = $2"; + + expect(getQueryContext(query)).toMatchInlineSnapshot(` + [ + "SELECT", + "FROM", + "WHERE", + ] + `); + }); + + it("should parse queries with CASE statements", () => { + const query = normalizeIndent` + SELECT id, + CASE WHEN col1 = 1 THEN 'A' + WHEN col2 = 2 THEN 'B' + ELSE 'C' END AS category + FROM tbl + `; + + expect(getQueryContext(query)).toMatchInlineSnapshot(` + [ + "SELECT", + "FROM", + ] + `); + }); + + it("should parse queries with window functions", () => { + const query = normalizeIndent` + SELECT id, ROW_NUMBER() OVER (PARTITION BY category ORDER BY created_at) AS row_num + FROM tbl + `; + + expect(getQueryContext(query)).toMatchInlineSnapshot(` + [ + "SELECT", + [ + "PARTITION BY", + "ORDER BY", + ], + "FROM", + ] + `); + }); + + it("should parse queries with complex expressions in SELECT", () => { + const query = "SELECT id, (col1 + col2) * col3 AS result FROM tbl"; + + expect(getQueryContext(query)).toMatchInlineSnapshot(` + [ + "SELECT", + "FROM", + ] + `); + }); + + it("should parse queries with DISTINCT ON", () => { + const query = "SELECT DISTINCT ON (col1) col1, col2 FROM tbl ORDER BY col1, col2"; + + expect(getQueryContext(query)).toMatchInlineSnapshot(` + [ + "SELECT", + "ON", + "FROM", + "ORDER BY", + ] + `); + }); + + it("should parse queries with multiple WITH clauses", () => { + const query = normalizeIndent` + WITH cte1 AS ( + SELECT id FROM tbl1 + ), + cte2 AS ( + SELECT id FROM tbl2 + ) + SELECT * FROM cte1 + INNER JOIN cte2 ON cte1.id = cte2.id + `; + + expect(getQueryContext(query)).toMatchInlineSnapshot(` + [ + "WITH", + [ + "SELECT", + "FROM", + ], + [ + "SELECT", + "FROM", + ], + "SELECT", + "FROM", + "INNER JOIN", + "ON", + ] + `); + }); +}); diff --git a/packages/eslint-plugin/src/utils/query-context.ts b/packages/eslint-plugin/src/utils/query-context.ts new file mode 100644 index 00000000..d6edeb1f --- /dev/null +++ b/packages/eslint-plugin/src/utils/query-context.ts @@ -0,0 +1,132 @@ +const keywords = [ + "WITH", + "SELECT", + "FROM", + "WHERE", + "GROUP BY", + "HAVING", + "WINDOW", + "ORDER BY", + "PARTITION BY", + "LIMIT", + "OFFSET", + "INSERT INTO", + "VALUES", + "UPDATE", + "SET", + "RETURNING", + "ON", + "JOIN", + "INNER JOIN", + "LEFT JOIN", + "RIGHT JOIN", + "FULL JOIN", + "FULL OUTER JOIN", + "CROSS JOIN", + "WHEN", + "USING", + "UNION", + "UNION ALL", + "INTERSECT", + "EXCEPT", +] as const; + +const keywordSet = new Set(keywords); +type Keyword = (typeof keywords)[number]; +type Context = (Keyword | Context)[]; + +export function isLastQueryContextOneOf(queryText: string, keywords: Keyword[]): boolean { + const contextKeywords = getLastQueryContext(queryText); + const lastKeyword = contextKeywords[contextKeywords.length - 1]; + + return keywords.some((keyword) => keyword === lastKeyword); +} + +export function getLastQueryContext(queryText: string): Keyword[] { + const context = getQueryContext(queryText); + + const iterate = (ctx: Context): Keyword[] => { + const last = ctx[ctx.length - 1]; + + if (Array.isArray(last)) { + return iterate(last); + } + + return ctx as Keyword[]; + }; + + return iterate(context); +} + +export function getQueryContext(queryText: string): Context { + const tokens = removePgComments(queryText) + .split(/(\s+|\(|\))/) + .filter((token) => token.trim() !== ""); + let index = 0; + + function parseQuery(): Context { + const context: Context = []; + + while (index < tokens.length) { + const token = tokens[index++].toUpperCase(); + + if (token === ")") { + // End of the current query context + return context; + } + + if (token === "(") { + // Start of a subquery + const subquery = parseQuery(); + if (subquery.length > 0) { + context.push(subquery); // Add valid subquery + } + continue; + } + + const previousToken = tokens[index - 2]?.toUpperCase(); + const nextToken = tokens[index]?.toUpperCase(); + + if (isOneOf(["ORDER", "GROUP", "PARTITION"], token) && nextToken === "BY") { + index++; // Consume "BY" + context.push(`${token} BY`); + continue; + } + + if (token === "JOIN") { + switch (previousToken) { + case "INNER": + case "LEFT": + case "RIGHT": + case "FULL": + case "CROSS": + context.push(`${previousToken} JOIN` as Keyword); + break; + case "OUTER": + context.push("FULL OUTER JOIN"); + break; + } + continue; + } + + if (keywordSet.has(token as Keyword)) { + context.push(token as Keyword); + continue; + } + + // Skip non-keyword tokens (identifiers, literals, etc.) + } + + return context; + } + + return parseQuery(); +} + +function removePgComments(query: string) { + return query.replace(/--.*(\r?\n|$)|\/\*[\s\S]*?\*\//g, "").trim(); +} + +function isOneOf(values: T[], value: string): value is T { + return values.includes(value as T); +} diff --git a/packages/eslint-plugin/src/utils/ts-pg.utils.ts b/packages/eslint-plugin/src/utils/ts-pg.utils.ts index 30e1c7c1..07e5215b 100644 --- a/packages/eslint-plugin/src/utils/ts-pg.utils.ts +++ b/packages/eslint-plugin/src/utils/ts-pg.utils.ts @@ -11,6 +11,7 @@ import ts, { TypeChecker } from "typescript"; import { RuleOptionConnection } from "../rules/RuleOptions"; import { E, pipe } from "./fp-ts"; import { TSUtils } from "./ts.utils"; +import { isLastQueryContextOneOf } from "./query-context"; export function mapTemplateLiteralToQueryText( quasi: TSESTree.TemplateLiteral, @@ -85,9 +86,15 @@ export function mapTemplateLiteralToQueryText( continue; } - if (pgTypeValue.kind === "one-of" && $queryText.trimEnd().endsWith("=")) { + const escapePgValue = (text: string) => text.replace(/'/g, "''"); + + if ( + pgTypeValue.kind === "one-of" && + $queryText.trimEnd().endsWith("=") && + isLastQueryContextOneOf($queryText, ["SELECT", "ON", "WHERE", "WHEN", "HAVING", "RETURNING"]) + ) { const textFromEquals = $queryText.slice($queryText.lastIndexOf("=")); - const placeholder = `IN (${pgTypeValue.types.map((t) => `'${t}'`).join(", ")})`; + const placeholder = `IN (${pgTypeValue.types.map((t) => `'${escapePgValue(t)}'`).join(", ")})`; const expressionText = sourceCode.text.slice( expression.range[0] - 2, expression.range[1] + 1,