diff --git a/docs/reference/changelog.rst b/docs/reference/changelog.rst index d29fa184..6f300ba7 100644 --- a/docs/reference/changelog.rst +++ b/docs/reference/changelog.rst @@ -7,6 +7,8 @@ Changelog - Deprecated: Added warning when asyncio test requests async ``@pytest.fixture`` in strict mode. This will become an error in a future version of flake8-asyncio. `#979 `_ - Updates the error message about `pytest.mark.asyncio`'s `scope` keyword argument to say `loop_scope` instead. `#1004 `_ - Verbose log displays correct parameter name: asyncio_default_fixture_loop_scope `#990 `_ +- Propagates `contextvars` set in async fixtures to other fixtures and tests on Python 3.11 and above. `#1008 `_ + 0.24.0 (2024-08-22) =================== diff --git a/pytest_asyncio/plugin.py b/pytest_asyncio/plugin.py index d3e006d8..12ead10f 100644 --- a/pytest_asyncio/plugin.py +++ b/pytest_asyncio/plugin.py @@ -4,6 +4,7 @@ import asyncio import contextlib +import contextvars import enum import functools import inspect @@ -13,6 +14,7 @@ from collections.abc import ( AsyncIterator, Awaitable, + Coroutine as AbstractCoroutine, Generator, Iterable, Iterator, @@ -322,6 +324,12 @@ async def setup(): res = await gen_obj.__anext__() # type: ignore[union-attr] return res + context = contextvars.copy_context() + setup_task = _create_task_in_context(event_loop, setup(), context) + result = event_loop.run_until_complete(setup_task) + + reset_contextvars = _apply_contextvar_changes(context) + def finalizer() -> None: """Yield again, to finalize.""" @@ -335,9 +343,11 @@ async def async_finalizer() -> None: msg += "Yield only once." raise ValueError(msg) - event_loop.run_until_complete(async_finalizer()) + task = _create_task_in_context(event_loop, async_finalizer(), context) + event_loop.run_until_complete(task) + if reset_contextvars is not None: + reset_contextvars() - result = event_loop.run_until_complete(setup()) request.addfinalizer(finalizer) return result @@ -360,7 +370,23 @@ async def setup(): res = await func(**_add_kwargs(func, kwargs, event_loop, request)) return res - return event_loop.run_until_complete(setup()) + context = contextvars.copy_context() + setup_task = _create_task_in_context(event_loop, setup(), context) + result = event_loop.run_until_complete(setup_task) + + # Copy the context vars modified by the setup task into the current + # context, and (if needed) add a finalizer to reset them. + # + # Note that this is slightly different from the behavior of a non-async + # fixture, which would rely on the fixture author to add a finalizer + # to reset the variables. In this case, the author of the fixture can't + # write such a finalizer because they have no way to capture the Context + # in which the setup function was run, so we need to do it for them. + reset_contextvars = _apply_contextvar_changes(context) + if reset_contextvars is not None: + request.addfinalizer(reset_contextvars) + + return result fixturedef.func = _async_fixture_wrapper # type: ignore[misc] @@ -385,6 +411,61 @@ def _get_event_loop_fixture_id_for_async_fixture( return event_loop_fixture_id +def _create_task_in_context( + loop: asyncio.AbstractEventLoop, + coro: AbstractCoroutine[Any, Any, _T], + context: contextvars.Context, +) -> asyncio.Task[_T]: + """ + Return an asyncio task that runs the coro in the specified context, + if possible. + + This allows fixture setup and teardown to be run as separate asyncio tasks, + while still being able to use context-manager idioms to maintain context + variables and make those variables visible to test functions. + + This is only fully supported on Python 3.11 and newer, as it requires + the API added for https://github.com/python/cpython/issues/91150. + On earlier versions, the returned task will use the default context instead. + """ + try: + return loop.create_task(coro, context=context) + except TypeError: + return loop.create_task(coro) + + +def _apply_contextvar_changes( + context: contextvars.Context, +) -> Callable[[], None] | None: + """ + Copy contextvar changes from the given context to the current context. + + If any contextvars were modified by the fixture, return a finalizer that + will restore them. + """ + context_tokens = [] + for var in context: + try: + if var.get() is context.get(var): + # This variable is not modified, so leave it as-is. + continue + except LookupError: + # This variable isn't yet set in the current context at all. + pass + token = var.set(context.get(var)) + context_tokens.append((var, token)) + + if not context_tokens: + return None + + def restore_contextvars(): + while context_tokens: + (var, token) = context_tokens.pop() + var.reset(token) + + return restore_contextvars + + class PytestAsyncioFunction(Function): """Base class for all test functions managed by pytest-asyncio.""" diff --git a/tests/async_fixtures/test_async_fixtures_contextvars.py b/tests/async_fixtures/test_async_fixtures_contextvars.py new file mode 100644 index 00000000..ff79e17e --- /dev/null +++ b/tests/async_fixtures/test_async_fixtures_contextvars.py @@ -0,0 +1,247 @@ +""" +Regression test for https://github.com/pytest-dev/pytest-asyncio/issues/127: +contextvars were not properly maintained among fixtures and tests. +""" + +from __future__ import annotations + +import sys +from textwrap import dedent + +import pytest +from pytest import Pytester + +_prelude = dedent( + """ + import pytest + import pytest_asyncio + from contextlib import contextmanager + from contextvars import ContextVar + + _context_var = ContextVar("context_var") + + @contextmanager + def context_var_manager(value): + token = _context_var.set(value) + try: + yield + finally: + _context_var.reset(token) +""" +) + + +def test_var_from_sync_generator_propagates_to_async(pytester: Pytester): + pytester.makeini("[pytest]\nasyncio_default_fixture_loop_scope = function") + pytester.makepyfile( + _prelude + + dedent( + """ + @pytest.fixture + def var_fixture(): + with context_var_manager("value"): + yield + + @pytest_asyncio.fixture + async def check_var_fixture(var_fixture): + assert _context_var.get() == "value" + + @pytest.mark.asyncio + async def test(check_var_fixture): + assert _context_var.get() == "value" + """ + ) + ) + result = pytester.runpytest("--asyncio-mode=strict") + result.assert_outcomes(passed=1) + + +@pytest.mark.xfail( + sys.version_info < (3, 11), + reason="requires asyncio Task context support", + strict=True, +) +def test_var_from_async_generator_propagates_to_sync(pytester: Pytester): + pytester.makeini("[pytest]\nasyncio_default_fixture_loop_scope = function") + pytester.makepyfile( + _prelude + + dedent( + """ + @pytest_asyncio.fixture + async def var_fixture(): + with context_var_manager("value"): + yield + + @pytest.fixture + def check_var_fixture(var_fixture): + assert _context_var.get() == "value" + + @pytest.mark.asyncio + async def test(check_var_fixture): + assert _context_var.get() == "value" + """ + ) + ) + result = pytester.runpytest("--asyncio-mode=strict") + result.assert_outcomes(passed=1) + + +@pytest.mark.xfail( + sys.version_info < (3, 11), + reason="requires asyncio Task context support", + strict=True, +) +def test_var_from_async_fixture_propagates_to_sync(pytester: Pytester): + pytester.makeini("[pytest]\nasyncio_default_fixture_loop_scope = function") + pytester.makepyfile( + _prelude + + dedent( + """ + @pytest_asyncio.fixture + async def var_fixture(): + _context_var.set("value") + # Rely on async fixture teardown to reset the context var. + + @pytest.fixture + def check_var_fixture(var_fixture): + assert _context_var.get() == "value" + + def test(check_var_fixture): + assert _context_var.get() == "value" + """ + ) + ) + result = pytester.runpytest("--asyncio-mode=strict") + result.assert_outcomes(passed=1) + + +@pytest.mark.xfail( + sys.version_info < (3, 11), + reason="requires asyncio Task context support", + strict=True, +) +def test_var_from_generator_reset_before_previous_fixture_cleanup(pytester: Pytester): + pytester.makeini("[pytest]\nasyncio_default_fixture_loop_scope = function") + pytester.makepyfile( + _prelude + + dedent( + """ + @pytest_asyncio.fixture + async def no_var_fixture(): + with pytest.raises(LookupError): + _context_var.get() + yield + with pytest.raises(LookupError): + _context_var.get() + + @pytest_asyncio.fixture + async def var_fixture(no_var_fixture): + with context_var_manager("value"): + yield + + @pytest.mark.asyncio + async def test(var_fixture): + assert _context_var.get() == "value" + """ + ) + ) + result = pytester.runpytest("--asyncio-mode=strict") + result.assert_outcomes(passed=1) + + +@pytest.mark.xfail( + sys.version_info < (3, 11), + reason="requires asyncio Task context support", + strict=True, +) +def test_var_from_fixture_reset_before_previous_fixture_cleanup(pytester: Pytester): + pytester.makeini("[pytest]\nasyncio_default_fixture_loop_scope = function") + pytester.makepyfile( + _prelude + + dedent( + """ + @pytest_asyncio.fixture + async def no_var_fixture(): + with pytest.raises(LookupError): + _context_var.get() + yield + with pytest.raises(LookupError): + _context_var.get() + + @pytest_asyncio.fixture + async def var_fixture(no_var_fixture): + _context_var.set("value") + # Rely on async fixture teardown to reset the context var. + + @pytest.mark.asyncio + async def test(var_fixture): + assert _context_var.get() == "value" + """ + ) + ) + result = pytester.runpytest("--asyncio-mode=strict") + result.assert_outcomes(passed=1) + + +@pytest.mark.xfail( + sys.version_info < (3, 11), + reason="requires asyncio Task context support", + strict=True, +) +def test_var_previous_value_restored_after_fixture(pytester: Pytester): + pytester.makeini("[pytest]\nasyncio_default_fixture_loop_scope = function") + pytester.makepyfile( + _prelude + + dedent( + """ + @pytest_asyncio.fixture + async def var_fixture_1(): + with context_var_manager("value1"): + yield + assert _context_var.get() == "value1" + + @pytest_asyncio.fixture + async def var_fixture_2(var_fixture_1): + with context_var_manager("value2"): + yield + assert _context_var.get() == "value2" + + @pytest.mark.asyncio + async def test(var_fixture_2): + assert _context_var.get() == "value2" + """ + ) + ) + result = pytester.runpytest("--asyncio-mode=strict") + result.assert_outcomes(passed=1) + + +@pytest.mark.xfail( + sys.version_info < (3, 11), + reason="requires asyncio Task context support", + strict=True, +) +def test_var_set_to_existing_value_ok(pytester: Pytester): + pytester.makeini("[pytest]\nasyncio_default_fixture_loop_scope = function") + pytester.makepyfile( + _prelude + + dedent( + """ + @pytest_asyncio.fixture + async def var_fixture(): + with context_var_manager("value"): + yield + + @pytest_asyncio.fixture + async def same_var_fixture(var_fixture): + with context_var_manager(_context_var.get()): + yield + + @pytest.mark.asyncio + async def test(same_var_fixture): + assert _context_var.get() == "value" + """ + ) + ) + result = pytester.runpytest("--asyncio-mode=strict") + result.assert_outcomes(passed=1)