21
21
import sys
22
22
import time
23
23
import typing
24
+ import typing_extensions
24
25
import urllib .parse
25
26
import warnings
26
27
37
38
if typing .TYPE_CHECKING :
38
39
from . import connection
39
40
40
- _Connection = typing .TypeVar (
41
- '_Connection ' ,
41
+ _ConnectionT = typing .TypeVar (
42
+ '_ConnectionT ' ,
42
43
bound = 'connection.Connection[typing.Any]'
43
44
)
44
- _Protocol = typing .TypeVar ('_Protocol' , bound = 'protocol.Protocol[typing.Any]' )
45
- _AsyncProtocol = typing . TypeVar (
46
- '_AsyncProtocol' , bound = 'asyncio.protocols. Protocol'
45
+ _ProtocolT = typing .TypeVar (
46
+ '_ProtocolT' ,
47
+ bound = 'protocol. Protocol[typing.Any] '
47
48
)
48
- _Record = typing .TypeVar ('_Record' , bound = protocol .Record )
49
- _SSLMode = typing .TypeVar ('_SSLMode' , bound = 'SSLMode' )
49
+ _AsyncProtocolT = typing .TypeVar (
50
+ '_AsyncProtocolT' , bound = 'asyncio.protocols.Protocol'
51
+ )
52
+ _RecordT = typing .TypeVar ('_RecordT' , bound = protocol .Record )
53
+ _SSLModeT = typing .TypeVar ('_SSLModeT' , bound = 'SSLMode' )
50
54
51
- _TPTupleType = typing .Tuple [asyncio .WriteTransport , _AsyncProtocol ]
52
- AddrType = typing .Union [typing .Tuple [str , int ], str ]
53
- SSLStringValues = compat .Literal [
55
+ _TPTupleType : typing_extensions .TypeAlias = typing .Tuple [
56
+ asyncio .WriteTransport ,
57
+ _AsyncProtocolT
58
+ ]
59
+ _SSLStringValues = compat .Literal [
54
60
'disable' , 'prefer' , 'allow' , 'require' , 'verify-ca' , 'verify-full'
55
61
]
56
- _ParsedSSLType = typing .Union [
62
+ AddrType : typing_extensions .TypeAlias = typing .Union [
63
+ typing .Tuple [str , int ],
64
+ str
65
+ ]
66
+ _ParsedSSLType : typing_extensions .TypeAlias = typing .Union [
57
67
ssl_module .SSLContext , compat .Literal [False ]
58
68
]
59
- SSLType = typing .Union [_ParsedSSLType , SSLStringValues , bool ]
60
- HostType = typing .Union [typing .List [str ], str ]
61
- PortListType = typing .Union [
69
+ SSLType : typing_extensions .TypeAlias = typing .Union [
70
+ _ParsedSSLType , _SSLStringValues , bool
71
+ ]
72
+ HostType : typing_extensions .TypeAlias = typing .Union [typing .List [str ], str ]
73
+ PortListType : typing_extensions .TypeAlias = typing .Union [
62
74
typing .List [typing .Union [int , str ]],
63
75
typing .List [int ],
64
76
typing .List [str ],
65
77
]
66
- PortType = typing .Union [
78
+ PortType : typing_extensions . TypeAlias = typing .Union [
67
79
PortListType ,
68
80
int ,
69
81
str
@@ -80,13 +92,13 @@ class SSLMode(enum.IntEnum):
80
92
81
93
@classmethod
82
94
def parse (
83
- cls : typing .Type [_SSLMode ],
84
- sslmode : typing .Union [str , _SSLMode ]
85
- ) -> _SSLMode :
95
+ cls : typing .Type [_SSLModeT ],
96
+ sslmode : typing .Union [str , _SSLModeT ]
97
+ ) -> _SSLModeT :
86
98
if isinstance (sslmode , cls ):
87
99
return sslmode
88
100
return typing .cast (
89
- _SSLMode ,
101
+ _SSLModeT ,
90
102
getattr (cls , typing .cast (str , sslmode ).replace ('-' , '_' ))
91
103
)
92
104
@@ -798,14 +810,14 @@ def connection_lost(self, exc: typing.Optional[Exception]) -> None:
798
810
799
811
@typing .overload
800
812
async def _create_ssl_connection (
801
- protocol_factory : typing .Callable [[], _Protocol ],
813
+ protocol_factory : typing .Callable [[], _ProtocolT ],
802
814
host : str ,
803
815
port : int ,
804
816
* ,
805
817
loop : asyncio .AbstractEventLoop ,
806
818
ssl_context : ssl_module .SSLContext ,
807
819
ssl_is_advisory : typing .Optional [bool ] = False
808
- ) -> _TPTupleType [_Protocol ]:
820
+ ) -> _TPTupleType [_ProtocolT ]:
809
821
...
810
822
811
823
@@ -824,7 +836,7 @@ async def _create_ssl_connection(
824
836
825
837
async def _create_ssl_connection (
826
838
protocol_factory : typing .Union [
827
- typing .Callable [[], _Protocol ],
839
+ typing .Callable [[], _ProtocolT ],
828
840
typing .Callable [[], '_CancelProto' ]
829
841
],
830
842
host : str ,
@@ -886,7 +898,7 @@ async def _create_ssl_connection(
886
898
887
899
try :
888
900
new_tr , pg_proto = typing .cast (
889
- typing .Tuple [asyncio .WriteTransport , _Protocol ],
901
+ typing .Tuple [asyncio .WriteTransport , _ProtocolT ],
890
902
await conn_factory (sock = sock )
891
903
)
892
904
pg_proto .is_ssl = do_ssl_upgrade
@@ -903,9 +915,9 @@ async def _connect_addr(
903
915
timeout : float ,
904
916
params : _ConnectionParameters ,
905
917
config : _ClientConfiguration ,
906
- connection_class : typing .Type [_Connection ],
907
- record_class : typing .Type [_Record ]
908
- ) -> _Connection :
918
+ connection_class : typing .Type [_ConnectionT ],
919
+ record_class : typing .Type [_RecordT ]
920
+ ) -> _ConnectionT :
909
921
assert loop is not None
910
922
911
923
if timeout <= 0 :
@@ -956,22 +968,22 @@ async def __connect_addr(
956
968
addr : AddrType ,
957
969
loop : asyncio .AbstractEventLoop ,
958
970
config : _ClientConfiguration ,
959
- connection_class : typing .Type [_Connection ],
960
- record_class : typing .Type [_Record ],
971
+ connection_class : typing .Type [_ConnectionT ],
972
+ record_class : typing .Type [_RecordT ],
961
973
params_input : _ConnectionParameters ,
962
- ) -> _Connection :
974
+ ) -> _ConnectionT :
963
975
connected = _create_future (loop )
964
976
965
977
proto_factory : typing .Callable [
966
- [], 'protocol.Protocol[_Record ]'
978
+ [], 'protocol.Protocol[_RecordT ]'
967
979
] = lambda : protocol .Protocol (
968
980
addr , connected , params , record_class , loop )
969
981
970
982
if isinstance (addr , str ):
971
983
# UNIX socket
972
984
connector = typing .cast (
973
985
typing .Coroutine [
974
- typing .Any , None , _TPTupleType ['protocol.Protocol[_Record ]' ]
986
+ typing .Any , None , _TPTupleType ['protocol.Protocol[_RecordT ]' ]
975
987
],
976
988
loop .create_unix_connection (proto_factory , addr )
977
989
)
@@ -981,7 +993,7 @@ async def __connect_addr(
981
993
# SSL connection
982
994
connector = typing .cast (
983
995
typing .Coroutine [
984
- typing .Any , None , _TPTupleType ['protocol.Protocol[_Record ]' ]
996
+ typing .Any , None , _TPTupleType ['protocol.Protocol[_RecordT ]' ]
985
997
],
986
998
loop .create_connection (
987
999
proto_factory , * addr , ssl = params .ssl
@@ -995,7 +1007,7 @@ async def __connect_addr(
995
1007
else :
996
1008
connector = typing .cast (
997
1009
typing .Coroutine [
998
- typing .Any , None , _TPTupleType ['protocol.Protocol[_Record ]' ]
1010
+ typing .Any , None , _TPTupleType ['protocol.Protocol[_RecordT ]' ]
999
1011
],
1000
1012
loop .create_connection (proto_factory , * addr )
1001
1013
)
@@ -1053,10 +1065,10 @@ async def _connect(
1053
1065
* ,
1054
1066
loop : typing .Optional [asyncio .AbstractEventLoop ],
1055
1067
timeout : float ,
1056
- connection_class : typing .Type [_Connection ],
1057
- record_class : typing .Type [_Record ],
1068
+ connection_class : typing .Type [_ConnectionT ],
1069
+ record_class : typing .Type [_RecordT ],
1058
1070
** kwargs : typing .Any
1059
- ) -> _Connection :
1071
+ ) -> _ConnectionT :
1060
1072
if loop is None :
1061
1073
loop = asyncio .get_event_loop ()
1062
1074
0 commit comments