Skip to content

Commit c9f8b0f

Browse files
committed
Allow using custom Record class
Add the new `record_class` parameter to the `create_pool()` and `connect()` functions, as well as to the `cursor()`, `prepare()`, `fetch()` and `fetchrow()` connection methods. This not only allows adding custom functionality to the returned objects, but also assists with typing (see #577 for discussion). Fixes: #40.
1 parent 39040b3 commit c9f8b0f

17 files changed

+610
-105
lines changed

.flake8

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[flake8]
2-
ignore = E402,E731,W504,E252
2+
ignore = E402,E731,W503,W504,E252
33
exclude = .git,__pycache__,build,dist,.eggs,.github,.local

asyncpg/_testbase/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import unittest
2020

2121

22+
import asyncpg
2223
from asyncpg import cluster as pg_cluster
2324
from asyncpg import connection as pg_connection
2425
from asyncpg import pool as pg_pool
@@ -266,13 +267,15 @@ def create_pool(dsn=None, *,
266267
loop=None,
267268
pool_class=pg_pool.Pool,
268269
connection_class=pg_connection.Connection,
270+
record_class=asyncpg.Record,
269271
**connect_kwargs):
270272
return pool_class(
271273
dsn,
272274
min_size=min_size, max_size=max_size,
273275
max_queries=max_queries, loop=loop, setup=setup, init=init,
274276
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
275277
connection_class=connection_class,
278+
record_class=record_class,
276279
**connect_kwargs)
277280

278281

asyncpg/connect_utils.py

+20-7
Original file line numberDiff line numberDiff line change
@@ -594,8 +594,16 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
594594
raise
595595

596596

597-
async def _connect_addr(*, addr, loop, timeout, params, config,
598-
connection_class):
597+
async def _connect_addr(
598+
*,
599+
addr,
600+
loop,
601+
timeout,
602+
params,
603+
config,
604+
connection_class,
605+
record_class
606+
):
599607
assert loop is not None
600608

601609
if timeout <= 0:
@@ -613,7 +621,7 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
613621
params = params._replace(password=password)
614622

615623
proto_factory = lambda: protocol.Protocol(
616-
addr, connected, params, loop)
624+
addr, connected, params, record_class, loop)
617625

618626
if isinstance(addr, str):
619627
# UNIX socket
@@ -649,7 +657,7 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
649657
return con
650658

651659

652-
async def _connect(*, loop, timeout, connection_class, **kwargs):
660+
async def _connect(*, loop, timeout, connection_class, record_class, **kwargs):
653661
if loop is None:
654662
loop = asyncio.get_event_loop()
655663

@@ -661,9 +669,14 @@ async def _connect(*, loop, timeout, connection_class, **kwargs):
661669
before = time.monotonic()
662670
try:
663671
con = await _connect_addr(
664-
addr=addr, loop=loop, timeout=timeout,
665-
params=params, config=config,
666-
connection_class=connection_class)
672+
addr=addr,
673+
loop=loop,
674+
timeout=timeout,
675+
params=params,
676+
config=config,
677+
connection_class=connection_class,
678+
record_class=record_class,
679+
)
667680
except (OSError, asyncio.TimeoutError, ConnectionError) as ex:
668681
last_error = ex
669682
else:

0 commit comments

Comments
 (0)