Skip to content

Commit 7528e6d

Browse files
committed
Clean up TypeVar naems and mark type aliases
1 parent 8951230 commit 7528e6d

8 files changed

+358
-318
lines changed

asyncpg/connect_utils.py

+48-36
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import sys
2222
import time
2323
import typing
24+
import typing_extensions
2425
import urllib.parse
2526
import warnings
2627

@@ -37,33 +38,44 @@
3738
if typing.TYPE_CHECKING:
3839
from . import connection
3940

40-
_Connection = typing.TypeVar(
41-
'_Connection',
41+
_ConnectionT = typing.TypeVar(
42+
'_ConnectionT',
4243
bound='connection.Connection[typing.Any]'
4344
)
44-
_Protocol = typing.TypeVar('_Protocol', bound='protocol.Protocol[typing.Any]')
45-
_AsyncProtocol = typing.TypeVar(
46-
'_AsyncProtocol', bound='asyncio.protocols.Protocol'
45+
_ProtocolT = typing.TypeVar(
46+
'_ProtocolT',
47+
bound='protocol.Protocol[typing.Any]'
4748
)
48-
_Record = typing.TypeVar('_Record', bound=protocol.Record)
49-
_SSLMode = typing.TypeVar('_SSLMode', bound='SSLMode')
49+
_AsyncProtocolT = typing.TypeVar(
50+
'_AsyncProtocolT', bound='asyncio.protocols.Protocol'
51+
)
52+
_RecordT = typing.TypeVar('_RecordT', bound=protocol.Record)
53+
_SSLModeT = typing.TypeVar('_SSLModeT', bound='SSLMode')
5054

51-
_TPTupleType = typing.Tuple[asyncio.WriteTransport, _AsyncProtocol]
52-
AddrType = typing.Union[typing.Tuple[str, int], str]
53-
SSLStringValues = compat.Literal[
55+
_TPTupleType: typing_extensions.TypeAlias = typing.Tuple[
56+
asyncio.WriteTransport,
57+
_AsyncProtocolT
58+
]
59+
_SSLStringValues = compat.Literal[
5460
'disable', 'prefer', 'allow', 'require', 'verify-ca', 'verify-full'
5561
]
56-
_ParsedSSLType = typing.Union[
62+
AddrType: typing_extensions.TypeAlias = typing.Union[
63+
typing.Tuple[str, int],
64+
str
65+
]
66+
_ParsedSSLType: typing_extensions.TypeAlias = typing.Union[
5767
ssl_module.SSLContext, compat.Literal[False]
5868
]
59-
SSLType = typing.Union[_ParsedSSLType, SSLStringValues, bool]
60-
HostType = typing.Union[typing.List[str], str]
61-
PortListType = typing.Union[
69+
SSLType: typing_extensions.TypeAlias = typing.Union[
70+
_ParsedSSLType, _SSLStringValues, bool
71+
]
72+
HostType: typing_extensions.TypeAlias = typing.Union[typing.List[str], str]
73+
PortListType: typing_extensions.TypeAlias = typing.Union[
6274
typing.List[typing.Union[int, str]],
6375
typing.List[int],
6476
typing.List[str],
6577
]
66-
PortType = typing.Union[
78+
PortType: typing_extensions.TypeAlias = typing.Union[
6779
PortListType,
6880
int,
6981
str
@@ -80,13 +92,13 @@ class SSLMode(enum.IntEnum):
8092

8193
@classmethod
8294
def parse(
83-
cls: typing.Type[_SSLMode],
84-
sslmode: typing.Union[str, _SSLMode]
85-
) -> _SSLMode:
95+
cls: typing.Type[_SSLModeT],
96+
sslmode: typing.Union[str, _SSLModeT]
97+
) -> _SSLModeT:
8698
if isinstance(sslmode, cls):
8799
return sslmode
88100
return typing.cast(
89-
_SSLMode,
101+
_SSLModeT,
90102
getattr(cls, typing.cast(str, sslmode).replace('-', '_'))
91103
)
92104

@@ -798,14 +810,14 @@ def connection_lost(self, exc: typing.Optional[Exception]) -> None:
798810

799811
@typing.overload
800812
async def _create_ssl_connection(
801-
protocol_factory: typing.Callable[[], _Protocol],
813+
protocol_factory: typing.Callable[[], _ProtocolT],
802814
host: str,
803815
port: int,
804816
*,
805817
loop: asyncio.AbstractEventLoop,
806818
ssl_context: ssl_module.SSLContext,
807819
ssl_is_advisory: typing.Optional[bool] = False
808-
) -> _TPTupleType[_Protocol]:
820+
) -> _TPTupleType[_ProtocolT]:
809821
...
810822

811823

@@ -824,7 +836,7 @@ async def _create_ssl_connection(
824836

825837
async def _create_ssl_connection(
826838
protocol_factory: typing.Union[
827-
typing.Callable[[], _Protocol],
839+
typing.Callable[[], _ProtocolT],
828840
typing.Callable[[], '_CancelProto']
829841
],
830842
host: str,
@@ -886,7 +898,7 @@ async def _create_ssl_connection(
886898

887899
try:
888900
new_tr, pg_proto = typing.cast(
889-
typing.Tuple[asyncio.WriteTransport, _Protocol],
901+
typing.Tuple[asyncio.WriteTransport, _ProtocolT],
890902
await conn_factory(sock=sock)
891903
)
892904
pg_proto.is_ssl = do_ssl_upgrade
@@ -903,9 +915,9 @@ async def _connect_addr(
903915
timeout: float,
904916
params: _ConnectionParameters,
905917
config: _ClientConfiguration,
906-
connection_class: typing.Type[_Connection],
907-
record_class: typing.Type[_Record]
908-
) -> _Connection:
918+
connection_class: typing.Type[_ConnectionT],
919+
record_class: typing.Type[_RecordT]
920+
) -> _ConnectionT:
909921
assert loop is not None
910922

911923
if timeout <= 0:
@@ -956,22 +968,22 @@ async def __connect_addr(
956968
addr: AddrType,
957969
loop: asyncio.AbstractEventLoop,
958970
config: _ClientConfiguration,
959-
connection_class: typing.Type[_Connection],
960-
record_class: typing.Type[_Record],
971+
connection_class: typing.Type[_ConnectionT],
972+
record_class: typing.Type[_RecordT],
961973
params_input: _ConnectionParameters,
962-
) -> _Connection:
974+
) -> _ConnectionT:
963975
connected = _create_future(loop)
964976

965977
proto_factory: typing.Callable[
966-
[], 'protocol.Protocol[_Record]'
978+
[], 'protocol.Protocol[_RecordT]'
967979
] = lambda: protocol.Protocol(
968980
addr, connected, params, record_class, loop)
969981

970982
if isinstance(addr, str):
971983
# UNIX socket
972984
connector = typing.cast(
973985
typing.Coroutine[
974-
typing.Any, None, _TPTupleType['protocol.Protocol[_Record]']
986+
typing.Any, None, _TPTupleType['protocol.Protocol[_RecordT]']
975987
],
976988
loop.create_unix_connection(proto_factory, addr)
977989
)
@@ -981,7 +993,7 @@ async def __connect_addr(
981993
# SSL connection
982994
connector = typing.cast(
983995
typing.Coroutine[
984-
typing.Any, None, _TPTupleType['protocol.Protocol[_Record]']
996+
typing.Any, None, _TPTupleType['protocol.Protocol[_RecordT]']
985997
],
986998
loop.create_connection(
987999
proto_factory, *addr, ssl=params.ssl
@@ -995,7 +1007,7 @@ async def __connect_addr(
9951007
else:
9961008
connector = typing.cast(
9971009
typing.Coroutine[
998-
typing.Any, None, _TPTupleType['protocol.Protocol[_Record]']
1010+
typing.Any, None, _TPTupleType['protocol.Protocol[_RecordT]']
9991011
],
10001012
loop.create_connection(proto_factory, *addr)
10011013
)
@@ -1053,10 +1065,10 @@ async def _connect(
10531065
*,
10541066
loop: typing.Optional[asyncio.AbstractEventLoop],
10551067
timeout: float,
1056-
connection_class: typing.Type[_Connection],
1057-
record_class: typing.Type[_Record],
1068+
connection_class: typing.Type[_ConnectionT],
1069+
record_class: typing.Type[_RecordT],
10581070
**kwargs: typing.Any
1059-
) -> _Connection:
1071+
) -> _ConnectionT:
10601072
if loop is None:
10611073
loop = asyncio.get_event_loop()
10621074

0 commit comments

Comments
 (0)