Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement connection service file functionality #1223

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 127 additions & 8 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations

import asyncio
import configparser
import collections
from collections.abc import Callable
import enum
Expand Down Expand Up @@ -87,6 +88,9 @@ class SSLNegotiation(compat.StrEnum):
PGPASSFILE = '.pgpass'


PG_SERVICEFILE = '.pg_service.conf'


def _read_password_file(passfile: pathlib.Path) \
-> typing.List[typing.Tuple[str, ...]]:

Expand Down Expand Up @@ -268,7 +272,7 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:


def _parse_connect_dsn_and_args(*, dsn, host, port, user,
password, passfile, database, ssl,
password, passfile, database, ssl, service,
direct_tls, server_settings,
target_session_attrs, krbsrvname, gsslib):
# `auth_hosts` is the version of host information for the purposes
Expand All @@ -281,6 +285,28 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if dsn:
parsed = urllib.parse.urlparse(dsn)

query = None
if parsed.query:
query = urllib.parse.parse_qs(parsed.query, strict_parsing=True)
for key, val in query.items():
if isinstance(val, list):
query[key] = val[-1]

if 'service' in query:
val = query.pop('service')
if not service and val:
service = val

connection_service_file = os.getenv('PGSERVICEFILE')
if connection_service_file is None:
homedir = compat.get_pg_home_directory()
if homedir:
connection_service_file = homedir / PG_SERVICEFILE
else:
connection_service_file = None
else:
connection_service_file = pathlib.Path(connection_service_file)

if parsed.scheme not in {'postgresql', 'postgres'}:
raise exceptions.ClientConfigurationError(
'invalid DSN: scheme is expected to be either '
Expand Down Expand Up @@ -315,11 +341,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if password is None and dsn_password:
password = urllib.parse.unquote(dsn_password)

if parsed.query:
query = urllib.parse.parse_qs(parsed.query, strict_parsing=True)
for key, val in query.items():
if isinstance(val, list):
query[key] = val[-1]
if query:

if 'port' in query:
val = query.pop('port')
Expand Down Expand Up @@ -406,12 +428,108 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if gsslib is None:
gsslib = val

if 'service' in query:
val = query.pop('service')
if service is None:
service = val

if query:
if server_settings is None:
server_settings = query
else:
server_settings = {**query, **server_settings}

if connection_service_file is not None and service is not None:
pg_service = configparser.ConfigParser()
pg_service.read(connection_service_file)
if service in pg_service.sections():
service_params = pg_service[service]
if 'port' in service_params:
val = service_params.pop('port')
if not port and val:
port = [int(p) for p in val.split(',')]

if 'host' in service_params:
val = service_params.pop('host')
if not host and val:
host, port = _parse_hostlist(val, port)

if 'dbname' in service_params:
val = service_params.pop('dbname')
if database is None:
database = val

if 'database' in service_params:
val = service_params.pop('database')
if database is None:
database = val

if 'user' in service_params:
val = service_params.pop('user')
if user is None:
user = val

if 'password' in service_params:
val = service_params.pop('password')
if password is None:
password = val

if 'passfile' in service_params:
val = service_params.pop('passfile')
if passfile is None:
passfile = val

if 'sslmode' in service_params:
val = service_params.pop('sslmode')
if ssl is None:
ssl = val

if 'sslcert' in service_params:
sslcert = service_params.pop('sslcert')

if 'sslkey' in service_params:
sslkey = service_params.pop('sslkey')

if 'sslrootcert' in service_params:
sslrootcert = service_params.pop('sslrootcert')

if 'sslnegotiation' in service_params:
sslnegotiation = service_params.pop('sslnegotiation')

if 'sslcrl' in service_params:
sslcrl = service_params.pop('sslcrl')

if 'sslpassword' in service_params:
sslpassword = service_params.pop('sslpassword')

if 'ssl_min_protocol_version' in service_params:
ssl_min_protocol_version = service_params.pop(
'ssl_min_protocol_version'
)

if 'ssl_max_protocol_version' in service_params:
ssl_max_protocol_version = service_params.pop(
'ssl_max_protocol_version'
)

if 'target_session_attrs' in service_params:
dsn_target_session_attrs = service_params.pop(
'target_session_attrs'
)
if target_session_attrs is None:
target_session_attrs = dsn_target_session_attrs

if 'krbsrvname' in service_params:
val = service_params.pop('krbsrvname')
if krbsrvname is None:
krbsrvname = val

if 'gsslib' in service_params:
val = service_params.pop('gsslib')
if gsslib is None:
gsslib = val
if not service:
service = os.environ.get('PGSERVICE')
if not host:
hostspec = os.environ.get('PGHOST')
if hostspec:
Expand Down Expand Up @@ -724,7 +842,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
max_cached_statement_lifetime,
max_cacheable_statement_size,
ssl, direct_tls, server_settings,
target_session_attrs, krbsrvname, gsslib):
target_session_attrs, krbsrvname, gsslib,
service):
local_vars = locals()
for var_name in {'max_cacheable_statement_size',
'max_cached_statement_lifetime',
Expand Down Expand Up @@ -754,7 +873,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
direct_tls=direct_tls, database=database,
server_settings=server_settings,
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname, gsslib=gsslib)
krbsrvname=krbsrvname, gsslib=gsslib, service=service)

config = _ClientConfiguration(
command_timeout=command_timeout,
Expand Down
6 changes: 6 additions & 0 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2074,6 +2074,7 @@ async def _do_execute(
async def connect(dsn=None, *,
host=None, port=None,
user=None, password=None, passfile=None,
service=None,
database=None,
loop=None,
timeout=60,
Expand Down Expand Up @@ -2183,6 +2184,10 @@ async def connect(dsn=None, *,
(defaults to ``~/.pgpass``, or ``%APPDATA%\postgresql\pgpass.conf``
on Windows).

:param service:
The name of the postgres connection service stored in the postgres
connection service file.

:param loop:
An asyncio event loop instance. If ``None``, the default
event loop will be used.
Expand Down Expand Up @@ -2428,6 +2433,7 @@ async def connect(dsn=None, *,
user=user,
password=password,
passfile=passfile,
service=service,
ssl=ssl,
direct_tls=direct_tls,
database=database,
Expand Down
111 changes: 109 additions & 2 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,7 +1116,8 @@ def run_testcase(self, testcase):
env = testcase.get('env', {})
test_env = {'PGHOST': None, 'PGPORT': None,
'PGUSER': None, 'PGPASSWORD': None,
'PGDATABASE': None, 'PGSSLMODE': None}
'PGDATABASE': None, 'PGSSLMODE': None,
'PGSERVICE': None, }
test_env.update(env)

dsn = testcase.get('dsn')
Expand All @@ -1132,6 +1133,7 @@ def run_testcase(self, testcase):
target_session_attrs = testcase.get('target_session_attrs')
krbsrvname = testcase.get('krbsrvname')
gsslib = testcase.get('gsslib')
service = testcase.get('service')

expected = testcase.get('result')
expected_error = testcase.get('error')
Expand All @@ -1157,7 +1159,7 @@ def run_testcase(self, testcase):
direct_tls=direct_tls,
server_settings=server_settings,
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname, gsslib=gsslib)
krbsrvname=krbsrvname, gsslib=gsslib, service=service)

params = {
k: v for k, v in params._asdict().items()
Expand Down Expand Up @@ -1236,6 +1238,111 @@ def test_connect_params(self):
for testcase in self.TESTS:
self.run_testcase(testcase)

def test_connect_connection_service_file(self):
connection_service_file = tempfile.NamedTemporaryFile(
'w+t', delete=False)
connection_service_file.write(textwrap.dedent('''
[test_service_dbname]
port=5433
host=somehost
dbname=test_dbname
user=admin
password=test_password
target_session_attrs=primary
krbsrvname=fakekrbsrvname
gsslib=sspi

[test_service_database]
port=5433
host=somehost
database=test_dbname
user=admin
password=test_password
target_session_attrs=primary
krbsrvname=fakekrbsrvname
gsslib=sspi
'''))
connection_service_file.close()
os.chmod(connection_service_file.name, stat.S_IWUSR | stat.S_IRUSR)
try:
# Test connection service file with dbname
self.run_testcase({
'dsn': 'postgresql://?service=test_service_dbname',
'env': {
'PGSERVICEFILE': connection_service_file.name
},
'result': (
[('somehost', 5433)],
{
'user': 'admin',
'password': 'test_password',
'database': 'test_dbname',
'target_session_attrs': 'primary',
'krbsrvname': 'fakekrbsrvname',
'gsslib': 'sspi',
}
)
})
# Test connection service file with database
self.run_testcase({
'dsn': 'postgresql://?service=test_service_database',
'env': {
'PGSERVICEFILE': connection_service_file.name
},
'result': (
[('somehost', 5433)],
{
'user': 'admin',
'password': 'test_password',
'database': 'test_dbname',
'target_session_attrs': 'primary',
'krbsrvname': 'fakekrbsrvname',
'gsslib': 'sspi',
}
)
})
# Test that envvars are overridden by service file
self.run_testcase({
'dsn': 'postgresql://?service=test_service_dbname',
'env': {
'PGUSER': 'user',
'PGSERVICEFILE': connection_service_file.name
},
'result': (
[('somehost', 5433)],
{
'user': 'admin',
'password': 'test_password',
'database': 'test_dbname',
'target_session_attrs': 'primary',
'krbsrvname': 'fakekrbsrvname',
'gsslib': 'sspi',
}
)
})
# Test that dsn params overwrite service file
self.run_testcase({
'dsn': 'postgresql://?service={}&dbname={}'.format(
"test_service_dbname", "test_dbname_dsn"
),
'env': {
'PGSERVICEFILE': connection_service_file.name
},
'result': (
[('somehost', 5433)],
{
'user': 'admin',
'password': 'test_password',
'database': 'test_dbname_dsn',
'target_session_attrs': 'primary',
'krbsrvname': 'fakekrbsrvname',
'gsslib': 'sspi',
}
)
})
finally:
os.unlink(connection_service_file.name)

def test_connect_pgpass_regular(self):
passfile = tempfile.NamedTemporaryFile('w+t', delete=False)
passfile.write(textwrap.dedent(R'''
Expand Down