From 140ca3a795a903268ad004e5109c916437696c21 Mon Sep 17 00:00:00 2001 From: Robert Jensen Date: Tue, 10 Dec 2024 18:28:41 -0500 Subject: [PATCH 1/5] feat: add JSON type, bindparam support --- sqlalchemy_bigquery/__init__.py | 4 +-- sqlalchemy_bigquery/_json.py | 8 +++++ sqlalchemy_bigquery/_types.py | 3 ++ sqlalchemy_bigquery/base.py | 17 ++++++++- tests/unit/test__json.py | 64 +++++++++++++++++++++++++++++++++ 5 files changed, 93 insertions(+), 3 deletions(-) create mode 100644 sqlalchemy_bigquery/_json.py create mode 100644 tests/unit/test__json.py diff --git a/sqlalchemy_bigquery/__init__.py b/sqlalchemy_bigquery/__init__.py index 1e506125..7ec79a29 100644 --- a/sqlalchemy_bigquery/__init__.py +++ b/sqlalchemy_bigquery/__init__.py @@ -23,7 +23,6 @@ import warnings from .version import __version__ - from .base import BigQueryDialect, dialect from ._types import ( ARRAY, @@ -37,6 +36,7 @@ FLOAT64, INT64, INTEGER, + JSON, NUMERIC, RECORD, STRING, @@ -44,7 +44,6 @@ TIME, TIMESTAMP, ) - from . import _versions_helpers sys_major, sys_minor, sys_micro = _versions_helpers.extract_runtime_version() @@ -74,6 +73,7 @@ "FLOAT64", "INT64", "INTEGER", + "JSON", "NUMERIC", "RECORD", "STRING", diff --git a/sqlalchemy_bigquery/_json.py b/sqlalchemy_bigquery/_json.py new file mode 100644 index 00000000..e12ac6fb --- /dev/null +++ b/sqlalchemy_bigquery/_json.py @@ -0,0 +1,8 @@ +import sqlalchemy + + +class JSON(sqlalchemy.sql.sqltypes.JSON): + def bind_expression(self, bindvalue): + # JSON query parameters have type STRING + # This hook ensures that the rendered expression has type JSON + return sqlalchemy.func.PARSE_JSON(bindvalue, type_=self) diff --git a/sqlalchemy_bigquery/_types.py b/sqlalchemy_bigquery/_types.py index 8399e978..6a268ce9 100644 --- a/sqlalchemy_bigquery/_types.py +++ b/sqlalchemy_bigquery/_types.py @@ -27,6 +27,7 @@ except ImportError: # pragma: NO COVER pass +from ._json import JSON from ._struct import STRUCT _type_map = { @@ -41,6 +42,7 @@ "FLOAT": sqlalchemy.types.Float, "INT64": sqlalchemy.types.Integer, "INTEGER": sqlalchemy.types.Integer, + "JSON": JSON, "NUMERIC": sqlalchemy.types.Numeric, "RECORD": STRUCT, "STRING": sqlalchemy.types.String, @@ -61,6 +63,7 @@ FLOAT = _type_map["FLOAT"] INT64 = _type_map["INT64"] INTEGER = _type_map["INTEGER"] +JSON = _type_map["JSON"] NUMERIC = _type_map["NUMERIC"] RECORD = _type_map["RECORD"] STRING = _type_map["STRING"] diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index c36ca1b1..236f1ab3 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -59,7 +59,7 @@ import re from .parse_url import parse_url -from . import _helpers, _struct, _types +from . import _helpers, _json, _struct, _types import sqlalchemy_bigquery_vendored.sqlalchemy.postgresql.base as vendored_postgresql # Illegal characters is intended to be all characters that are not explicitly @@ -547,6 +547,13 @@ def visit_bindparam( bq_type = self.dialect.type_compiler.process(type_) bq_type = self.__remove_type_parameter(bq_type) + if bq_type == "JSON": + # FIXME: JSON is not a member of `SqlParameterScalarTypes` in the DBAPI + # For now, we hack around this by: + # - Rewriting the bindparam type to STRING + # - Applying a bind expression that converts the parameter back to JSON + bq_type = "STRING" + assert_(param != "%s", f"Unexpected param: {param}") if bindparam.expanding: # pragma: NO COVER @@ -641,6 +648,9 @@ def visit_NUMERIC(self, type_, **kw): visit_DECIMAL = visit_NUMERIC + def visit_JSON(self, type_, **kw): + return "JSON" + class BigQueryDDLCompiler(DDLCompiler): option_datatype_mapping = { @@ -1076,6 +1086,7 @@ class BigQueryDialect(DefaultDialect): sqlalchemy.sql.sqltypes.TIMESTAMP: BQTimestamp, sqlalchemy.sql.sqltypes.ARRAY: BQArray, sqlalchemy.sql.sqltypes.Enum: sqlalchemy.sql.sqltypes.Enum, + sqlalchemy.sql.sqltypes.JSON: _json.JSON, } def __init__( @@ -1086,6 +1097,8 @@ def __init__( credentials_info=None, credentials_base64=None, list_tables_page_size=1000, + json_serializer=None, + json_deserializer=None, *args, **kwargs, ): @@ -1098,6 +1111,8 @@ def __init__( self.identifier_preparer = self.preparer(self) self.dataset_id = None self.list_tables_page_size = list_tables_page_size + self._json_serializer = json_serializer + self._json_deserializer = json_deserializer @classmethod def dbapi(cls): diff --git a/tests/unit/test__json.py b/tests/unit/test__json.py new file mode 100644 index 00000000..0667b13d --- /dev/null +++ b/tests/unit/test__json.py @@ -0,0 +1,64 @@ +import json +from unittest import mock + +import pytest +import sqlalchemy + + +@pytest.fixture +def json_table(metadata): + from sqlalchemy_bigquery import JSON + + return sqlalchemy.Table("json_table", metadata, sqlalchemy.Column("json", JSON)) + + +@pytest.fixture +def json_data(): + return {"foo": "bar"} + + +def test_set_json_serde(faux_conn, metadata, json_table, json_data): + from sqlalchemy_bigquery import JSON + + json_serializer = mock.Mock(side_effect=json.dumps) + json_deserializer = mock.Mock(side_effect=json.loads) + + engine = sqlalchemy.create_engine( + "bigquery://myproject/mydataset", + json_serializer=json_serializer, + json_deserializer=json_deserializer, + ) + + json_column = json_table.c.json + + process_bind = json_column.type.bind_processor(engine.dialect) + process_bind(json_data) + assert json_serializer.mock_calls == [mock.call(json_data)] + + process_result = json_column.type.result_processor(engine.dialect, JSON) + process_result(json.dumps(json_data)) + assert json_deserializer.mock_calls == [mock.call(json.dumps(json_data))] + + +def test_json_create(faux_conn, metadata, json_table, json_data): + expr = sqlalchemy.schema.CreateTable(json_table) + sql = expr.compile(faux_conn.engine).string + assert sql == ("\nCREATE TABLE `json_table` (\n" "\t`json` JSON\n" ") \n\n") + + +def test_json_insert(faux_conn, metadata, json_table, json_data): + expr = sqlalchemy.insert(json_table).values(json=json_data) + sql = expr.compile(faux_conn.engine).string + assert ( + sql == "INSERT INTO `json_table` (`json`) VALUES (PARSE_JSON(%(json:STRING)s))" + ) + + +def test_json_where(faux_conn, metadata, json_table, json_data): + expr = sqlalchemy.select(json_table.c.json).where(json_table.c.json == json_data) + sql = expr.compile(faux_conn.engine).string + assert sql == ( + "SELECT `json_table`.`json` \n" + "FROM `json_table` \n" + "WHERE `json_table`.`json` = PARSE_JSON(%(json_1:STRING)s)" + ) From d0fd734ff9c7cdb5be497dbcb8b79cbeccdea9a3 Mon Sep 17 00:00:00 2001 From: Robert Jensen Date: Tue, 10 Dec 2024 19:31:06 -0500 Subject: [PATCH 2/5] formatting fixes --- sqlalchemy_bigquery/__init__.py | 2 ++ tests/unit/test__json.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/sqlalchemy_bigquery/__init__.py b/sqlalchemy_bigquery/__init__.py index 7ec79a29..567015ee 100644 --- a/sqlalchemy_bigquery/__init__.py +++ b/sqlalchemy_bigquery/__init__.py @@ -23,6 +23,7 @@ import warnings from .version import __version__ + from .base import BigQueryDialect, dialect from ._types import ( ARRAY, @@ -44,6 +45,7 @@ TIME, TIMESTAMP, ) + from . import _versions_helpers sys_major, sys_minor, sys_micro = _versions_helpers.extract_runtime_version() diff --git a/tests/unit/test__json.py b/tests/unit/test__json.py index 0667b13d..117da42e 100644 --- a/tests/unit/test__json.py +++ b/tests/unit/test__json.py @@ -43,7 +43,7 @@ def test_set_json_serde(faux_conn, metadata, json_table, json_data): def test_json_create(faux_conn, metadata, json_table, json_data): expr = sqlalchemy.schema.CreateTable(json_table) sql = expr.compile(faux_conn.engine).string - assert sql == ("\nCREATE TABLE `json_table` (\n" "\t`json` JSON\n" ") \n\n") + assert sql == "\nCREATE TABLE `json_table` (\n\t`json` JSON\n) \n\n" def test_json_insert(faux_conn, metadata, json_table, json_data): From 65d676d3a06b768bb6cdf18f5fbb9522da1b58c8 Mon Sep 17 00:00:00 2001 From: Robert Jensen Date: Wed, 11 Dec 2024 16:42:50 -0500 Subject: [PATCH 3/5] move bindparam workaround to visit_JSON --- sqlalchemy_bigquery/base.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index 236f1ab3..c1f3f0db 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -548,10 +548,7 @@ def visit_bindparam( bq_type = self.__remove_type_parameter(bq_type) if bq_type == "JSON": - # FIXME: JSON is not a member of `SqlParameterScalarTypes` in the DBAPI - # For now, we hack around this by: - # - Rewriting the bindparam type to STRING - # - Applying a bind expression that converts the parameter back to JSON + bq_type = "STRING" assert_(param != "%s", f"Unexpected param: {param}") @@ -649,6 +646,14 @@ def visit_NUMERIC(self, type_, **kw): visit_DECIMAL = visit_NUMERIC def visit_JSON(self, type_, **kw): + if isinstance( + kw.get("type_expression"), sqlalchemy.sql.expression.BindParameter + ): # bindparam + # FIXME: JSON is not a member of `SqlParameterScalarTypes` in the DBAPI + # For now, we hack around this by: + # - Rewriting the bindparam type to STRING + # - Applying a bind expression that converts the parameter back to JSON + return "STRING" return "JSON" From b1680803a8290470f0b67de7227ec42887640d9b Mon Sep 17 00:00:00 2001 From: Robert Jensen Date: Wed, 11 Dec 2024 16:45:51 -0500 Subject: [PATCH 4/5] Revert "move bindparam workaround to visit_JSON" This reverts commit 65d676d3a06b768bb6cdf18f5fbb9522da1b58c8. --- sqlalchemy_bigquery/base.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index c1f3f0db..236f1ab3 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -548,7 +548,10 @@ def visit_bindparam( bq_type = self.__remove_type_parameter(bq_type) if bq_type == "JSON": - + # FIXME: JSON is not a member of `SqlParameterScalarTypes` in the DBAPI + # For now, we hack around this by: + # - Rewriting the bindparam type to STRING + # - Applying a bind expression that converts the parameter back to JSON bq_type = "STRING" assert_(param != "%s", f"Unexpected param: {param}") @@ -646,14 +649,6 @@ def visit_NUMERIC(self, type_, **kw): visit_DECIMAL = visit_NUMERIC def visit_JSON(self, type_, **kw): - if isinstance( - kw.get("type_expression"), sqlalchemy.sql.expression.BindParameter - ): # bindparam - # FIXME: JSON is not a member of `SqlParameterScalarTypes` in the DBAPI - # For now, we hack around this by: - # - Rewriting the bindparam type to STRING - # - Applying a bind expression that converts the parameter back to JSON - return "STRING" return "JSON" From 2a03b311dbb60ca61116abdc11212a574e00907c Mon Sep 17 00:00:00 2001 From: Robert Jensen Date: Wed, 11 Dec 2024 17:08:26 -0500 Subject: [PATCH 5/5] move bindparam workaround to visit_JSON better --- sqlalchemy_bigquery/base.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index 236f1ab3..afca6b9e 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -547,13 +547,6 @@ def visit_bindparam( bq_type = self.dialect.type_compiler.process(type_) bq_type = self.__remove_type_parameter(bq_type) - if bq_type == "JSON": - # FIXME: JSON is not a member of `SqlParameterScalarTypes` in the DBAPI - # For now, we hack around this by: - # - Rewriting the bindparam type to STRING - # - Applying a bind expression that converts the parameter back to JSON - bq_type = "STRING" - assert_(param != "%s", f"Unexpected param: {param}") if bindparam.expanding: # pragma: NO COVER @@ -649,7 +642,15 @@ def visit_NUMERIC(self, type_, **kw): visit_DECIMAL = visit_NUMERIC def visit_JSON(self, type_, **kw): - return "JSON" + if isinstance( + kw.get("type_expression"), Column + ): # column def + return "JSON" + # FIXME: JSON is not a member of `SqlParameterScalarTypes` in the DBAPI + # For now, we hack around this by: + # - Rewriting the bindparam type to STRING + # - Applying a bind expression that converts the parameter back to JSON + return "STRING" class BigQueryDDLCompiler(DDLCompiler):