diff --git a/sqlalchemy_bigquery/__init__.py b/sqlalchemy_bigquery/__init__.py index 1e506125..567015ee 100644 --- a/sqlalchemy_bigquery/__init__.py +++ b/sqlalchemy_bigquery/__init__.py @@ -37,6 +37,7 @@ FLOAT64, INT64, INTEGER, + JSON, NUMERIC, RECORD, STRING, @@ -74,6 +75,7 @@ "FLOAT64", "INT64", "INTEGER", + "JSON", "NUMERIC", "RECORD", "STRING", diff --git a/sqlalchemy_bigquery/_json.py b/sqlalchemy_bigquery/_json.py new file mode 100644 index 00000000..e12ac6fb --- /dev/null +++ b/sqlalchemy_bigquery/_json.py @@ -0,0 +1,8 @@ +import sqlalchemy + + +class JSON(sqlalchemy.sql.sqltypes.JSON): + def bind_expression(self, bindvalue): + # JSON query parameters have type STRING + # This hook ensures that the rendered expression has type JSON + return sqlalchemy.func.PARSE_JSON(bindvalue, type_=self) diff --git a/sqlalchemy_bigquery/_types.py b/sqlalchemy_bigquery/_types.py index 8399e978..6a268ce9 100644 --- a/sqlalchemy_bigquery/_types.py +++ b/sqlalchemy_bigquery/_types.py @@ -27,6 +27,7 @@ except ImportError: # pragma: NO COVER pass +from ._json import JSON from ._struct import STRUCT _type_map = { @@ -41,6 +42,7 @@ "FLOAT": sqlalchemy.types.Float, "INT64": sqlalchemy.types.Integer, "INTEGER": sqlalchemy.types.Integer, + "JSON": JSON, "NUMERIC": sqlalchemy.types.Numeric, "RECORD": STRUCT, "STRING": sqlalchemy.types.String, @@ -61,6 +63,7 @@ FLOAT = _type_map["FLOAT"] INT64 = _type_map["INT64"] INTEGER = _type_map["INTEGER"] +JSON = _type_map["JSON"] NUMERIC = _type_map["NUMERIC"] RECORD = _type_map["RECORD"] STRING = _type_map["STRING"] diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index c36ca1b1..afca6b9e 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -59,7 +59,7 @@ import re from .parse_url import parse_url -from . import _helpers, _struct, _types +from . import _helpers, _json, _struct, _types import sqlalchemy_bigquery_vendored.sqlalchemy.postgresql.base as vendored_postgresql # Illegal characters is intended to be all characters that are not explicitly @@ -641,6 +641,17 @@ def visit_NUMERIC(self, type_, **kw): visit_DECIMAL = visit_NUMERIC + def visit_JSON(self, type_, **kw): + if isinstance( + kw.get("type_expression"), Column + ): # column def + return "JSON" + # FIXME: JSON is not a member of `SqlParameterScalarTypes` in the DBAPI + # For now, we hack around this by: + # - Rewriting the bindparam type to STRING + # - Applying a bind expression that converts the parameter back to JSON + return "STRING" + class BigQueryDDLCompiler(DDLCompiler): option_datatype_mapping = { @@ -1076,6 +1087,7 @@ class BigQueryDialect(DefaultDialect): sqlalchemy.sql.sqltypes.TIMESTAMP: BQTimestamp, sqlalchemy.sql.sqltypes.ARRAY: BQArray, sqlalchemy.sql.sqltypes.Enum: sqlalchemy.sql.sqltypes.Enum, + sqlalchemy.sql.sqltypes.JSON: _json.JSON, } def __init__( @@ -1086,6 +1098,8 @@ def __init__( credentials_info=None, credentials_base64=None, list_tables_page_size=1000, + json_serializer=None, + json_deserializer=None, *args, **kwargs, ): @@ -1098,6 +1112,8 @@ def __init__( self.identifier_preparer = self.preparer(self) self.dataset_id = None self.list_tables_page_size = list_tables_page_size + self._json_serializer = json_serializer + self._json_deserializer = json_deserializer @classmethod def dbapi(cls): diff --git a/tests/unit/test__json.py b/tests/unit/test__json.py new file mode 100644 index 00000000..117da42e --- /dev/null +++ b/tests/unit/test__json.py @@ -0,0 +1,64 @@ +import json +from unittest import mock + +import pytest +import sqlalchemy + + +@pytest.fixture +def json_table(metadata): + from sqlalchemy_bigquery import JSON + + return sqlalchemy.Table("json_table", metadata, sqlalchemy.Column("json", JSON)) + + +@pytest.fixture +def json_data(): + return {"foo": "bar"} + + +def test_set_json_serde(faux_conn, metadata, json_table, json_data): + from sqlalchemy_bigquery import JSON + + json_serializer = mock.Mock(side_effect=json.dumps) + json_deserializer = mock.Mock(side_effect=json.loads) + + engine = sqlalchemy.create_engine( + "bigquery://myproject/mydataset", + json_serializer=json_serializer, + json_deserializer=json_deserializer, + ) + + json_column = json_table.c.json + + process_bind = json_column.type.bind_processor(engine.dialect) + process_bind(json_data) + assert json_serializer.mock_calls == [mock.call(json_data)] + + process_result = json_column.type.result_processor(engine.dialect, JSON) + process_result(json.dumps(json_data)) + assert json_deserializer.mock_calls == [mock.call(json.dumps(json_data))] + + +def test_json_create(faux_conn, metadata, json_table, json_data): + expr = sqlalchemy.schema.CreateTable(json_table) + sql = expr.compile(faux_conn.engine).string + assert sql == "\nCREATE TABLE `json_table` (\n\t`json` JSON\n) \n\n" + + +def test_json_insert(faux_conn, metadata, json_table, json_data): + expr = sqlalchemy.insert(json_table).values(json=json_data) + sql = expr.compile(faux_conn.engine).string + assert ( + sql == "INSERT INTO `json_table` (`json`) VALUES (PARSE_JSON(%(json:STRING)s))" + ) + + +def test_json_where(faux_conn, metadata, json_table, json_data): + expr = sqlalchemy.select(json_table.c.json).where(json_table.c.json == json_data) + sql = expr.compile(faux_conn.engine).string + assert sql == ( + "SELECT `json_table`.`json` \n" + "FROM `json_table` \n" + "WHERE `json_table`.`json` = PARSE_JSON(%(json_1:STRING)s)" + )