Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add overrides config option #83

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,47 @@ class Status(str, enum.Enum):
OPEN = "op!en"
CLOSED = "clo@sed"
```

### Override Column Types

Option: `overrides`

You can override the SQL to Python type mapping for specific columns using the `overrides` option. This is useful for columns with JSON data or other custom types.

Example configuration:

```yaml
options:
package: authors
emit_pydantic_models: true
overrides:
- column: "some_table.payload"
py_import: "my_lib.models"
py_type: "Payload"
```

This will:
1. Override the column `payload` in `some_table` to use the type `Payload`
2. Add an import for `my_lib.models` to the models file

Example output:

```python
# Code generated by sqlc. DO NOT EDIT.
# versions:
# sqlc v1.28.0

import datetime
import pydantic
from typing import Any

import my_lib.models


class SomeTable(pydantic.BaseModel):
id: int
created_at: datetime.datetime
payload: my_lib.models.Payload
```

This is similar to the [overrides functionality in the Go version of sqlc](https://docs.sqlc.dev/en/stable/howto/overrides.html#overriding-types).
25 changes: 16 additions & 9 deletions internal/config.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
package python

type OverrideColumn struct {
Column string `json:"column"`
PyType string `json:"py_type"`
PyImport string `json:"py_import"`
}

type Config struct {
EmitExactTableNames bool `json:"emit_exact_table_names"`
EmitSyncQuerier bool `json:"emit_sync_querier"`
EmitAsyncQuerier bool `json:"emit_async_querier"`
Package string `json:"package"`
Out string `json:"out"`
EmitPydanticModels bool `json:"emit_pydantic_models"`
EmitStrEnum bool `json:"emit_str_enum"`
QueryParameterLimit *int32 `json:"query_parameter_limit"`
InflectionExcludeTableNames []string `json:"inflection_exclude_table_names"`
EmitExactTableNames bool `json:"emit_exact_table_names"`
EmitSyncQuerier bool `json:"emit_sync_querier"`
EmitAsyncQuerier bool `json:"emit_async_querier"`
Package string `json:"package"`
Out string `json:"out"`
EmitPydanticModels bool `json:"emit_pydantic_models"`
EmitStrEnum bool `json:"emit_str_enum"`
QueryParameterLimit *int32 `json:"query_parameter_limit"`
InflectionExcludeTableNames []string `json:"inflection_exclude_table_names"`
Overrides []OverrideColumn `json:"overrides"`
}
2 changes: 1 addition & 1 deletion internal/endtoend/testdata/emit_pydantic_models/sqlc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ plugins:
- name: py
wasm:
url: file://../../../../bin/sqlc-gen-python.wasm
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
sql:
- schema: schema.sql
queries: query.sql
Expand Down
2 changes: 1 addition & 1 deletion internal/endtoend/testdata/emit_str_enum/sqlc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ plugins:
- name: py
wasm:
url: file://../../../../bin/sqlc-gen-python.wasm
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
sql:
- schema: schema.sql
queries: query.sql
Expand Down
11 changes: 11 additions & 0 deletions internal/endtoend/testdata/emit_type_overrides/db/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Code generated by sqlc. DO NOT EDIT.
# versions:
# sqlc v1.28.0
import pydantic

import my_lib.models


class Book(pydantic.BaseModel):
id: int
payload: my_lib.models.Payload
92 changes: 92 additions & 0 deletions internal/endtoend/testdata/emit_type_overrides/db/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Code generated by sqlc. DO NOT EDIT.
# versions:
# sqlc v1.28.0
# source: query.sql
from typing import AsyncIterator, Iterator, Optional

import my_lib.models
import sqlalchemy
import sqlalchemy.ext.asyncio

from db import models


CREATE_BOOK = """-- name: create_book \\:one
INSERT INTO books (payload)
VALUES (:p1)
RETURNING id, payload
"""


GET_BOOK = """-- name: get_book \\:one
SELECT id, payload FROM books
WHERE id = :p1 LIMIT 1
"""


LIST_BOOKS = """-- name: list_books \\:many
SELECT id, payload FROM books
ORDER BY id
"""


class Querier:
def __init__(self, conn: sqlalchemy.engine.Connection):
self._conn = conn

def create_book(self, *, payload: my_lib.models.Payload) -> Optional[models.Book]:
row = self._conn.execute(sqlalchemy.text(CREATE_BOOK), {"p1": payload}).first()
if row is None:
return None
return models.Book(
id=row[0],
payload=row[1],
)

def get_book(self, *, id: int) -> Optional[models.Book]:
row = self._conn.execute(sqlalchemy.text(GET_BOOK), {"p1": id}).first()
if row is None:
return None
return models.Book(
id=row[0],
payload=row[1],
)

def list_books(self) -> Iterator[models.Book]:
result = self._conn.execute(sqlalchemy.text(LIST_BOOKS))
for row in result:
yield models.Book(
id=row[0],
payload=row[1],
)


class AsyncQuerier:
def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection):
self._conn = conn

async def create_book(self, *, payload: my_lib.models.Payload) -> Optional[models.Book]:
row = (await self._conn.execute(sqlalchemy.text(CREATE_BOOK), {"p1": payload})).first()
if row is None:
return None
return models.Book(
id=row[0],
payload=row[1],
)

async def get_book(self, *, id: int) -> Optional[models.Book]:
row = (await self._conn.execute(sqlalchemy.text(GET_BOOK), {"p1": id})).first()
if row is None:
return None
return models.Book(
id=row[0],
payload=row[1],
)

async def list_books(self) -> AsyncIterator[models.Book]:
result = await self._conn.stream(sqlalchemy.text(LIST_BOOKS))
async for row in result:
yield models.Book(
id=row[0],
payload=row[1],
)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from datetime import date

from pydantic import BaseModel

class Payload(BaseModel):
name: str
release_date: date
12 changes: 12 additions & 0 deletions internal/endtoend/testdata/emit_type_overrides/query.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
-- name: GetBook :one
SELECT * FROM books
WHERE id = $1 LIMIT 1;

-- name: ListBooks :many
SELECT * FROM books
ORDER BY id;

-- name: CreateBook :one
INSERT INTO books (payload)
VALUES (sqlc.arg(payload))
RETURNING *;
4 changes: 4 additions & 0 deletions internal/endtoend/testdata/emit_type_overrides/schema.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
CREATE TABLE books (
id SERIAL PRIMARY KEY,
payload JSONB NOT NULL
);
22 changes: 22 additions & 0 deletions internal/endtoend/testdata/emit_type_overrides/sqlc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
version: "2"
plugins:
- name: py
wasm:
url: file://../../../../bin/sqlc-gen-python.wasm
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
sql:
- schema: schema.sql
queries: query.sql
engine: postgresql
codegen:
- plugin: py
out: db
options:
package: db
emit_pydantic_models: true
emit_sync_querier: true
emit_async_querier: true
overrides:
- column: "books.payload"
py_import: "my_lib.models"
py_type: "Payload"
2 changes: 1 addition & 1 deletion internal/endtoend/testdata/exec_result/sqlc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ plugins:
- name: py
wasm:
url: file://../../../../bin/sqlc-gen-python.wasm
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
sql:
- schema: schema.sql
queries: query.sql
Expand Down
2 changes: 1 addition & 1 deletion internal/endtoend/testdata/exec_rows/sqlc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ plugins:
- name: py
wasm:
url: file://../../../../bin/sqlc-gen-python.wasm
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
sql:
- schema: schema.sql
queries: query.sql
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ plugins:
- name: py
wasm:
url: file://../../../../bin/sqlc-gen-python.wasm
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
sql:
- schema: schema.sql
queries: query.sql
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ plugins:
- name: py
wasm:
url: file://../../../../bin/sqlc-gen-python.wasm
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
sql:
- schema: schema.sql
queries: query.sql
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ plugins:
- name: py
wasm:
url: file://../../../../bin/sqlc-gen-python.wasm
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
sql:
- schema: schema.sql
queries: query.sql
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ plugins:
- name: py
wasm:
url: file://../../../../bin/sqlc-gen-python.wasm
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
sql:
- schema: schema.sql
queries: query.sql
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ plugins:
- name: py
wasm:
url: file://../../../../bin/sqlc-gen-python.wasm
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
sql:
- schema: schema.sql
queries: query.sql
Expand Down
34 changes: 34 additions & 0 deletions internal/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,40 @@ func (q Query) ArgDictNode() *pyast.Node {
}

func makePyType(req *plugin.GenerateRequest, col *plugin.Column) pyType {
// Parse the configuration
var conf Config
if len(req.PluginOptions) > 0 {
if err := json.Unmarshal(req.PluginOptions, &conf); err != nil {
log.Printf("failed to parse plugin options: %s", err)
}
}

// Check for overrides
if len(conf.Overrides) > 0 && col.Table != nil {
tableName := col.Table.Name
if col.Table.Schema != "" && col.Table.Schema != req.Catalog.DefaultSchema {
tableName = col.Table.Schema + "." + tableName
}

// Look for a matching override
for _, override := range conf.Overrides {
overrideKey := tableName + "." + col.Name
if override.Column == overrideKey {
// Found a match, use the override
typeStr := override.PyType
if override.PyImport != "" && !strings.Contains(typeStr, ".") {
typeStr = override.PyImport + "." + override.PyType
}
return pyType{
InnerType: typeStr,
IsArray: col.IsArray,
IsNull: !col.NotNull,
}
}
}
}

// No override found, use the standard type mapping
typ := pyInnerType(req, col)
return pyType{
InnerType: typ,
Expand Down
28 changes: 28 additions & 0 deletions internal/imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,20 @@ func (i *importer) modelImportSpecs() (map[string]importSpec, map[string]importS

pkg := make(map[string]importSpec)

// Add custom imports from overrides
for _, override := range i.C.Overrides {
if override.PyImport != "" {
// Check if it's a standard module or a package import
if strings.Contains(override.PyImport, ".") {
// It's a package import
pkg[override.PyImport] = importSpec{Module: override.PyImport}
} else {
// It's a standard import
std[override.PyImport] = importSpec{Module: override.PyImport}
}
}
}

return std, pkg
}

Expand Down Expand Up @@ -167,6 +181,20 @@ func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map
}
}

// Add custom imports from overrides for query files
for _, override := range i.C.Overrides {
if override.PyImport != "" {
// Check if it's a standard module or a package import
if strings.Contains(override.PyImport, ".") {
// It's a package import
pkg[override.PyImport] = importSpec{Module: override.PyImport}
} else {
// It's a standard import
std[override.PyImport] = importSpec{Module: override.PyImport}
}
}
}

return std, pkg
}

Expand Down