Skip to content

Commit

Permalink
Maintain contextvars.Context in fixtures and tests
Browse files Browse the repository at this point in the history
The approach I've taken here is to maintain a contextvars.Context
instance in a contextvars.ContextVar, copying it from the ambient
context whenever we create a new event loop. The fixture setup
and teardown run within that context, and each test function gets
a copy (as if it were created as a new asyncio.Task from within the
fixture task).

Fixes pytest-dev#127.
  • Loading branch information
bcmills committed Dec 6, 2024
1 parent a1cd861 commit 7733e66
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 7 deletions.
73 changes: 66 additions & 7 deletions pytest_asyncio/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import asyncio
import contextlib
import contextvars
import enum
import functools
import inspect
Expand Down Expand Up @@ -54,6 +55,7 @@
_ScopeName = Literal["session", "package", "module", "class", "function"]
_T = TypeVar("_T")


SimpleFixtureFunction = TypeVar(
"SimpleFixtureFunction", bound=Callable[..., Awaitable[object]]
)
Expand Down Expand Up @@ -318,6 +320,8 @@ def _asyncgen_fixture_wrapper(request: FixtureRequest, **kwargs: Any):
kwargs.pop(event_loop_fixture_id, None)
gen_obj = func(**_add_kwargs(func, kwargs, event_loop, request))

context = _event_loop_context.get(None)

async def setup():
res = await gen_obj.__anext__() # type: ignore[union-attr]
return res
Expand All @@ -335,9 +339,11 @@ async def async_finalizer() -> None:
msg += "Yield only once."
raise ValueError(msg)

event_loop.run_until_complete(async_finalizer())
task = _create_task_in_context(event_loop, async_finalizer(), context)
event_loop.run_until_complete(task)

result = event_loop.run_until_complete(setup())
setup_task = _create_task_in_context(event_loop, setup(), context)
result = event_loop.run_until_complete(setup_task)
request.addfinalizer(finalizer)
return result

Expand All @@ -360,7 +366,10 @@ async def setup():
res = await func(**_add_kwargs(func, kwargs, event_loop, request))
return res

return event_loop.run_until_complete(setup())
task = _create_task_in_context(
event_loop, setup(), _event_loop_context.get(None)
)
return event_loop.run_until_complete(task)

fixturedef.func = _async_fixture_wrapper # type: ignore[misc]

Expand Down Expand Up @@ -584,6 +593,46 @@ def pytest_pycollect_makeitem_convert_async_functions_to_subclass(
Session: "session",
}

# _event_loop_context stores the Context in which asyncio tasks on the fixture
# event loop should be run. After fixture setup, individual async test functions
# are run on copies of this context.
_event_loop_context: contextvars.ContextVar[contextvars.Context] = (
contextvars.ContextVar("pytest_asyncio_event_loop_context")
)


@contextlib.contextmanager
def _set_event_loop_context():
"""Set event_loop_context to a copy of the calling thread's current context."""
context = contextvars.copy_context()
token = _event_loop_context.set(context)
try:
yield
finally:
_event_loop_context.reset(token)


def _create_task_in_context(loop, coro, context):
"""
Return an asyncio task that runs the coro in the specified context,
if possible.
This allows fixture setup and teardown to be run as separate asyncio tasks,
while still being able to use context-manager idioms to maintain context
variables and make those variables visible to test functions.
This is only fully supported on Python 3.11 and newer, as it requires
the API added for https://github.com/python/cpython/issues/91150.
On earlier versions, the returned task will use the default context instead.
"""
if context is not None:
try:
return loop.create_task(coro, context=context)
except TypeError:
pass
return loop.create_task(coro)


# A stack used to push package-scoped loops during collection of a package
# and pop those loops during collection of a Module
__package_loop_stack: list[FixtureFunctionMarker | FixtureFunction] = []
Expand Down Expand Up @@ -631,7 +680,8 @@ def scoped_event_loop(
loop = asyncio.new_event_loop()
loop.__pytest_asyncio = True # type: ignore[attr-defined]
asyncio.set_event_loop(loop)
yield loop
with _set_event_loop_context():
yield loop
loop.close()

# @pytest.fixture does not register the fixture anywhere, so pytest doesn't
Expand Down Expand Up @@ -938,9 +988,16 @@ def wrap_in_sync(

@functools.wraps(func)
def inner(*args, **kwargs):
# Give each test its own context based on the loop's main context.
context = _event_loop_context.get(None)
if context is not None:
# We are using our own event loop fixture, so make a new copy of the
# fixture context so that the test won't pollute it.
context = context.copy()

coro = func(*args, **kwargs)
_loop = _get_event_loop_no_warn()
task = asyncio.ensure_future(coro, loop=_loop)
task = _create_task_in_context(_loop, coro, context)
try:
_loop.run_until_complete(task)
except BaseException:
Expand Down Expand Up @@ -1049,7 +1106,8 @@ def event_loop(request: FixtureRequest) -> Iterator[asyncio.AbstractEventLoop]:
# The magic value must be set as part of the function definition, because pytest
# seems to have multiple instances of the same FixtureDef or fixture function
loop.__original_fixture_loop = True # type: ignore[attr-defined]
yield loop
with _set_event_loop_context():
yield loop
loop.close()


Expand All @@ -1062,7 +1120,8 @@ def _session_event_loop(
loop = asyncio.new_event_loop()
loop.__pytest_asyncio = True # type: ignore[attr-defined]
asyncio.set_event_loop(loop)
yield loop
with _set_event_loop_context():
yield loop
loop.close()


Expand Down
36 changes: 36 additions & 0 deletions tests/async_fixtures/test_async_fixtures_contextvars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
Regression test for https://github.com/pytest-dev/pytest-asyncio/issues/127:
contextvars were not properly maintained among fixtures and tests.
"""

from __future__ import annotations

import sys
from contextlib import asynccontextmanager
from contextvars import ContextVar

import pytest


@asynccontextmanager
async def context_var_manager():
context_var = ContextVar("context_var")
token = context_var.set("value")
try:
yield context_var
finally:
context_var.reset(token)


@pytest.fixture(scope="function")
async def context_var():
async with context_var_manager() as v:
yield v


@pytest.mark.asyncio
@pytest.mark.xfail(
sys.version_info < (3, 11), reason="requires asyncio Task context support"
)
async def test(context_var):
assert context_var.get() == "value"

0 comments on commit 7733e66

Please sign in to comment.