diff --git a/databases/backends/psycopg.py b/databases/backends/psycopg.py index db9ec2fe..885e8336 100644 --- a/databases/backends/psycopg.py +++ b/databases/backends/psycopg.py @@ -3,10 +3,10 @@ import psycopg import psycopg_pool from sqlalchemy.engine.interfaces import Dialect +from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg from sqlalchemy.sql import ClauseElement from databases.backends.common.records import Record, create_column_maps -from databases.backends.dialects.psycopg import compile_query, get_dialect from databases.core import DatabaseURL from databases.interfaces import ( ConnectionBackend, @@ -29,7 +29,7 @@ def __init__( ) -> None: self._database_url = DatabaseURL(database_url) self._options = options - self._dialect = get_dialect() + self._dialect = PGDialect_psycopg() self._pool = None async def connect(self) -> None: @@ -86,7 +86,7 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: if self._connection is None: raise RuntimeError("Connection is not acquired") - query_str, args, result_columns = compile_query(query, self._dialect) + query_str, args, result_columns = self._compile(query) async with self._connection.cursor() as cursor: await cursor.execute(query_str, args) @@ -99,7 +99,7 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterfa if self._connection is None: raise RuntimeError("Connection is not acquired") - query_str, args, result_columns = compile_query(query, self._dialect) + query_str, args, result_columns = self._compile(query) async with self._connection.cursor() as cursor: await cursor.execute(query_str, args) @@ -125,7 +125,7 @@ async def execute(self, query: ClauseElement) -> typing.Any: if self._connection is None: raise RuntimeError("Connection is not acquired") - query_str, args, _ = compile_query(query, self._dialect) + query_str, args, _ = self._compile(query) async with self._connection.cursor() as cursor: await cursor.execute(query_str, args) @@ -141,7 +141,7 @@ async def iterate( if self._connection is None: raise RuntimeError("Connection is not acquired") - query_str, args, result_columns = compile_query(query, self._dialect) + query_str, args, result_columns = self._compile(query) column_maps = create_column_maps(result_columns) async with self._connection.cursor() as cursor: @@ -164,6 +164,17 @@ def raw_connection(self) -> typing.Any: raise RuntimeError("Connection is not acquired") return self._connection + def _compile( + self, query: ClauseElement, + ) -> typing.Tuple[str, typing.Mapping[str, typing.Any], tuple]: + compiled = query.compile(dialect=self._dialect) + + compiled_query = compiled.string + params = compiled.params + result_map = compiled._result_columns + + return compiled_query, params, result_map + class PsycopgTransaction(TransactionBackend): _connecttion: PsycopgConnection diff --git a/tests/test_databases.py b/tests/test_databases.py index 90771b1d..5c0b61d1 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -1338,6 +1338,7 @@ async def test_queries_with_expose_backend_connection(database_url): "mysql+asyncmy", "mysql+aiomysql", "postgresql+aiopg", + "postgresql+psycopg", ]: insert_query = "INSERT INTO notes (text, completed) VALUES (%s, %s)" else: