Skip to content

Commit

Permalink
fix: improve source resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
Newbie012 committed Jan 13, 2025
1 parent b63d77c commit 42326c2
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 98 deletions.
5 changes: 5 additions & 0 deletions .changeset/slow-stingrays-know.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@ts-safeql/generate": patch
---

fixed an issue when the wrong type was returned in some cases when using CTEs
17 changes: 9 additions & 8 deletions packages/generate/src/ast-describe.ts
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ function getDescribedCoalesceExpr({
.at(0);

if (type === undefined) {
return [unknownCoalesce];
return [];
}

return [
Expand Down Expand Up @@ -794,7 +794,7 @@ function getDescribedColumnRef({
return getDescribedColumnByResolvedColumns({
alias: alias,
context: context,
resolved: context.resolver.getAllResolvedColumns(),
resolved: context.resolver.getAllResolvedColumns().map((x) => x.column),
});
}

Expand Down Expand Up @@ -822,15 +822,16 @@ function getDescribedColumnRef({
}

if (isColumnTableColumnRef(node.fields)) {
const resolved = context.resolver.getColumnsByTargetField({
kind: "column",
table: node.fields[0].String.sval,
column: node.fields[1].String.sval,
});

return getDescribedColumnByResolvedColumns({
alias: alias,
context: context,
resolved:
context.resolver.getColumnsByTargetField({
kind: "column",
table: node.fields[0].String.sval,
column: node.fields[1].String.sval,
}) ?? [],
resolved: resolved ?? [],
});
}

Expand Down
212 changes: 122 additions & 90 deletions packages/generate/src/ast-get-sources.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export type SourcesResolver = ReturnType<typeof getSources>;

type SourcesOptions = {
select: LibPgQueryAST.SelectStmt;
prevSources?: Map<string, SelectSource>;
nonNullableColumns: Set<string>;
pgColsBySchemaAndTableName: Map<string, Map<string, PgColRow[]>>;
relations: FlattenedRelationWithJoins[];
Expand All @@ -17,51 +18,65 @@ export type ResolvedColumn = {
isNotNull: boolean;
};

type SelectSource =
| {
kind: "table";
schemaName: string;
name: string;
original: string;
alias?: string;
columns: ResolvedColumn[];
}
| { kind: "cte" | "subselect"; name: string; sources: SourcesResolver };

type TargetField =
| { kind: "unknown"; field: string }
| { kind: "column"; table: string; column: string };

export function getSources({
pgColsBySchemaAndTableName,
relations,
prevSources,
select,
nonNullableColumns,
}: SourcesOptions) {
const { columns, sources: sourcesEntries } = getColumnSources(select.fromClause ?? []);
const sources = new Map(sourcesEntries);

function getAllResolvedColumns() {
return columns.map((x) => resolveColumn(x.column));
const ctes = getColumnCTEs(select.withClause?.ctes ?? []);
const sources: Map<string, SelectSource> = new Map([
...(prevSources?.entries() ?? []),
...getColumnSources(select.fromClause ?? []).entries(),
]);

function getSourceColumns(source: SelectSource) {
switch (source.kind) {
case "cte":
case "subselect":
return source.sources.getAllResolvedColumns();
case "table":
return source.columns.map((column) => ({ column, source }));
}
}

function getResolvedColumnsInTable(sourceName: string) {
return columns.filter((x) => x.source.name === sourceName).map((x) => resolveColumn(x.column));
function getAllResolvedColumns(): { column: ResolvedColumn; source: SelectSource }[] {
return [...sources.values()].map(getSourceColumns).flat();
}

function getColumnByTableAndColumnName(p: { table: string; column: string }) {
const columnSource = columns.find((x) => {
if (x.column.colName !== p.column) {
return false;
}
function getResolvedColumnsInTable(sourceName: string): ResolvedColumn[] {
return fmap(sources.get(sourceName), getSourceColumns)?.map((x) => x.column) ?? [];
}

switch (x.source.kind) {
case "table":
return (x.source.alias ?? x.source.name) === p.table;
case "subselect":
return x.source.name === p.table;
}
});
function getColumnByTableAndColumnName(p: {
table: string;
column: string;
}): ResolvedColumn | null {
const source = sources.get(p.table);

if (columnSource === undefined) {
if (source === undefined) {
return null;
}

const resolved =
columnSource.source.kind === "table" && columnSource.source.alias !== undefined
? resolveColumn({ ...columnSource.column, tableName: columnSource.source.alias })
: resolveColumn(columnSource.column);
const resolved = getSourceColumns(source).find((x) => x.column.column.colName === p.column);

return resolved;
return resolved?.column ?? null;
}

function getColumnsByTargetField(field: TargetField): ResolvedColumn[] | null {
Expand All @@ -73,12 +88,12 @@ export function getSources({
const source = sources.get(field.field);

if (source !== undefined) {
return columns.filter((x) => x.source === source).map((x) => resolveColumn(x.column));
return getSourceColumns(source).map((x) => x.column);
}

for (const { column } of columns) {
if (column.colName === field.field) {
return [resolveColumn(column)];
for (const { column } of getAllResolvedColumns()) {
if (column.column.colName === field.field) {
return [column];
}
}

Expand All @@ -87,8 +102,8 @@ export function getSources({
}
}

function checkIsNullableDueToRelation(column: PgColRow) {
const findByJoin = relations.find((x) => (x.alias ?? x.joinRelName) === column.tableName);
function checkIsNullableDueToRelation(tableName: string): boolean {
const findByJoin = relations.find((x) => (x.alias ?? x.joinRelName) === tableName);

if (findByJoin !== undefined) {
switch (findByJoin.joinType) {
Expand All @@ -109,7 +124,7 @@ export function getSources({
}
}

const findByRel = relations.filter((x) => x.relName === column.tableName);
const findByRel = relations.filter((x) => x.relName === tableName);

for (const rel of findByRel) {
switch (rel.joinType) {
Expand All @@ -133,27 +148,17 @@ export function getSources({
return false;
}

function resolveColumn(col: PgColRow): ResolvedColumn {
const isNullableDueToRelation = checkIsNullableDueToRelation(col);
function resolveColumn(col: PgColRow, tableName: string): ResolvedColumn {
const isNullableDueToRelation = checkIsNullableDueToRelation(tableName);
const isNotNullBasedOnAST =
nonNullableColumns.has(col.colName) ||
nonNullableColumns.has(`${col.tableName}.${col.colName}`);
nonNullableColumns.has(col.colName) || nonNullableColumns.has(`${tableName}.${col.colName}`);
const isNotNullInTable = col.colNotNull;

const isNonNullable = isNotNullBasedOnAST || (isNotNullInTable && !isNullableDueToRelation);

return { column: col, isNotNull: isNonNullable };
}

type SelectSource =
| { kind: "table"; schemaName: string; name: string; original: string; alias?: string }
| { kind: "subselect"; name: string };

type ColumnWithSource = {
column: PgColRow;
source: SelectSource;
};

function resolveRangeVarSchema(node: LibPgQueryAST.RangeVar): string {
if (node.schemaname !== undefined) {
return node.schemaname;
Expand All @@ -172,66 +177,93 @@ export function getSources({
return "public";
}

function getColumnSources(nodes: LibPgQueryAST.Node[]): {
columns: ColumnWithSource[];
sources: [string, SelectSource][];
} {
const columns: ColumnWithSource[] = [];
const sources: [string, SelectSource][] = [];
function getColumnCTEs(ctes: LibPgQueryAST.Node[]): Map<string, SourcesResolver> {
const map = new Map<string, SourcesResolver>();

for (const node of nodes) {
if (node.RangeVar !== undefined) {
const source: SelectSource = {
kind: "table",
schemaName: resolveRangeVarSchema(node.RangeVar),
original: node.RangeVar.relname,
name: node.RangeVar.alias?.aliasname ?? node.RangeVar.relname,
alias: node.RangeVar.alias?.aliasname,
};
for (const cte of ctes) {
if (cte.CommonTableExpr?.ctequery?.SelectStmt === undefined) continue;
if (cte.CommonTableExpr?.ctename === undefined) continue;

sources.push([source.name, source]);
const resolver = getSources({
pgColsBySchemaAndTableName,
prevSources,
nonNullableColumns,
relations,
select: cte.CommonTableExpr.ctequery.SelectStmt,
});

for (const column of pgColsBySchemaAndTableName
.get(source.schemaName)
?.get(source.original) ?? []) {
columns.push({ column, source });
}
}
map.set(cte.CommonTableExpr.ctename, resolver);
}

if (node.JoinExpr?.larg !== undefined) {
const resolved = getColumnSources([node.JoinExpr.larg]);
columns.push(...resolved.columns);
sources.push(...resolved.sources);
}
return map;
}

if (node.JoinExpr?.rarg !== undefined) {
const resolved = getColumnSources([node.JoinExpr.rarg]);
columns.push(...resolved.columns);
sources.push(...resolved.sources);
function getNodeColumnAndSources(node: LibPgQueryAST.Node): SelectSource[] {
if (node.RangeVar !== undefined) {
const cte = ctes.get(node.RangeVar.relname);

if (cte !== undefined) {
return [{ kind: "cte", name: node.RangeVar.relname, sources: cte }];
}

if (node.RangeSubselect?.subquery?.SelectStmt?.fromClause !== undefined) {
const source: SelectSource = {
kind: "subselect",
name: node.RangeSubselect.alias?.aliasname ?? "subselect",
};
const schemaName = resolveRangeVarSchema(node.RangeVar);
const realTableName = node.RangeVar.relname;
const tableName = node.RangeVar.alias?.aliasname ?? realTableName;
const tableColumns = pgColsBySchemaAndTableName.get(schemaName)?.get(realTableName) ?? [];

sources.push([source.name, source]);
return [
{
kind: "table",
schemaName: schemaName,
original: realTableName,
name: node.RangeVar.alias?.aliasname ?? node.RangeVar.relname,
alias: node.RangeVar.alias?.aliasname,
columns: tableColumns.map((col) => resolveColumn(col, tableName)),
},
];
}

const resolvedColumns = getColumnSources(
node.RangeSubselect.subquery.SelectStmt.fromClause,
).columns.map((x) => x.column);
const sources: SelectSource[] = [];

for (const column of resolvedColumns) {
columns.push({ column, source });
}
}
if (node.JoinExpr?.larg !== undefined) {
sources.push(...getNodeColumnAndSources(node.JoinExpr.larg));
}

if (node.JoinExpr?.rarg !== undefined) {
sources.push(...getNodeColumnAndSources(node.JoinExpr.rarg));
}

return { columns, sources };
if (node.RangeSubselect?.subquery?.SelectStmt?.fromClause !== undefined) {
sources.push({
kind: "subselect",
name: node.RangeSubselect.alias?.aliasname ?? "subselect",
sources: getSources({
nonNullableColumns,
pgColsBySchemaAndTableName,
relations,
prevSources: new Map([
...(prevSources?.entries() ?? []),
...sources.map((x) => [x.name, x] as const),
]),
select: node.RangeSubselect.subquery.SelectStmt,
}),
});
}

return sources;
}

function getColumnSources(nodes: LibPgQueryAST.Node[]): Map<string, SelectSource> {
return new Map(
nodes
.map(getNodeColumnAndSources)
.flat()
.map((x) => [x.name, x]),
);
}

return {
getNodeColumnAndSources: getNodeColumnAndSources,
getResolvedColumnsInTable: getResolvedColumnsInTable,
getAllResolvedColumns: getAllResolvedColumns,
getColumnsByTargetField: getColumnsByTargetField,
Expand Down
Loading

0 comments on commit 42326c2

Please sign in to comment.