From aec8eb471309f937e8dfc0e1f78e2e972c26a7e3 Mon Sep 17 00:00:00 2001 From: "Bryan C. Mills" Date: Fri, 6 Dec 2024 12:35:29 -0500 Subject: [PATCH] Maintain contextvars.Context in fixtures and tests The approach I've taken here is to maintain a contextvars.Context instance in a contextvars.ContextVar, copying it from the ambient context whenever we create a new event loop. The fixture setup and teardown run within that context, and each test function gets a copy (as if it were created as a new asyncio.Task from within the fixture task). Fixes pytest-dev/pytest-asyncio#127. --- pytest_asyncio/plugin.py | 72 +++++++++++++++++-- .../test_async_fixtures_contextvars.py | 36 ++++++++++ 2 files changed, 101 insertions(+), 7 deletions(-) create mode 100644 tests/async_fixtures/test_async_fixtures_contextvars.py diff --git a/pytest_asyncio/plugin.py b/pytest_asyncio/plugin.py index d3e006d8..762c58ae 100644 --- a/pytest_asyncio/plugin.py +++ b/pytest_asyncio/plugin.py @@ -4,6 +4,7 @@ import asyncio import contextlib +import contextvars import enum import functools import inspect @@ -318,6 +319,8 @@ def _asyncgen_fixture_wrapper(request: FixtureRequest, **kwargs: Any): kwargs.pop(event_loop_fixture_id, None) gen_obj = func(**_add_kwargs(func, kwargs, event_loop, request)) + context = _event_loop_context.get(None) + async def setup(): res = await gen_obj.__anext__() # type: ignore[union-attr] return res @@ -335,9 +338,11 @@ async def async_finalizer() -> None: msg += "Yield only once." raise ValueError(msg) - event_loop.run_until_complete(async_finalizer()) + task = _create_task_in_context(event_loop, async_finalizer(), context) + event_loop.run_until_complete(task) - result = event_loop.run_until_complete(setup()) + setup_task = _create_task_in_context(event_loop, setup(), context) + result = event_loop.run_until_complete(setup_task) request.addfinalizer(finalizer) return result @@ -360,7 +365,10 @@ async def setup(): res = await func(**_add_kwargs(func, kwargs, event_loop, request)) return res - return event_loop.run_until_complete(setup()) + task = _create_task_in_context( + event_loop, setup(), _event_loop_context.get(None) + ) + return event_loop.run_until_complete(task) fixturedef.func = _async_fixture_wrapper # type: ignore[misc] @@ -584,6 +592,46 @@ def pytest_pycollect_makeitem_convert_async_functions_to_subclass( Session: "session", } +# _event_loop_context stores the Context in which asyncio tasks on the fixture +# event loop should be run. After fixture setup, individual async test functions +# are run on copies of this context. +_event_loop_context: contextvars.ContextVar[contextvars.Context] = ( + contextvars.ContextVar("pytest_asyncio_event_loop_context") +) + + +@contextlib.contextmanager +def _set_event_loop_context(): + """Set event_loop_context to a copy of the calling thread's current context.""" + context = contextvars.copy_context() + token = _event_loop_context.set(context) + try: + yield + finally: + _event_loop_context.reset(token) + + +def _create_task_in_context(loop, coro, context): + """ + Return an asyncio task that runs the coro in the specified context, + if possible. + + This allows fixture setup and teardown to be run as separate asyncio tasks, + while still being able to use context-manager idioms to maintain context + variables and make those variables visible to test functions. + + This is only fully supported on Python 3.11 and newer, as it requires + the API added for https://github.com/python/cpython/issues/91150. + On earlier versions, the returned task will use the default context instead. + """ + if context is not None: + try: + return loop.create_task(coro, context=context) + except TypeError: + pass + return loop.create_task(coro) + + # A stack used to push package-scoped loops during collection of a package # and pop those loops during collection of a Module __package_loop_stack: list[FixtureFunctionMarker | FixtureFunction] = [] @@ -631,7 +679,8 @@ def scoped_event_loop( loop = asyncio.new_event_loop() loop.__pytest_asyncio = True # type: ignore[attr-defined] asyncio.set_event_loop(loop) - yield loop + with _set_event_loop_context(): + yield loop loop.close() # @pytest.fixture does not register the fixture anywhere, so pytest doesn't @@ -938,9 +987,16 @@ def wrap_in_sync( @functools.wraps(func) def inner(*args, **kwargs): + # Give each test its own context based on the loop's main context. + context = _event_loop_context.get(None) + if context is not None: + # We are using our own event loop fixture, so make a new copy of the + # fixture context so that the test won't pollute it. + context = context.copy() + coro = func(*args, **kwargs) _loop = _get_event_loop_no_warn() - task = asyncio.ensure_future(coro, loop=_loop) + task = _create_task_in_context(_loop, coro, context) try: _loop.run_until_complete(task) except BaseException: @@ -1049,7 +1105,8 @@ def event_loop(request: FixtureRequest) -> Iterator[asyncio.AbstractEventLoop]: # The magic value must be set as part of the function definition, because pytest # seems to have multiple instances of the same FixtureDef or fixture function loop.__original_fixture_loop = True # type: ignore[attr-defined] - yield loop + with _set_event_loop_context(): + yield loop loop.close() @@ -1062,7 +1119,8 @@ def _session_event_loop( loop = asyncio.new_event_loop() loop.__pytest_asyncio = True # type: ignore[attr-defined] asyncio.set_event_loop(loop) - yield loop + with _set_event_loop_context(): + yield loop loop.close() diff --git a/tests/async_fixtures/test_async_fixtures_contextvars.py b/tests/async_fixtures/test_async_fixtures_contextvars.py new file mode 100644 index 00000000..25bb8106 --- /dev/null +++ b/tests/async_fixtures/test_async_fixtures_contextvars.py @@ -0,0 +1,36 @@ +""" +Regression test for https://github.com/pytest-dev/pytest-asyncio/issues/127: +contextvars were not properly maintained among fixtures and tests. +""" + +from __future__ import annotations + +import sys +from contextlib import asynccontextmanager +from contextvars import ContextVar + +import pytest + + +@asynccontextmanager +async def context_var_manager(): + context_var = ContextVar("context_var") + token = context_var.set("value") + try: + yield context_var + finally: + context_var.reset(token) + + +@pytest.fixture(scope="function") +async def context_var(): + async with context_var_manager() as v: + yield v + + +@pytest.mark.asyncio +@pytest.mark.xfail( + sys.version_info < (3, 11), reason="requires asyncio Task context support" +) +async def test(context_var): + assert context_var.get() == "value"