|
1 | 1 | import asyncio
|
| 2 | +import sys |
2 | 3 | import threading
|
3 | 4 | from types import TracebackType
|
4 | 5 | from typing import (
|
|
29 | 30 | anyio = None # type: ignore
|
30 | 31 |
|
31 | 32 |
|
| 33 | +if sys.version_info >= (3, 11): # pragma: nocover |
| 34 | + import asyncio as asyncio_timeout |
| 35 | + |
| 36 | + anyio_shield = None |
| 37 | +else: # pragma: nocover |
| 38 | + import async_timeout as asyncio_timeout |
| 39 | + |
| 40 | + if anyio is None: # pragma: nocover |
| 41 | + raise RuntimeError("Running in Python<3.11 requires anyio") |
| 42 | + anyio_shield = anyio.CancelScope |
| 43 | + |
| 44 | + |
32 | 45 | AsyncBackend = Literal["asyncio", "trio"]
|
33 | 46 |
|
34 | 47 |
|
@@ -163,9 +176,11 @@ async def wait(self, timeout: Optional[float] = None) -> None:
|
163 | 176 | with trio.fail_after(timeout_or_inf):
|
164 | 177 | await event.wait()
|
165 | 178 | else:
|
166 |
| - asyncio_exc_map: ExceptionMapping = {TimeoutError: PoolTimeout} |
| 179 | + asyncio_exc_map: ExceptionMapping = { |
| 180 | + asyncio.exceptions.TimeoutError: PoolTimeout |
| 181 | + } |
167 | 182 | with map_exceptions(asyncio_exc_map):
|
168 |
| - async with asyncio.timeout(timeout): |
| 183 | + async with asyncio_timeout.timeout(timeout): |
169 | 184 | await event.wait()
|
170 | 185 |
|
171 | 186 |
|
@@ -217,17 +232,20 @@ async def shield(shielded: Callable[[], Coroutine[Any, Any, None]]) -> None:
|
217 | 232 | if current_async_backend() == "trio":
|
218 | 233 | with trio.CancelScope(shield=True):
|
219 | 234 | await shielded()
|
| 235 | + elif sys.version_info < (3, 11): # pragma: nocover |
| 236 | + with anyio_shield(shield=True): |
| 237 | + await shielded() |
220 | 238 | else:
|
221 |
| - await AsyncShieldCancellation._asyncio_shield(shielded) |
| 239 | + await AsyncShieldCancellation._asyncio_shield(shielded) # pragma: nocover |
222 | 240 |
|
223 | 241 | @staticmethod
|
224 | 242 | async def _asyncio_shield(
|
225 | 243 | shielded: Callable[[], Coroutine[Any, Any, None]],
|
226 |
| - ) -> None: |
| 244 | + ) -> None: # pragma: nocover |
227 | 245 | inner_task = asyncio.create_task(shielded())
|
228 | 246 | try:
|
229 | 247 | await asyncio.shield(inner_task)
|
230 |
| - except asyncio.CancelledError: |
| 248 | + except (asyncio.exceptions.CancelledError, asyncio.CancelledError): |
231 | 249 | # Let the inner_task to complete as it was shielded from the cancellation
|
232 | 250 | await inner_task
|
233 | 251 |
|
|
0 commit comments