Skip to content

Commit 226d29e

Browse files
authored
fix: correct type inference for arithmetic operations (#359)
This update addresses an issue where the inferred type was incorrect when dealing with arithmetic operations in SQL queries. The changes ensure that the type is accurately resolved based on the operation and operand types.
1 parent 8502e67 commit 226d29e

10 files changed

+1083
-222
lines changed

.changeset/chatty-hairs-hide.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@ts-safeql/generate": patch
3+
---
4+
5+
fixed an issue where the inferred typed was incorrect when dealing with arithmetic operations

packages/eslint-plugin/src/rules/check-sql.test.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -1290,7 +1290,7 @@ RuleTester.describe("check-sql", () => {
12901290
await sql<Caregiver[]>\`
12911291
SELECT
12921292
CASE WHEN caregiver.id IS NOT NULL
1293-
THEN jsonb_build_object('is_test', caregiver.middle_name NOT LIKE '%test%')
1293+
THEN jsonb_build_object('is_test', caregiver.first_name LIKE '%test%')
12941294
ELSE NULL
12951295
END AS meta
12961296
FROM

packages/generate/src/ast-decribe.utils.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,6 @@ export function isSingleCell<T>(arr: T[]): arr is [T] {
2828
return arr.length === 1;
2929
}
3030

31-
function isTuple<T>(arr: T[]): arr is [T, T] {
31+
export function isTuple<T>(arr: T[]): arr is [T, T] {
3232
return arr.length === 2;
3333
}

packages/generate/src/ast-describe.ts

+112-21
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
import { fmap, normalizeIndent } from "@ts-safeql/shared";
1+
import { defaultTypeExprMapping, fmap, normalizeIndent } from "@ts-safeql/shared";
22
import * as LibPgQueryAST from "@ts-safeql/sql-ast";
33
import {
44
isColumnStarRef,
55
isColumnTableColumnRef,
66
isColumnTableStarRef,
77
isColumnUnknownRef,
88
isSingleCell,
9+
isTuple,
910
} from "./ast-decribe.utils";
1011
import { ResolvedColumn, SourcesResolver, getSources } from "./ast-get-sources";
1112
import { PgColRow, PgEnumsMaps, PgTypesMap } from "./generate";
@@ -20,7 +21,7 @@ type ASTDescriptionOptions = {
2021
pgColsBySchemaAndTableName: Map<string, Map<string, PgColRow[]>>;
2122
pgTypes: PgTypesMap;
2223
pgEnums: PgEnumsMaps;
23-
pgFns: Map<string, string>;
24+
pgFns: Map<string, { ts: string; pg: string }>;
2425
};
2526

2627
type ASTDescriptionContext = ASTDescriptionOptions & {
@@ -38,7 +39,7 @@ export type ASTDescribedColumnType =
3839
| { kind: "union"; value: ASTDescribedColumnType[] }
3940
| { kind: "array"; value: ASTDescribedColumnType }
4041
| { kind: "object"; value: [string, ASTDescribedColumnType][] }
41-
| { kind: "type"; value: string }
42+
| { kind: "type"; value: string; type: string }
4243
| { kind: "literal"; value: string; base: ASTDescribedColumnType };
4344

4445
export function getASTDescription(params: ASTDescriptionOptions): Map<number, ASTDescribedColumn> {
@@ -82,20 +83,32 @@ export function getASTDescription(params: ASTDescriptionOptions): Map<number, AS
8283
p: { oid: number; baseOid: number | null } | { name: string },
8384
): ASTDescribedColumnType => {
8485
if ("name" in p) {
85-
return { kind: "type", value: params.typesMap.get(p.name)?.value ?? "unknown" };
86+
return {
87+
kind: "type",
88+
value: params.typesMap.get(p.name)?.value ?? "unknown",
89+
type: p.name,
90+
};
8691
}
8792

8893
const typeByOid = getTypeByOid(p.oid);
8994

9095
if (typeByOid.override) {
91-
const baseType: ASTDescribedColumnType = { kind: "type", value: typeByOid.value };
96+
const baseType: ASTDescribedColumnType = {
97+
kind: "type",
98+
value: typeByOid.value,
99+
type: params.pgTypes.get(p.oid)?.name ?? "unknown",
100+
};
92101
return typeByOid.isArray ? { kind: "array", value: baseType } : baseType;
93102
}
94103

95104
const typeByBaseOid = fmap(p.baseOid, getTypeByOid);
96105

97106
if (typeByBaseOid?.override === true) {
98-
const baseType: ASTDescribedColumnType = { kind: "type", value: typeByBaseOid.value };
107+
const baseType: ASTDescribedColumnType = {
108+
kind: "type",
109+
value: typeByBaseOid.value,
110+
type: params.pgTypes.get(p.baseOid!)?.name ?? "unknown",
111+
};
99112
return typeByBaseOid.isArray ? { kind: "array", value: baseType } : baseType;
100113
}
101114

@@ -104,13 +117,21 @@ export function getASTDescription(params: ASTDescriptionOptions): Map<number, AS
104117
if (enumValue !== undefined) {
105118
return {
106119
kind: "union",
107-
value: enumValue.values.map((value) => ({ kind: "type", value: `'${value}'` })),
120+
value: enumValue.values.map((value) => ({
121+
kind: "type",
122+
value: `'${value}'`,
123+
type: enumValue.name,
124+
})),
108125
};
109126
}
110127

111128
const { isArray, value } = typeByBaseOid ?? typeByOid;
112129

113-
const type: ASTDescribedColumnType = { kind: "type", value: value };
130+
const type: ASTDescribedColumnType = {
131+
kind: "type",
132+
value: value,
133+
type: params.pgTypes.get(p.oid)?.name ?? "unknown",
134+
};
114135

115136
return isArray ? { kind: "array", value: type } : type;
116137
},
@@ -215,15 +236,81 @@ function getDescribedNode(params: {
215236

216237
function getDescribedAExpr({
217238
alias,
239+
node,
218240
context,
219241
}: GetDescribedParamsOf<LibPgQueryAST.AExpr>): ASTDescribedColumn[] {
242+
const name = alias ?? "?column?";
243+
244+
if (node.lexpr === undefined && node.rexpr !== undefined) {
245+
const described = getDescribedNode({ alias, node: node.rexpr, context }).at(0);
246+
const type = fmap(described, (x) => getBaseType(x.type));
247+
248+
if (type === null) return [];
249+
250+
return [{ name, type }];
251+
}
252+
253+
if (node.lexpr === undefined || node.rexpr === undefined) {
254+
return [];
255+
}
256+
257+
const getResolvedNullableValueOrNull = (node: LibPgQueryAST.Node) => {
258+
const column = getDescribedNode({ alias: undefined, node, context }).at(0);
259+
260+
if (column === undefined) return null;
261+
262+
if (column.type.kind === "array") {
263+
return { value: "array", nullable: false };
264+
}
265+
266+
if (column.type.kind === "type") {
267+
return { value: column.type.type, nullable: false };
268+
}
269+
270+
if (column.type.kind === "literal" && column.type.base.kind === "type") {
271+
return { value: column.type.base.type, nullable: false };
272+
}
273+
274+
if (column.type.kind === "union" && isTuple(column.type.value)) {
275+
let nullable = false;
276+
let value: string | undefined = undefined;
277+
278+
for (const type of column.type.value) {
279+
if (type.kind !== "type") return null;
280+
if (type.value === "null") nullable = true;
281+
if (type.value !== "null") value = type.type;
282+
}
283+
284+
if (value === undefined) return null;
285+
286+
return { value, nullable };
287+
}
288+
289+
return null;
290+
};
291+
292+
const lnode = getResolvedNullableValueOrNull(node.lexpr);
293+
const rnode = getResolvedNullableValueOrNull(node.rexpr);
294+
295+
if (lnode === null || rnode === null) {
296+
return [];
297+
}
298+
299+
const operator = concatStringNodes(node.name);
300+
const resolved: string | undefined =
301+
defaultTypeExprMapping[`${lnode.value} ${operator} ${rnode.value}`];
302+
303+
if (resolved === undefined) {
304+
return [];
305+
}
306+
220307
return [
221308
{
222-
name: alias ?? "?column?",
309+
name: name,
223310
type: resolveType({
224311
context: context,
225-
nullable: false,
226-
type: context.toTypeScriptType({ name: "boolean" }),
312+
nullable: !context.nonNullableColumns.has(name) && (lnode.nullable || rnode.nullable),
313+
type: context.toTypeScriptType({ name: resolved }),
227314
}),
228315
},
229316
];
@@ -239,7 +326,7 @@ function getDescribedNullTest({
239326
type: resolveType({
240327
context: context,
241328
nullable: false,
242-
type: context.toTypeScriptType({ name: "boolean" }),
329+
type: context.toTypeScriptType({ name: "bool" }),
243330
}),
244331
},
245332
];
@@ -298,7 +385,7 @@ function getDescribedBoolExpr({
298385
type: resolveType({
299386
context: context,
300387
nullable: false,
301-
type: context.toTypeScriptType({ name: "boolean" }),
388+
type: context.toTypeScriptType({ name: "bool" }),
302389
}),
303390
},
304391
];
@@ -317,7 +404,7 @@ function getDescribedSubLink({
317404
nullable: false,
318405
type: (() => {
319406
if (node.subLinkType === LibPgQueryAST.SubLinkType.EXISTS_SUBLINK) {
320-
return context.toTypeScriptType({ name: "boolean" });
407+
return context.toTypeScriptType({ name: "bool" });
321408
}
322409

323410
return context.toTypeScriptType({ name: "unknown" });
@@ -412,7 +499,7 @@ function mergeDescribedColumnTypes(types: ASTDescribedColumnType[]): ASTDescribe
412499

413500
if (!seenSymbols.has("boolean") && seenSymbols.has("true") && seenSymbols.has("false")) {
414501
seenSymbols.add("boolean");
415-
result.push({ kind: "type", value: "boolean" });
502+
result.push({ kind: "type", value: "boolean", type: "bool" });
416503
}
417504

418505
if (seenSymbols.has("boolean") && (seenSymbols.has("true") || seenSymbols.has("false"))) {
@@ -537,15 +624,15 @@ function getDescribedFuncCallByPgFn({
537624

538625
const pgFnValue =
539626
args.length === 0
540-
? context.pgFns.get(functionName)
627+
? (context.pgFns.get(functionName) ?? context.pgFns.get(`${functionName}(string)`))
541628
: (context.pgFns.get(`${functionName}(${args.join(", ")})`) ??
542629
context.pgFns.get(`${functionName}(any)`) ??
543630
context.pgFns.get(`${functionName}(unknown)`));
544631

545632
const type = resolveType({
546633
context: context,
547634
nullable: !context.nonNullableColumns.has(name),
548-
type: { kind: "type", value: pgFnValue ?? "unknown" },
635+
type: { kind: "type", value: pgFnValue?.ts ?? "unknown", type: pgFnValue?.pg ?? "unknown" },
549636
});
550637

551638
return [{ name, type }];
@@ -758,7 +845,11 @@ function getDescribedColumnByResolvedColumns(params: {
758845
?.get(column.colName);
759846

760847
if (overridenType !== undefined) {
761-
return { kind: "type", value: overridenType };
848+
return {
849+
kind: "type",
850+
value: overridenType,
851+
type: params.context.pgTypes.get(column.colTypeOid)?.name ?? "unknown",
852+
};
762853
}
763854

764855
return params.context.toTypeScriptType({
@@ -789,7 +880,7 @@ function getDescribedAConst({
789880
return {
790881
kind: "literal",
791882
value: node.boolval.boolval ? "true" : "false",
792-
base: context.toTypeScriptType({ name: "boolean" }),
883+
base: context.toTypeScriptType({ name: "bool" }),
793884
};
794885
case node.bsval !== undefined:
795886
return context.toTypeScriptType({ name: "bytea" });
@@ -838,7 +929,7 @@ function asNonNullableType(type: ASTDescribedColumnType): ASTDescribedColumnType
838929
);
839930

840931
if (filtered.length === 0) {
841-
return { kind: "type", value: "unknown" };
932+
return { kind: "type", value: "unknown", type: "unknown" };
842933
}
843934

844935
if (filtered.length === 1) {
@@ -848,7 +939,7 @@ function asNonNullableType(type: ASTDescribedColumnType): ASTDescribedColumnType
848939
return { kind: "union", value: filtered };
849940
}
850941
case "type":
851-
return type.value === "null" ? { kind: "type", value: "unknown" } : type;
942+
return type.value === "null" ? { kind: "type", value: "unknown", type: "unknown" } : type;
852943
}
853944
}
854945

0 commit comments

Comments
 (0)