7
7
8
8
import asyncio
9
9
import collections
10
+ import functools
10
11
import getpass
11
12
import os
12
13
import pathlib
13
14
import platform
14
15
import re
15
16
import socket
17
+ import ssl as ssl_module
16
18
import stat
17
19
import struct
18
20
import time
32
34
'password' ,
33
35
'database' ,
34
36
'ssl' ,
37
+ 'ssl_is_advisory' ,
35
38
'connect_timeout' ,
36
39
'server_settings' ,
37
40
])
@@ -208,6 +211,11 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
208
211
if passfile is None :
209
212
passfile = val
210
213
214
+ if 'sslmode' in query :
215
+ val = query .pop ('sslmode' )
216
+ if ssl is None :
217
+ ssl = val
218
+
211
219
if query :
212
220
if server_settings is None :
213
221
server_settings = query
@@ -303,6 +311,47 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
303
311
raise ValueError (
304
312
'could not determine the database address to connect to' )
305
313
314
+ if ssl is None :
315
+ ssl = os .getenv ('PGSSLMODE' )
316
+
317
+ # ssl_is_advisory is only allowed to come from the sslmode parameter.
318
+ ssl_is_advisory = None
319
+ if isinstance (ssl , str ):
320
+ SSLMODES = {
321
+ 'disable' : 0 ,
322
+ 'allow' : 1 ,
323
+ 'prefer' : 2 ,
324
+ 'require' : 3 ,
325
+ 'verify-ca' : 4 ,
326
+ 'verify-full' : 5 ,
327
+ }
328
+ try :
329
+ sslmode = SSLMODES [ssl ]
330
+ except KeyError :
331
+ modes = ', ' .join (SSLMODES .keys ())
332
+ raise ValueError ('`sslmode` parameter must be one of ' + modes )
333
+
334
+ # sslmode 'allow' is currently handled as 'prefer' because we're
335
+ # missing the "retry with SSL" behavior for 'allow', but do have the
336
+ # "retry without SSL" behavior for 'prefer'.
337
+ # Not changing 'allow' to 'prefer' here would be effectively the same
338
+ # as changing 'allow' to 'disable'.
339
+ if sslmode == SSLMODES ['allow' ]:
340
+ sslmode = SSLMODES ['prefer' ]
341
+
342
+ # docs at https://www.postgresql.org/docs/10/static/libpq-connect.html
343
+ # Not implemented: sslcert & sslkey & sslrootcert & sslcrl params.
344
+ if sslmode <= SSLMODES ['allow' ]:
345
+ ssl = False
346
+ ssl_is_advisory = sslmode >= SSLMODES ['allow' ]
347
+ else :
348
+ ssl = ssl_module .create_default_context ()
349
+ ssl .check_hostname = sslmode >= SSLMODES ['verify-full' ]
350
+ ssl .verify_mode = ssl_module .CERT_REQUIRED
351
+ if sslmode <= SSLMODES ['require' ]:
352
+ ssl .verify_mode = ssl_module .CERT_NONE
353
+ ssl_is_advisory = sslmode <= SSLMODES ['prefer' ]
354
+
306
355
if ssl :
307
356
for addr in addrs :
308
357
if isinstance (addr , str ):
@@ -321,7 +370,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
321
370
322
371
params = _ConnectionParameters (
323
372
user = user , password = password , database = database , ssl = ssl ,
324
- connect_timeout = connect_timeout , server_settings = server_settings )
373
+ ssl_is_advisory = ssl_is_advisory , connect_timeout = connect_timeout ,
374
+ server_settings = server_settings )
325
375
326
376
return addrs , params
327
377
@@ -384,11 +434,12 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
384
434
385
435
if isinstance (addr , str ):
386
436
# UNIX socket
387
- assert params .ssl is None
437
+ assert not params .ssl
388
438
connector = loop .create_unix_connection (proto_factory , addr )
389
439
elif params .ssl :
390
440
connector = _create_ssl_connection (
391
- proto_factory , * addr , loop = loop , ssl_context = params .ssl )
441
+ proto_factory , * addr , loop = loop , ssl_context = params .ssl ,
442
+ ssl_is_advisory = params .ssl_is_advisory )
392
443
else :
393
444
connector = loop .create_connection (proto_factory , * addr )
394
445
@@ -435,7 +486,12 @@ async def _connect(*, loop, timeout, connection_class, **kwargs):
435
486
raise last_error
436
487
437
488
438
- async def _get_ssl_ready_socket (host , port , * , loop ):
489
+ async def _negotiate_ssl_connection (host , port , conn_factory , * , loop , ssl ,
490
+ server_hostname , ssl_is_advisory = False ):
491
+ # Note: ssl_is_advisory only affects behavior when the server does not
492
+ # accept SSLRequests. If the SSLRequest is accepted but either the SSL
493
+ # negotiation fails or the PostgreSQL user isn't permitted to use SSL,
494
+ # there's nothing that would attempt to reconnect with a non-SSL socket.
439
495
reader , writer = await asyncio .open_connection (host , port , loop = loop )
440
496
441
497
tr = writer .transport
@@ -448,44 +504,55 @@ async def _get_ssl_ready_socket(host, port, *, loop):
448
504
resp = await reader .readexactly (1 )
449
505
450
506
if resp == b'S' :
451
- return sock .dup ()
507
+ conn_factory = functools .partial (
508
+ conn_factory , ssl = ssl , server_hostname = server_hostname )
509
+ elif (ssl_is_advisory and
510
+ ssl .verify_mode == ssl_module .CERT_NONE and
511
+ resp == b'N' ):
512
+ # ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
513
+ # since the only way to get ssl_is_advisory is from sslmode=prefer
514
+ # (or sslmode=allow). But be extra sure to disallow insecure
515
+ # connections when the ssl context asks for real security.
516
+ pass
452
517
else :
453
518
raise ConnectionError (
454
519
'PostgreSQL server at "{}:{}" rejected SSL upgrade' .format (
455
520
host , port ))
521
+
522
+ sock = sock .dup () # Must come before tr.close()
456
523
finally :
457
524
tr .close ()
458
525
459
-
460
- async def _create_ssl_connection (protocol_factory , host , port , * ,
461
- loop , ssl_context ):
462
- sock = await _get_ssl_ready_socket (host , port , loop = loop )
463
526
try :
464
- return await loop .create_connection (
465
- protocol_factory , sock = sock , ssl = ssl_context ,
466
- server_hostname = host )
527
+ return await conn_factory (sock = sock ) # Must come after tr.close()
467
528
except Exception :
468
529
sock .close ()
469
530
raise
470
531
471
532
533
+ async def _create_ssl_connection (protocol_factory , host , port , * ,
534
+ loop , ssl_context , ssl_is_advisory = False ):
535
+ return await _negotiate_ssl_connection (
536
+ host , port ,
537
+ functools .partial (loop .create_connection , protocol_factory ),
538
+ loop = loop ,
539
+ ssl = ssl_context ,
540
+ server_hostname = host ,
541
+ ssl_is_advisory = ssl_is_advisory )
542
+
543
+
472
544
async def _open_connection (* , loop , addr , params : _ConnectionParameters ):
473
545
if isinstance (addr , str ):
474
546
r , w = await asyncio .open_unix_connection (addr , loop = loop )
475
547
else :
476
548
if params .ssl :
477
- sock = await _get_ssl_ready_socket (* addr , loop = loop )
478
-
479
- try :
480
- r , w = await asyncio .open_connection (
481
- sock = sock ,
482
- loop = loop ,
483
- ssl = params .ssl ,
484
- server_hostname = addr [0 ])
485
- except Exception :
486
- sock .close ()
487
- raise
488
-
549
+ r , w = await _negotiate_ssl_connection (
550
+ * addr ,
551
+ functools .partial (asyncio .open_connection , loop = loop ),
552
+ loop = loop ,
553
+ ssl = params .ssl ,
554
+ server_hostname = addr [0 ],
555
+ ssl_is_advisory = params .ssl_is_advisory )
489
556
else :
490
557
r , w = await asyncio .open_connection (* addr , loop = loop )
491
558
_set_nodelay (_get_socket (w .transport ))
0 commit comments