Skip to content

Commit b94dd8e

Browse files
authored
PYTHON-4745 - Test behavior of async task cancellation (mongodb#2136)
1 parent 7a7ffa6 commit b94dd8e

File tree

11 files changed

+155
-9
lines changed

11 files changed

+155
-9
lines changed

pymongo/asynchronous/change_stream.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,8 @@ async def try_next(self) -> Optional[_DocumentType]:
391391
if not _resumable(exc) and not exc.timeout:
392392
await self.close()
393393
raise
394-
except Exception:
394+
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
395+
except BaseException:
395396
await self.close()
396397
raise
397398

pymongo/asynchronous/client_session.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,8 @@ async def callback(session, custom_arg, custom_kwarg=None):
697697
)
698698
try:
699699
ret = await callback(self)
700-
except Exception as exc:
700+
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
701+
except BaseException as exc:
701702
if self.in_transaction:
702703
await self.abort_transaction()
703704
if (

pymongo/asynchronous/cursor.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1126,7 +1126,8 @@ async def _send_message(self, operation: Union[_Query, _GetMore]) -> None:
11261126
self._killed = True
11271127
await self.close()
11281128
raise
1129-
except Exception:
1129+
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
1130+
except BaseException:
11301131
await self.close()
11311132
raise
11321133

pymongo/asynchronous/pool.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ async def command(
559559
)
560560
except (OperationFailure, NotPrimaryError):
561561
raise
562-
# Catch socket.error, KeyboardInterrupt, etc. and close ourselves.
562+
# Catch socket.error, KeyboardInterrupt, CancelledError, etc. and close ourselves.
563563
except BaseException as error:
564564
self._raise_connection_failure(error)
565565

@@ -576,6 +576,7 @@ async def send_message(self, message: bytes, max_doc_size: int) -> None:
576576

577577
try:
578578
await async_sendall(self.conn, message)
579+
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
579580
except BaseException as error:
580581
self._raise_connection_failure(error)
581582

@@ -586,6 +587,7 @@ async def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _O
586587
"""
587588
try:
588589
return await receive_message(self, request_id, self.max_message_size)
590+
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
589591
except BaseException as error:
590592
self._raise_connection_failure(error)
591593

@@ -1269,6 +1271,7 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A
12691271

12701272
try:
12711273
sock = await _configured_socket(self.address, self.opts)
1274+
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
12721275
except BaseException as error:
12731276
async with self.lock:
12741277
self.active_contexts.discard(tmp_context)
@@ -1308,6 +1311,7 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A
13081311
handler.contribute_socket(conn, completed_handshake=False)
13091312

13101313
await conn.authenticate()
1314+
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
13111315
except BaseException:
13121316
async with self.lock:
13131317
self.active_contexts.discard(conn.cancel_context)
@@ -1369,6 +1373,7 @@ async def checkout(
13691373
async with self.lock:
13701374
self.active_contexts.add(conn.cancel_context)
13711375
yield conn
1376+
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
13721377
except BaseException:
13731378
# Exception in caller. Ensure the connection gets returned.
13741379
# Note that when pinned is True, the session owns the
@@ -1515,6 +1520,7 @@ async def _get_conn(
15151520
async with self._max_connecting_cond:
15161521
self._pending -= 1
15171522
self._max_connecting_cond.notify()
1523+
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
15181524
except BaseException:
15191525
if conn:
15201526
# We checked out a socket but authentication failed.

pymongo/periodic_executor.py

+2
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ async def _run(self) -> None:
100100
if not await self._target():
101101
self._stopped = True
102102
break
103+
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
103104
except BaseException:
104105
self._stopped = True
105106
raise
@@ -232,6 +233,7 @@ def _run(self) -> None:
232233
if not self._target():
233234
self._stopped = True
234235
break
236+
# Catch KeyboardInterrupt, etc. and cleanup.
235237
except BaseException:
236238
with self._lock:
237239
self._stopped = True

pymongo/synchronous/change_stream.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,8 @@ def try_next(self) -> Optional[_DocumentType]:
389389
if not _resumable(exc) and not exc.timeout:
390390
self.close()
391391
raise
392-
except Exception:
392+
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
393+
except BaseException:
393394
self.close()
394395
raise
395396

pymongo/synchronous/client_session.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,8 @@ def callback(session, custom_arg, custom_kwarg=None):
694694
self.start_transaction(read_concern, write_concern, read_preference, max_commit_time_ms)
695695
try:
696696
ret = callback(self)
697-
except Exception as exc:
697+
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
698+
except BaseException as exc:
698699
if self.in_transaction:
699700
self.abort_transaction()
700701
if (

pymongo/synchronous/cursor.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1124,7 +1124,8 @@ def _send_message(self, operation: Union[_Query, _GetMore]) -> None:
11241124
self._killed = True
11251125
self.close()
11261126
raise
1127-
except Exception:
1127+
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
1128+
except BaseException:
11281129
self.close()
11291130
raise
11301131

pymongo/synchronous/pool.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ def command(
559559
)
560560
except (OperationFailure, NotPrimaryError):
561561
raise
562-
# Catch socket.error, KeyboardInterrupt, etc. and close ourselves.
562+
# Catch socket.error, KeyboardInterrupt, CancelledError, etc. and close ourselves.
563563
except BaseException as error:
564564
self._raise_connection_failure(error)
565565

@@ -576,6 +576,7 @@ def send_message(self, message: bytes, max_doc_size: int) -> None:
576576

577577
try:
578578
sendall(self.conn, message)
579+
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
579580
except BaseException as error:
580581
self._raise_connection_failure(error)
581582

@@ -586,6 +587,7 @@ def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _OpMsg]:
586587
"""
587588
try:
588589
return receive_message(self, request_id, self.max_message_size)
590+
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
589591
except BaseException as error:
590592
self._raise_connection_failure(error)
591593

@@ -1263,6 +1265,7 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect
12631265

12641266
try:
12651267
sock = _configured_socket(self.address, self.opts)
1268+
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
12661269
except BaseException as error:
12671270
with self.lock:
12681271
self.active_contexts.discard(tmp_context)
@@ -1302,6 +1305,7 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect
13021305
handler.contribute_socket(conn, completed_handshake=False)
13031306

13041307
conn.authenticate()
1308+
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
13051309
except BaseException:
13061310
with self.lock:
13071311
self.active_contexts.discard(conn.cancel_context)
@@ -1363,6 +1367,7 @@ def checkout(
13631367
with self.lock:
13641368
self.active_contexts.add(conn.cancel_context)
13651369
yield conn
1370+
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
13661371
except BaseException:
13671372
# Exception in caller. Ensure the connection gets returned.
13681373
# Note that when pinned is True, the session owns the
@@ -1509,6 +1514,7 @@ def _get_conn(
15091514
with self._max_connecting_cond:
15101515
self._pending -= 1
15111516
self._max_connecting_cond.notify()
1517+
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
15121518
except BaseException:
15131519
if conn:
15141520
# We checked out a socket but authentication failed.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright 2025-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Test that async cancellation performed by users clean up resources correctly."""
16+
from __future__ import annotations
17+
18+
import asyncio
19+
import sys
20+
from test.utils import async_get_pool, delay, one
21+
22+
sys.path[0:0] = [""]
23+
24+
from test.asynchronous import AsyncIntegrationTest, async_client_context, connected
25+
26+
27+
class TestAsyncCancellation(AsyncIntegrationTest):
28+
async def test_async_cancellation_closes_connection(self):
29+
pool = await async_get_pool(self.client)
30+
await self.client.db.test.insert_one({"x": 1})
31+
self.addAsyncCleanup(self.client.db.test.delete_many, {})
32+
33+
conn = one(pool.conns)
34+
35+
async def task():
36+
await self.client.db.test.find_one({"$where": delay(0.2)})
37+
38+
task = asyncio.create_task(task())
39+
40+
await asyncio.sleep(0.1)
41+
42+
task.cancel()
43+
with self.assertRaises(asyncio.CancelledError):
44+
await task
45+
46+
self.assertTrue(conn.closed)
47+
48+
@async_client_context.require_transactions
49+
async def test_async_cancellation_aborts_transaction(self):
50+
await self.client.db.test.insert_one({"x": 1})
51+
self.addAsyncCleanup(self.client.db.test.delete_many, {})
52+
53+
session = self.client.start_session()
54+
55+
async def callback(session):
56+
await self.client.db.test.find_one({"$where": delay(0.2)}, session=session)
57+
58+
async def task():
59+
await session.with_transaction(callback)
60+
61+
task = asyncio.create_task(task())
62+
63+
await asyncio.sleep(0.1)
64+
65+
task.cancel()
66+
with self.assertRaises(asyncio.CancelledError):
67+
await task
68+
69+
self.assertFalse(session.in_transaction)
70+
71+
@async_client_context.require_failCommand_blockConnection
72+
async def test_async_cancellation_closes_cursor(self):
73+
await self.client.db.test.insert_many([{"x": 1}, {"x": 2}])
74+
self.addAsyncCleanup(self.client.db.test.delete_many, {})
75+
76+
cursor = self.client.db.test.find({}, batch_size=1)
77+
await cursor.next()
78+
79+
# Make sure getMore commands block
80+
fail_command = {
81+
"configureFailPoint": "failCommand",
82+
"mode": "alwaysOn",
83+
"data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 200},
84+
}
85+
86+
async def task():
87+
async with self.fail_point(fail_command):
88+
await cursor.next()
89+
90+
task = asyncio.create_task(task())
91+
92+
await asyncio.sleep(0.1)
93+
94+
task.cancel()
95+
with self.assertRaises(asyncio.CancelledError):
96+
await task
97+
98+
self.assertTrue(cursor._killed)
99+
100+
@async_client_context.require_change_streams
101+
@async_client_context.require_failCommand_blockConnection
102+
async def test_async_cancellation_closes_change_stream(self):
103+
self.addAsyncCleanup(self.client.db.test.delete_many, {})
104+
change_stream = await self.client.db.test.watch(batch_size=2)
105+
106+
# Make sure getMore commands block
107+
fail_command = {
108+
"configureFailPoint": "failCommand",
109+
"mode": "alwaysOn",
110+
"data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 200},
111+
}
112+
113+
async def task():
114+
async with self.fail_point(fail_command):
115+
await self.client.db.test.insert_many([{"x": 1}, {"x": 2}])
116+
await change_stream.next()
117+
118+
task = asyncio.create_task(task())
119+
120+
await asyncio.sleep(0.1)
121+
122+
task.cancel()
123+
with self.assertRaises(asyncio.CancelledError):
124+
await task
125+
126+
self.assertTrue(change_stream._closed)

tools/synchro.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@
169169

170170
def async_only_test(f: str) -> bool:
171171
"""Return True for async tests that should not be converted to sync."""
172-
return f in ["test_locks.py", "test_concurrency.py"]
172+
return f in ["test_locks.py", "test_concurrency.py", "test_async_cancellation.py"]
173173

174174

175175
test_files = [

0 commit comments

Comments
 (0)