From 3f26f76056d0da62600a2fca5f3c23e6ee8bbae5 Mon Sep 17 00:00:00 2001 From: ansipunk Date: Sun, 3 Mar 2024 14:39:47 +0500 Subject: [PATCH] S01E08 --- databases/backends/asyncpg.py | 66 +++++++++++++++++++++----- databases/backends/dialects/psycopg.py | 43 +---------------- databases/backends/psycopg.py | 5 +- databases/core.py | 2 +- 4 files changed, 58 insertions(+), 58 deletions(-) diff --git a/databases/backends/asyncpg.py b/databases/backends/asyncpg.py index ff61fe26..124f7af1 100644 --- a/databases/backends/asyncpg.py +++ b/databases/backends/asyncpg.py @@ -4,10 +4,11 @@ import asyncpg from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.sql import ClauseElement +from sqlalchemy.sql.ddl import DDLElement from databases.backends.common.records import Record, create_column_maps -from databases.backends.dialects.psycopg import compile_query, get_dialect -from databases.core import DatabaseURL +from databases.backends.dialects.psycopg import dialect as psycopg_dialect +from databases.core import LOG_EXTRA, DatabaseURL from databases.interfaces import ( ConnectionBackend, DatabaseBackend, @@ -24,9 +25,20 @@ def __init__( ) -> None: self._database_url = DatabaseURL(database_url) self._options = options - self._dialect = get_dialect() + self._dialect = self._get_dialect() self._pool = None + def _get_dialect(self) -> Dialect: + dialect = psycopg_dialect(paramstyle="pyformat") + dialect.implicit_returning = True + dialect.supports_native_enum = True + dialect.supports_smallserial = True # 9.2+ + dialect._backslash_escapes = False + dialect.supports_sane_multi_rowcount = True # psycopg 2.0.9+ + dialect._has_native_hstore = True + dialect.supports_native_decimal = True + return dialect + def _get_connection_kwargs(self) -> dict: url_options = self._database_url.options @@ -87,7 +99,7 @@ async def release(self) -> None: async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: assert self._connection is not None, "Connection is not acquired" - query_str, args, result_columns = compile_query(query, self._dialect) + query_str, args, result_columns = self._compile(query) rows = await self._connection.fetch(query_str, *args) dialect = self._dialect column_maps = create_column_maps(result_columns) @@ -95,7 +107,7 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: assert self._connection is not None, "Connection is not acquired" - query_str, args, result_columns = compile_query(query, self._dialect) + query_str, args, result_columns = self._compile(query) row = await self._connection.fetchrow(query_str, *args) if row is None: return None @@ -123,7 +135,7 @@ async def fetch_val( async def execute(self, query: ClauseElement) -> typing.Any: assert self._connection is not None, "Connection is not acquired" - query_str, args, _ = compile_query(query, self._dialect) + query_str, args, _ = self._compile(query) return await self._connection.fetchval(query_str, *args) async def execute_many(self, queries: typing.List[ClauseElement]) -> None: @@ -132,14 +144,14 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: # loop through multiple executes here, which should all end up # using the same prepared statement. for single_query in queries: - single_query, args, _ = compile_query(single_query, self._dialect) + single_query, args, _ = self._compile(single_query) await self._connection.execute(single_query, *args) async def iterate( self, query: ClauseElement ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" - query_str, args, result_columns = compile_query(query, self._dialect) + query_str, args, result_columns = self._compile(query) column_maps = create_column_maps(result_columns) async for row in self._connection.cursor(query_str, *args): yield Record(row, result_columns, self._dialect, column_maps) @@ -147,10 +159,40 @@ async def iterate( def transaction(self) -> TransactionBackend: return AsyncpgTransaction(connection=self) - @property - def raw_connection(self) -> asyncpg.connection.Connection: - assert self._connection is not None, "Connection is not acquired" - return self._connection + def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: + compiled = query.compile( + dialect=self._dialect, compile_kwargs={"render_postcompile": True} + ) + + if not isinstance(query, DDLElement): + compiled_params = sorted(compiled.params.items()) + + mapping = { + key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1) + } + compiled_query = compiled.string % mapping + + processors = compiled._bind_processors + args = [ + processors[key](val) if key in processors else val + for key, val in compiled_params + ] + result_map = compiled._result_columns + else: + compiled_query = compiled.string + args = [] + result_map = None + + query_message = compiled_query.replace(" \n", " ").replace("\n", " ") + logger.debug( + "Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA + ) + return compiled_query, args, result_map + + @property + def raw_connection(self) -> asyncpg.connection.Connection: + assert self._connection is not None, "Connection is not acquired" + return self._connection class AsyncpgTransaction(TransactionBackend): diff --git a/databases/backends/dialects/psycopg.py b/databases/backends/dialects/psycopg.py index 80bf5b76..07bd1880 100644 --- a/databases/backends/dialects/psycopg.py +++ b/databases/backends/dialects/psycopg.py @@ -10,9 +10,6 @@ from sqlalchemy import types, util from sqlalchemy.dialects.postgresql.base import PGDialect, PGExecutionContext from sqlalchemy.engine import processors -from sqlalchemy.engine.interfaces import Dialect -from sqlalchemy.sql import ClauseElement -from sqlalchemy.sql.ddl import DDLElement from sqlalchemy.types import Float, Numeric @@ -46,42 +43,4 @@ class PGDialect_psycopg(PGDialect): execution_ctx_cls = PGExecutionContext_psycopg -def get_dialect() -> Dialect: - dialect = PGDialect_psycopg(paramstyle="pyformat") - dialect.implicit_returning = True - dialect.supports_native_enum = True - dialect.supports_smallserial = True # 9.2+ - dialect._backslash_escapes = False - dialect.supports_sane_multi_rowcount = True # psycopg 2.0.9+ - dialect._has_native_hstore = True - dialect.supports_native_decimal = True - return dialect - - -def compile_query( - query: ClauseElement, dialect: Dialect -) -> typing.Tuple[str, list, tuple]: - compiled = query.compile( - dialect=dialect, compile_kwargs={"render_postcompile": True} - ) - - if not isinstance(query, DDLElement): - compiled_params = sorted(compiled.params.items()) - - mapping = { - key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1) - } - compiled_query = compiled.string % mapping - - processors = compiled._bind_processors - args = [ - processors[key](val) if key in processors else val - for key, val in compiled_params - ] - result_map = compiled._result_columns - else: - compiled_query = compiled.string - args = [] - result_map = None - - return compiled_query, args, result_map +dialect = PGDialect_psycopg diff --git a/databases/backends/psycopg.py b/databases/backends/psycopg.py index bb623cfa..f83f4917 100644 --- a/databases/backends/psycopg.py +++ b/databases/backends/psycopg.py @@ -39,9 +39,8 @@ async def connect(self) -> None: if self._pool is not None: return - self._pool = psycopg_pool.AsyncConnectionPool( - self._database_url._url, open=False, **self._options - ) + url = self._database_url._url.replace("postgresql+psycopg", "postgresql") + self._pool = psycopg_pool.AsyncConnectionPool(url, open=False, **self._options) # TODO: Add configurable timeouts await self._pool.open() diff --git a/databases/core.py b/databases/core.py index c09f8814..cba06ced 100644 --- a/databases/core.py +++ b/databases/core.py @@ -44,10 +44,10 @@ class Database: SUPPORTED_BACKENDS = { "postgres": "databases.backends.asyncpg:AsyncpgBackend", + "postgresql": "databases.backends.asyncpg:AsyncpgBackend", "postgresql+aiopg": "databases.backends.aiopg:AiopgBackend", "postgresql+asyncpg": "databases.backends.asyncpg:AsyncpgBackend", "postgresql+psycopg": "databases.backends.psycopg:PsycopgBackend", - "postgresql": "databases.backends.psycopg:PsycopgBackend", "mysql": "databases.backends.mysql:MySQLBackend", "mysql+aiomysql": "databases.backends.asyncmy:MySQLBackend", "mysql+asyncmy": "databases.backends.asyncmy:AsyncMyBackend",