diff --git a/databases/core.py b/databases/core.py index efa59471..28042009 100644 --- a/databases/core.py +++ b/databases/core.py @@ -385,15 +385,19 @@ async def commit(self) -> None: async with self._connection._transaction_lock: assert self._connection._transaction_stack[-1] is self self._connection._transaction_stack.pop() - await self._transaction.commit() - await self._connection.__aexit__() + try: + await self._transaction.commit() + finally: + await self._connection.__aexit__() async def rollback(self) -> None: async with self._connection._transaction_lock: assert self._connection._transaction_stack[-1] is self self._connection._transaction_stack.pop() - await self._transaction.rollback() - await self._connection.__aexit__() + try: + await self._transaction.rollback() + finally: + await self._connection.__aexit__() class _EmptyNetloc(str): diff --git a/tests/test_databases.py b/tests/test_databases.py index a7545e31..60f17349 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -543,6 +543,33 @@ async def test_transaction_rollback(database_url): assert len(results) == 0 +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_doesnt_leak_when_commit_fails(database_url): + """ + Ensure that transaction doesn't leak the connection when the commit or rollback + fails + """ + + async with Database(database_url) as database: + with pytest.raises(Exception) as excinfo: + async with database.connection() as connection: + await connection.execute( + """ + CREATE TABLE test (id integer PRIMARY KEY INITIALLY DEFERRED); + """ + ) + async with connection.transaction(): + await connection.execute("insert into test (id) values (1)") + await connection.execute("insert into test (id) values (1)") + + # During transaction.commit() postgres will raise this exception: + # asyncpg.exceptions.UniqueViolationError: duplicate key value violates unique constraint "test_pkey" + # DETAIL: Key (id)=(1) already exists. + assert "unique constraint" in str(excinfo.value) + assert connection._connection_counter == 0 + + @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_transaction_commit_low_level(database_url):