diff --git a/synchronicity/exceptions.py b/synchronicity/exceptions.py index 2b11a26..9f15709 100644 --- a/synchronicity/exceptions.py +++ b/synchronicity/exceptions.py @@ -1,6 +1,10 @@ import asyncio +class SynchronizerShutdown(Exception): + pass + + class UserCodeException(Exception): """This is used to wrap and unwrap exceptions in "user code". diff --git a/synchronicity/synchronizer.py b/synchronicity/synchronizer.py index 8ea3aa1..f35161c 100644 --- a/synchronicity/synchronizer.py +++ b/synchronicity/synchronizer.py @@ -1,6 +1,7 @@ import asyncio import atexit import collections.abc +import concurrent.futures import contextlib import functools import inspect @@ -17,7 +18,7 @@ from .async_wrap import wraps_by_interface from .callback import Callback -from .exceptions import UserCodeException, unwrap_coro_exception, wrap_coro_exception +from .exceptions import SynchronizerShutdown, UserCodeException, unwrap_coro_exception, wrap_coro_exception from .interface import Interface _BUILTIN_ASYNC_METHODS = { @@ -100,6 +101,8 @@ def should_have_aio_interface(func): class Synchronizer: """Helps you offer a blocking (synchronous) interface to asynchronous code.""" + _stopping: asyncio.Event + def __init__( self, multiwrap_warning=False, @@ -313,7 +316,15 @@ def _run_function_sync(self, coro, interface, original_func): coro = self._wrap_check_async_leakage(coro) loop = self._get_loop(start=True) fut = asyncio.run_coroutine_threadsafe(coro, loop) - value = fut.result() + try: + value = fut.result() + except concurrent.futures.CancelledError: + if not self._loop or self._stopping.is_set(): + # this allows differentiate between wrapped code raising concurrent.futures.CancelledError + # and synchronicity itself raising it due to the event loop shutting down while waiting + # for the synchronizer. + raise SynchronizerShutdown() + raise if getattr(original_func, self._output_translation_attr, True): return self._translate_out(value, interface) @@ -332,12 +343,18 @@ async def _run_function_async(self, coro, interface, original_func): coro = wrap_coro_exception(coro) coro = self._wrap_check_async_leakage(coro) loop = self._get_loop(start=True) + if self._is_inside_loop(): value = await coro else: - c_fut = asyncio.run_coroutine_threadsafe(coro, loop) - a_fut = asyncio.wrap_future(c_fut) - value = await a_fut + try: + c_fut = asyncio.run_coroutine_threadsafe(coro, loop) + a_fut = asyncio.wrap_future(c_fut) + value = await a_fut + except asyncio.CancelledError: + if not self._loop or self._stopping.is_set(): + raise SynchronizerShutdown() + raise if getattr(original_func, self._output_translation_attr, True): return self._translate_out(value, interface) diff --git a/test/shutdown_test.py b/test/shutdown_test.py index fadff2c..b81f50b 100644 --- a/test/shutdown_test.py +++ b/test/shutdown_test.py @@ -1,7 +1,14 @@ +import asyncio import os +import pytest import signal import subprocess import sys +import threading +import time + +import synchronicity +from synchronicity.exceptions import SynchronizerShutdown def test_shutdown(): @@ -21,3 +28,57 @@ def test_shutdown(): assert p.stdout.readline() == b"exiting\n" stderr_content = p.stderr.read() assert b"Traceback" not in stderr_content + + +def test_shutdown_raises_shutdown_error(): + s = synchronicity.Synchronizer() + + @s.create_blocking + async def wrapped(): + await asyncio.sleep(10) + + def shut_down_soon(): + s._get_loop(start=True) # ensure loop is running + time.sleep(0.1) + s._close_loop() + + t = threading.Thread(target=shut_down_soon) + t.start() + + with pytest.raises(SynchronizerShutdown): + wrapped() + + t.join() + + +@pytest.mark.asyncio +async def test_shutdown_raises_shutdown_error_async(): + s = synchronicity.Synchronizer() + + @s.create_blocking + async def wrapped(): + await asyncio.sleep(10) + + @s.create_blocking + async def supercall(): + try: + # loop-internal calls should propagate the CancelledError + return await wrapped.aio() + except asyncio.CancelledError: + raise # expected + except BaseException: + raise Exception("asyncio.CancelledError is expected internally") + + def shut_down_soon(): + s._get_loop(start=True) # ensure loop is running + time.sleep(0.1) + s._close_loop() + + t = threading.Thread(target=shut_down_soon) + t.start() + + with pytest.raises(SynchronizerShutdown): + # calls from outside of the synchronizer loop should get the SynchronizerShutdown + await supercall.aio() + + t.join()