Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Maintain contextvars between fixtures and tests #1008

Merged
merged 6 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 79 additions & 3 deletions pytest_asyncio/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import asyncio
import contextlib
import contextvars
import enum
import functools
import inspect
Expand Down Expand Up @@ -322,6 +323,12 @@ async def setup():
res = await gen_obj.__anext__() # type: ignore[union-attr]
return res

context = contextvars.copy_context()
setup_task = _create_task_in_context(event_loop, setup(), context)
result = event_loop.run_until_complete(setup_task)

reset_contextvars = _apply_contextvar_changes(context)

def finalizer() -> None:
"""Yield again, to finalize."""

Expand All @@ -335,9 +342,11 @@ async def async_finalizer() -> None:
msg += "Yield only once."
raise ValueError(msg)

event_loop.run_until_complete(async_finalizer())
task = _create_task_in_context(event_loop, async_finalizer(), context)
event_loop.run_until_complete(task)
if reset_contextvars is not None:
reset_contextvars()

result = event_loop.run_until_complete(setup())
request.addfinalizer(finalizer)
return result

Expand All @@ -360,7 +369,23 @@ async def setup():
res = await func(**_add_kwargs(func, kwargs, event_loop, request))
return res

return event_loop.run_until_complete(setup())
context = contextvars.copy_context()
setup_task = _create_task_in_context(event_loop, setup(), context)
result = event_loop.run_until_complete(setup_task)

# Copy the context vars modified by the setup task into the current
# context, and (if needed) add a finalizer to reset them.
#
# Note that this is slightly different from the behavior of a non-async
# fixture, which would rely on the fixture author to add a finalizer
# to reset the variables. In this case, the author of the fixture can't
# write such a finalizer because they have no way to capture the Context
# in which the setup function was run, so we need to do it for them.
reset_contextvars = _apply_contextvar_changes(context)
if reset_contextvars is not None:
request.addfinalizer(reset_contextvars)

return result

fixturedef.func = _async_fixture_wrapper # type: ignore[misc]

Expand All @@ -385,6 +410,57 @@ def _get_event_loop_fixture_id_for_async_fixture(
return event_loop_fixture_id


def _create_task_in_context(loop, coro, context):
bcmills marked this conversation as resolved.
Show resolved Hide resolved
"""
Return an asyncio task that runs the coro in the specified context,
if possible.

This allows fixture setup and teardown to be run as separate asyncio tasks,
while still being able to use context-manager idioms to maintain context
variables and make those variables visible to test functions.

This is only fully supported on Python 3.11 and newer, as it requires
the API added for https://github.com/python/cpython/issues/91150.
On earlier versions, the returned task will use the default context instead.
"""
try:
return loop.create_task(coro, context=context)
except TypeError:
return loop.create_task(coro)


def _apply_contextvar_changes(
context: contextvars.Context,
) -> Callable[[], None] | None:
"""
Copy contextvar changes from the given context to the current context.

If any contextvars were modified by the fixture, return a finalizer that
will restore them.
"""
context_tokens = []
for var in context:
try:
if var.get() is context.get(var):
# This variable is not modified, so leave it as-is.
continue
except LookupError:
# This variable isn't yet set in the current context at all.
pass
token = var.set(context.get(var))
context_tokens.append((var, token))

if not context_tokens:
return None

def restore_contextvars():
while context_tokens:
(var, token) = context_tokens.pop()
var.reset(token)

return restore_contextvars


class PytestAsyncioFunction(Function):
"""Base class for all test functions managed by pytest-asyncio."""

Expand Down
73 changes: 73 additions & 0 deletions tests/async_fixtures/test_async_fixtures_contextvars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
Regression test for https://github.com/pytest-dev/pytest-asyncio/issues/127:
contextvars were not properly maintained among fixtures and tests.
"""

from __future__ import annotations

import sys
from contextlib import contextmanager
from contextvars import ContextVar

import pytest

_context_var = ContextVar("context_var")


@contextmanager
def context_var_manager(value):
token = _context_var.set(value)
try:
yield
finally:
_context_var.reset(token)


@pytest.fixture(scope="function")
bcmills marked this conversation as resolved.
Show resolved Hide resolved
async def no_var_fixture():
with pytest.raises(LookupError):
_context_var.get()
yield
with pytest.raises(LookupError):
_context_var.get()


@pytest.fixture(scope="function")
async def var_fixture_1(no_var_fixture):
with context_var_manager("value1"):
yield


@pytest.fixture(scope="function")
async def var_nop_fixture(var_fixture_1):
with context_var_manager(_context_var.get()):
yield


@pytest.fixture(scope="function")
def var_fixture_2(var_nop_fixture):
assert _context_var.get() == "value1"
with context_var_manager("value2"):
yield


@pytest.fixture(scope="function")
async def var_fixture_3(var_fixture_2):
assert _context_var.get() == "value2"
with context_var_manager("value3"):
yield


@pytest.fixture(scope="function")
async def var_fixture_4(var_fixture_3, request):
assert _context_var.get() == "value3"
_context_var.set("value4")
# Rely on fixture teardown to reset the context var.


@pytest.mark.asyncio
@pytest.mark.xfail(
sys.version_info < (3, 11), reason="requires asyncio Task context support"
)
async def test(var_fixture_4):
assert _context_var.get() == "value4"
Loading