diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index da616523..20b1acbd 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -38,6 +38,7 @@ 'statement_cache_size', 'max_cached_statement_lifetime', 'max_cacheable_statement_size', + 'max_consecutive_exceptions', ]) @@ -210,6 +211,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, database, timeout, command_timeout, statement_cache_size, max_cached_statement_lifetime, max_cacheable_statement_size, + max_consecutive_exceptions, ssl, server_settings): local_vars = locals() @@ -245,7 +247,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, database, command_timeout=command_timeout, statement_cache_size=statement_cache_size, max_cached_statement_lifetime=max_cached_statement_lifetime, - max_cacheable_statement_size=max_cacheable_statement_size,) + max_cacheable_statement_size=max_cacheable_statement_size, + max_consecutive_exceptions=max_consecutive_exceptions,) return addrs, params, config diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 9efd5233..bf697b0b 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -44,7 +44,7 @@ class Connection(metaclass=ConnectionMeta): '_listeners', '_server_version', '_server_caps', '_intro_query', '_reset_query', '_proxy', '_stmt_exclusive_section', '_config', '_params', '_addr', - '_log_listeners', '_cancellations') + '_log_listeners', '_cancellations', '_consecutive_exceptions') def __init__(self, protocol, transport, loop, addr: (str, int) or str, @@ -97,6 +97,7 @@ def __init__(self, protocol, transport, loop, # Used for `con.fetchval()`, `con.fetch()`, `con.fetchrow()`, # `con.execute()`, and `con.executemany()`. self._stmt_exclusive_section = _Atomic() + self._consecutive_exceptions = 0 async def add_listener(self, channel, callback): """Add a listener for Postgres notifications. @@ -1331,6 +1332,7 @@ async def _do_execute(self, query, executor, timeout, retry=True): # It is not possible to recover (the statement is already done at # the server's side), the only way is to drop our caches and # reraise the exception to the caller. + # await self.reload_schema_state() raise except exceptions.InvalidCachedStatementError: @@ -1362,9 +1364,21 @@ async def _do_execute(self, query, executor, timeout, retry=True): else: return await self._do_execute( query, executor, timeout, retry=False) + except: + await self._maybe_close_bad_connection() + raise + self._consecutive_exceptions = 0 return result, stmt + async def _maybe_close_bad_connection(self): + if self._config.max_consecutive_exceptions > 0: + self._consecutive_exceptions += 1 + + if self._consecutive_exceptions > \ + self._config.max_consecutive_exceptions: + await self.close() + async def connect(dsn=None, *, host=None, port=None, @@ -1375,6 +1389,7 @@ async def connect(dsn=None, *, statement_cache_size=100, max_cached_statement_lifetime=300, max_cacheable_statement_size=1024 * 15, + max_consecutive_exceptions=0, command_timeout=None, ssl=None, connection_class=Connection, @@ -1431,6 +1446,11 @@ async def connect(dsn=None, *, default). Pass ``0`` to allow all statements to be cached regardless of their size. + :param int max_consecutive_exceptions: + the maximum number of consecutive exceptions that may be raised by a + single connection before that connection is assumed corrupt (ex. + pointing to an old DB after a failover). Pass ``0`` to disable. + :param float command_timeout: the default timeout for operations on this connection (the default is ``None``: no timeout). @@ -1495,7 +1515,8 @@ class of the returned connection object. Must be a subclass of command_timeout=command_timeout, statement_cache_size=statement_cache_size, max_cached_statement_lifetime=max_cached_statement_lifetime, - max_cacheable_statement_size=max_cacheable_statement_size) + max_cacheable_statement_size=max_cacheable_statement_size, + max_consecutive_exceptions=max_consecutive_exceptions) class _StatementCacheEntry: