diff --git a/channels/utils.py b/channels/utils.py index 72cd9ca3..dcd8e15b 100644 --- a/channels/utils.py +++ b/channels/utils.py @@ -33,12 +33,25 @@ async def await_many_dispatch(consumer_callables, dispatch): """ Given a set of consumer callables, awaits on them all and passes results from them to the dispatch awaitable as they come in. + If a dispatch awaitable raises an exception, + this coroutine will fail with that exception. """ # Call all callables, and ensure all return types are Futures tasks = [ asyncio.ensure_future(consumer_callable()) for consumer_callable in consumer_callables ] + + dispatch_tasks = [] + fut = asyncio.Future() # For child task to report an exception + tasks.append(fut) + + def on_dispatch_task_complete(task): + dispatch_tasks.remove(task) + exc = task.exception() + if exc and not isinstance(exc, asyncio.CancelledError) and not fut.done(): + fut.set_exception(exc) + try: while True: # Wait for any of them to complete @@ -46,9 +59,16 @@ async def await_many_dispatch(consumer_callables, dispatch): # Find the completed one(s), yield results, and replace them for i, task in enumerate(tasks): if task.done(): - result = task.result() - await dispatch(result) - tasks[i] = asyncio.ensure_future(consumer_callables[i]()) + if task == fut: + exc = fut.exception() # Child task has reported an exception + if exc: + raise exc + else: + result = task.result() + task = asyncio.create_task(dispatch(result)) + dispatch_tasks.append(task) + task.add_done_callback(on_dispatch_task_complete) + tasks[i] = asyncio.ensure_future(consumer_callables[i]()) finally: # Make sure we clean up tasks on exit for task in tasks: @@ -57,3 +77,15 @@ async def await_many_dispatch(consumer_callables, dispatch): await task except asyncio.CancelledError: pass + if dispatch_tasks: + """ + This may be needed if the consumer task running this coroutine + is cancelled and one of the subtasks raises an exception after cancellation. + """ + done, pending = await asyncio.wait(dispatch_tasks) + for task in done: + exc = task.exception() + if exc and not isinstance(exc, asyncio.CancelledError): + raise exc + if not fut.done(): + fut.set_result(None)