Skip to content

Commit 0304288

Browse files
percontationelprans
authored andcommitted
Allow 'sslmode' in DSNs
Basic support for libpq-like handling of 'sslmode' parameter. It uses default ssl contexts, rather than reading ~/.postgres files or other 'ssl*' parameters like libpq would.
1 parent a7eaf2b commit 0304288

File tree

2 files changed

+205
-34
lines changed

2 files changed

+205
-34
lines changed

asyncpg/connect_utils.py

+91-24
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77

88
import asyncio
99
import collections
10+
import functools
1011
import getpass
1112
import os
1213
import pathlib
1314
import platform
1415
import re
1516
import socket
17+
import ssl as ssl_module
1618
import stat
1719
import struct
1820
import time
@@ -32,6 +34,7 @@
3234
'password',
3335
'database',
3436
'ssl',
37+
'ssl_is_advisory',
3538
'connect_timeout',
3639
'server_settings',
3740
])
@@ -208,6 +211,11 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
208211
if passfile is None:
209212
passfile = val
210213

214+
if 'sslmode' in query:
215+
val = query.pop('sslmode')
216+
if ssl is None:
217+
ssl = val
218+
211219
if query:
212220
if server_settings is None:
213221
server_settings = query
@@ -303,6 +311,47 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
303311
raise ValueError(
304312
'could not determine the database address to connect to')
305313

314+
if ssl is None:
315+
ssl = os.getenv('PGSSLMODE')
316+
317+
# ssl_is_advisory is only allowed to come from the sslmode parameter.
318+
ssl_is_advisory = None
319+
if isinstance(ssl, str):
320+
SSLMODES = {
321+
'disable': 0,
322+
'allow': 1,
323+
'prefer': 2,
324+
'require': 3,
325+
'verify-ca': 4,
326+
'verify-full': 5,
327+
}
328+
try:
329+
sslmode = SSLMODES[ssl]
330+
except KeyError:
331+
modes = ', '.join(SSLMODES.keys())
332+
raise ValueError('`sslmode` parameter must be one of ' + modes)
333+
334+
# sslmode 'allow' is currently handled as 'prefer' because we're
335+
# missing the "retry with SSL" behavior for 'allow', but do have the
336+
# "retry without SSL" behavior for 'prefer'.
337+
# Not changing 'allow' to 'prefer' here would be effectively the same
338+
# as changing 'allow' to 'disable'.
339+
if sslmode == SSLMODES['allow']:
340+
sslmode = SSLMODES['prefer']
341+
342+
# docs at https://www.postgresql.org/docs/10/static/libpq-connect.html
343+
# Not implemented: sslcert & sslkey & sslrootcert & sslcrl params.
344+
if sslmode <= SSLMODES['allow']:
345+
ssl = False
346+
ssl_is_advisory = sslmode >= SSLMODES['allow']
347+
else:
348+
ssl = ssl_module.create_default_context()
349+
ssl.check_hostname = sslmode >= SSLMODES['verify-full']
350+
ssl.verify_mode = ssl_module.CERT_REQUIRED
351+
if sslmode <= SSLMODES['require']:
352+
ssl.verify_mode = ssl_module.CERT_NONE
353+
ssl_is_advisory = sslmode <= SSLMODES['prefer']
354+
306355
if ssl:
307356
for addr in addrs:
308357
if isinstance(addr, str):
@@ -321,7 +370,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
321370

322371
params = _ConnectionParameters(
323372
user=user, password=password, database=database, ssl=ssl,
324-
connect_timeout=connect_timeout, server_settings=server_settings)
373+
ssl_is_advisory=ssl_is_advisory, connect_timeout=connect_timeout,
374+
server_settings=server_settings)
325375

326376
return addrs, params
327377

@@ -384,11 +434,12 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
384434

385435
if isinstance(addr, str):
386436
# UNIX socket
387-
assert params.ssl is None
437+
assert not params.ssl
388438
connector = loop.create_unix_connection(proto_factory, addr)
389439
elif params.ssl:
390440
connector = _create_ssl_connection(
391-
proto_factory, *addr, loop=loop, ssl_context=params.ssl)
441+
proto_factory, *addr, loop=loop, ssl_context=params.ssl,
442+
ssl_is_advisory=params.ssl_is_advisory)
392443
else:
393444
connector = loop.create_connection(proto_factory, *addr)
394445

@@ -435,7 +486,12 @@ async def _connect(*, loop, timeout, connection_class, **kwargs):
435486
raise last_error
436487

437488

438-
async def _get_ssl_ready_socket(host, port, *, loop):
489+
async def _negotiate_ssl_connection(host, port, conn_factory, *, loop, ssl,
490+
server_hostname, ssl_is_advisory=False):
491+
# Note: ssl_is_advisory only affects behavior when the server does not
492+
# accept SSLRequests. If the SSLRequest is accepted but either the SSL
493+
# negotiation fails or the PostgreSQL user isn't permitted to use SSL,
494+
# there's nothing that would attempt to reconnect with a non-SSL socket.
439495
reader, writer = await asyncio.open_connection(host, port, loop=loop)
440496

441497
tr = writer.transport
@@ -448,44 +504,55 @@ async def _get_ssl_ready_socket(host, port, *, loop):
448504
resp = await reader.readexactly(1)
449505

450506
if resp == b'S':
451-
return sock.dup()
507+
conn_factory = functools.partial(
508+
conn_factory, ssl=ssl, server_hostname=server_hostname)
509+
elif (ssl_is_advisory and
510+
ssl.verify_mode == ssl_module.CERT_NONE and
511+
resp == b'N'):
512+
# ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
513+
# since the only way to get ssl_is_advisory is from sslmode=prefer
514+
# (or sslmode=allow). But be extra sure to disallow insecure
515+
# connections when the ssl context asks for real security.
516+
pass
452517
else:
453518
raise ConnectionError(
454519
'PostgreSQL server at "{}:{}" rejected SSL upgrade'.format(
455520
host, port))
521+
522+
sock = sock.dup() # Must come before tr.close()
456523
finally:
457524
tr.close()
458525

459-
460-
async def _create_ssl_connection(protocol_factory, host, port, *,
461-
loop, ssl_context):
462-
sock = await _get_ssl_ready_socket(host, port, loop=loop)
463526
try:
464-
return await loop.create_connection(
465-
protocol_factory, sock=sock, ssl=ssl_context,
466-
server_hostname=host)
527+
return await conn_factory(sock=sock) # Must come after tr.close()
467528
except Exception:
468529
sock.close()
469530
raise
470531

471532

533+
async def _create_ssl_connection(protocol_factory, host, port, *,
534+
loop, ssl_context, ssl_is_advisory=False):
535+
return await _negotiate_ssl_connection(
536+
host, port,
537+
functools.partial(loop.create_connection, protocol_factory),
538+
loop=loop,
539+
ssl=ssl_context,
540+
server_hostname=host,
541+
ssl_is_advisory=ssl_is_advisory)
542+
543+
472544
async def _open_connection(*, loop, addr, params: _ConnectionParameters):
473545
if isinstance(addr, str):
474546
r, w = await asyncio.open_unix_connection(addr, loop=loop)
475547
else:
476548
if params.ssl:
477-
sock = await _get_ssl_ready_socket(*addr, loop=loop)
478-
479-
try:
480-
r, w = await asyncio.open_connection(
481-
sock=sock,
482-
loop=loop,
483-
ssl=params.ssl,
484-
server_hostname=addr[0])
485-
except Exception:
486-
sock.close()
487-
raise
488-
549+
r, w = await _negotiate_ssl_connection(
550+
*addr,
551+
functools.partial(asyncio.open_connection, loop=loop),
552+
loop=loop,
553+
ssl=params.ssl,
554+
server_hostname=addr[0],
555+
ssl_is_advisory=params.ssl_is_advisory)
489556
else:
490557
r, w = await asyncio.open_connection(*addr, loop=loop)
491558
_set_nodelay(_get_socket(w.transport))

0 commit comments

Comments
 (0)