diff --git a/docs/source/reference-core.rst b/docs/source/reference-core.rst index 24f9ab52e0..0814e28907 100644 --- a/docs/source/reference-core.rst +++ b/docs/source/reference-core.rst @@ -656,6 +656,8 @@ finished. This code will wait 5 seconds (for the child task to finish), and then return. +.. _child-tasks-and-cancellation: + Child tasks and cancellation ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -933,8 +935,8 @@ Will properly log the inner exceptions: .. _task-local-storage: -Task-local storage ------------------- +Context variables support task-local storage +-------------------------------------------- Suppose you're writing a server that responds to network requests, and you log some information about each request as you process it. If the @@ -1020,6 +1022,94 @@ For more information, read the `contextvar docs `__. +More on context variable inheritance +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +As you can see from the example above, a task's context variables are +automatically inherited by any other tasks it starts. To be precise, +each new task gets a *shallow copy* of the context in the task that +spawned it. That means: + +* If the new task changes one of its context variables using `ContextVar.set() + `, then that change is not visible in the + task that started it. (This would hardly be "task-local" storage otherwise!) + +* But if the context variable referred to a mutable object (such as a + list or dictionary), and the new task makes a change to that object + (such as by calling ``some_contextvar.get().append(42)``), then that + change *is* visible in the task that started it, as well as in any other + tasks it started. Since that's rather confusing, it's often best to + limit yourself to immutable values (strings, integers, tuples, and so + on) when working with context variables. + +A new task's context is set up as a copy of the context that existed +at the time of the call to :meth:`~Nursery.start_soon` or +:meth:`~Nursery.start` that created it. For example, this code: + +.. code-block:: python3 + + some_cvar = contextvars.ContextVar() + + async def print_in_child(tag): + print("In child", tag, "some_cvar has value", some_cvar.get()) + + some_cvar.set(1) + async with trio.open_nursery() as nursery: + nursery.start_soon(print_in_child, 1) + some_cvar.set(2) + nursery.start_soon(print_in_child, 2) + some_cvar.set(3) + print("In parent some_cvar has value", some_cvar.get()) + +will produce output like:: + + In parent some_cvar has value 3 + In child 1 some_cvar has value 1 + In child 2 some_cvar has value 2 + +(If you run it yourself, you might find that the "child 2" line comes +before "child 1", but it will still be the case that child 1 sees value 1 +while child 2 sees value 2.) + +You might wonder why this differs from the behavior of cancel scopes, +which only apply to a new task if they surround the new task's entire +nursery (as explained above in the section on +:ref:`child-tasks-and-cancellation`). The difference is that a cancel +scope has a limited lifetime (it can't cancel anything once you exit +its ``with`` block), while a context variable's value is just a value +(request #42 can keep being request #42 for as long as it likes, +without any cooperation from the task that created it). + +In specialized cases, you might want to provide a task-local value +that's inherited only from the parent nursery, like cancel scopes are. +(For example, maybe you're trying to provide child tasks with access +to a limited-lifetime resource such as a nursery or network +connection, and you only want a task to be able to use the resource if +it's going to remain available for the task's entire lifetime.) Trio +supports this using `TreeVar`, which is like `contextvars.ContextVar` +except for the way that it's inherited by new tasks. (It's a "tree" +variable because it's inherited along the parent-child links that +form the Trio task tree.) + +If the above example used `TreeVar`, then its output would be: + +.. code-block:: none + :emphasize-lines: 3 + + In parent some_cvar has value 3 + In child 1 some_cvar has value 1 + In child 2 some_cvar has value 1 + +because child 2 would inherit the value from its parent nursery, rather than +from the environment of the ``start_soon()`` call that creates it. + +.. autoclass:: trio.TreeVar(name, [*, default]) + + .. automethod:: being + :with: + .. automethod:: get_in(task_or_nursery, [default]) + + .. _synchronization: Synchronizing and communicating between tasks diff --git a/docs/source/reference-lowlevel.rst b/docs/source/reference-lowlevel.rst index 36b37e3e37..1da3dbec81 100644 --- a/docs/source/reference-lowlevel.rst +++ b/docs/source/reference-lowlevel.rst @@ -256,7 +256,11 @@ anything real. See `#26 Global state: system tasks and run-local variables ================================================== -.. autoclass:: RunVar +.. autoclass:: RunVar(name, [default]) + + .. automethod:: get([default]) + .. automethod:: set + .. automethod:: reset .. autofunction:: spawn_system_task diff --git a/newsfragments/1523.feature.rst b/newsfragments/1523.feature.rst new file mode 100644 index 0000000000..bc11952474 --- /dev/null +++ b/newsfragments/1523.feature.rst @@ -0,0 +1,8 @@ +Added the concept of a "tree variable" (`trio.TreeVar`), which is like +a context variable except that its value in a new task is inherited +from the new task's parent nursery rather than from the task that +spawned it. (The `~trio.TreeVar` behavior matches the existing +:ref:`behavior of cancel scopes `.) +This distinction makes tree variables useful for anything that's +naturally inherited along parent/child task relationships, such as a +reference to a resource that has a limited lifetime. diff --git a/trio/__init__.py b/trio/__init__.py index 63e74e9da8..1117c8feed 100644 --- a/trio/__init__.py +++ b/trio/__init__.py @@ -32,6 +32,7 @@ BrokenResourceError, EndOfChannel, Nursery, + TreeVar, ) from ._timeouts import ( diff --git a/trio/_core/__init__.py b/trio/_core/__init__.py index 2bd0c74e67..02226b2dba 100644 --- a/trio/_core/__init__.py +++ b/trio/_core/__init__.py @@ -69,7 +69,7 @@ from ._unbounded_queue import UnboundedQueue -from ._local import RunVar +from ._local import RunVar, TreeVar from ._thread_cache import start_thread_soon diff --git a/trio/_core/_local.py b/trio/_core/_local.py index 352caa5682..f32ddeb029 100644 --- a/trio/_core/_local.py +++ b/trio/_core/_local.py @@ -1,7 +1,9 @@ -# Runvar implementations +# Implementations of RunVar and ScopeVar +import contextvars +from contextlib import contextmanager from . import _run -from .._util import SubclassingDeprecatedIn_v0_15_0 +from .._util import Final, SubclassingDeprecatedIn_v0_15_0 class _RunVarToken: @@ -19,16 +21,17 @@ def __init__(self, var, value): self.redeemed = False +_NO_DEFAULT = object() + + class RunVar(metaclass=SubclassingDeprecatedIn_v0_15_0): """The run-local variant of a context variable. :class:`RunVar` objects are similar to context variable objects, except that they are shared across a single call to :func:`trio.run` rather than a single task. - """ - _NO_DEFAULT = object() __slots__ = ("_name", "_default") def __init__(self, name, default=_NO_DEFAULT): @@ -36,25 +39,28 @@ def __init__(self, name, default=_NO_DEFAULT): self._default = default def get(self, default=_NO_DEFAULT): - """Gets the value of this :class:`RunVar` for the current run call.""" + """Gets the value of this `RunVar` for the current call + to :func:`trio.run`.""" try: return _run.GLOBAL_RUN_CONTEXT.runner._locals[self] except AttributeError: raise RuntimeError("Cannot be used outside of a run context") from None except KeyError: # contextvars consistency - if default is not self._NO_DEFAULT: + if default is not _NO_DEFAULT: return default - if self._default is not self._NO_DEFAULT: + if self._default is not _NO_DEFAULT: return self._default raise LookupError(self) from None def set(self, value): - """Sets the value of this :class:`RunVar` for this current run - call. + """Sets the value of this `RunVar` for the current call to + :func:`trio.run`. + Returns a token which may be passed to :meth:`reset` to restore + the previous value. """ try: old_value = self.get() @@ -69,9 +75,8 @@ def set(self, value): return token def reset(self, token): - """Resets the value of this :class:`RunVar` to what it was - previously specified by the token. - + """Resets the value of this `RunVar` to the value it had + before the call to :meth:`set` that returned the given *token*. """ if token is None: raise TypeError("token must not be none") @@ -95,3 +100,118 @@ def reset(self, token): def __repr__(self): return "".format(self._name) + + +class TreeVar(metaclass=Final): + """A "tree variable": like a context variable except that its value + in a new task is inherited from the new task's parent nursery rather + than from the new task's spawner. + + `TreeVar` objects support all the same methods and attributes as + `~contextvars.ContextVar` objects + (:meth:`~contextvars.ContextVar.get`, + :meth:`~contextvars.ContextVar.set`, + :meth:`~contextvars.ContextVar.reset`, and + `~contextvars.ContextVar.name`), and they are constructed the same + way. They also provide the additional methods :meth:`being` and + :meth:`get_in`, documented below. + + Accessing or changing the value of a `TreeVar` outside of a Trio + task will raise `RuntimeError`. (Exception: :meth:`get_in` still + works outside of a task, as long as you have a reference to the + task or nursery of interest.) + + .. note:: `TreeVar` values are not directly stored in the + `contextvars.Context`, so you can't use `Context.get() + ` to access them. If you need the value + in a context other than your own, use :meth:`get_in`. + + """ + + __slots__ = ("_cvar",) + + def __init__(self, name, **default): + self._cvar = contextvars.ContextVar(name, **default) + + @property + def name(self): + """The name of the variable, as passed during construction. Read-only.""" + return self._cvar.name + + def get(self, default=_NO_DEFAULT): + """Gets the value of this `TreeVar` for the current task. + + If this `TreeVar` has no value in the current task, then + :meth:`get` returns the *default* specified as argument to + :meth:`get`, or else the *default* specified when constructing + the `TreeVar`, or else raises `LookupError`. See the + documentation of :meth:`contextvars.ContextVar.get` for more + details. + """ + # This is effectively an inlining for efficiency of: + # return _run.current_task()._tree_context.run(self._cvar.get, default) + try: + return _run.GLOBAL_RUN_CONTEXT.task._tree_context[self._cvar] + except AttributeError: + raise RuntimeError("must be called from async context") from None + except KeyError: + pass + # This will always return the default or raise, because we never give + # self._cvar a value in any context in which we run user code. + if default is _NO_DEFAULT: + return self._cvar.get() + else: + return self._cvar.get(default) + + def set(self, value): + """Sets the value of this `TreeVar` for the current task. The new + value will be inherited by nurseries that are later opened in + this task, so that new tasks can inherit whatever value was + set when their parent nursery was created. + + Returns a token which may be passed to :meth:`reset` to restore + the previous value. + """ + return _run.current_task()._tree_context.run(self._cvar.set, value) + + def reset(self, token): + """Resets the value of this `TreeVar` to the value it had + before the call to :meth:`set` that returned the given *token*. + + The *token* must have been obtained from a call to :meth:`set` on + this same `TreeVar` and in the same task that is now calling + :meth:`reset`. Also, each *token* may only be used in one call to + :meth:`reset`. Violating these conditions will raise `ValueError`. + """ + _run.current_task()._tree_context.run(self._cvar.reset, token) + + @contextmanager + def being(self, value): + """Returns a context manager which sets the value of this `TreeVar` to + *value* upon entry and restores its previous value upon exit. + """ + token = self.set(value) + try: + yield + finally: + self.reset(token) + + def get_in(self, task_or_nursery, default=_NO_DEFAULT): + """Gets the value of this `TreeVar` in the given + `~trio.lowlevel.Task` or `~trio.Nursery`. + + The value in a task is the value that would be returned by a + call to :meth:`~contextvars.ContextVar.get` in that task. The + value in a nursery is the value that would be returned by + :meth:`~contextvars.ContextVar.get` at the beginning of a new + child task started in that nursery. The *default* argument has + the same semantics as it does for :meth:`~contextvars.ContextVar.get`. + """ + # copy() so this works from a different thread too. It's a + # cheap and thread-safe operation (just copying one reference) + # since the underlying context data is immutable. + defarg = () if default is _NO_DEFAULT else (default,) + return task_or_nursery._tree_context.copy().run(self._cvar.get, *defarg) + + def __repr__(self): + return f"" diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 9d6d388ca7..88bd8ef2c9 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -13,7 +13,7 @@ import warnings import enum -from contextvars import copy_context +from contextvars import copy_context, Context from math import inf from time import perf_counter from typing import Callable, TYPE_CHECKING @@ -861,6 +861,7 @@ class Nursery(metaclass=NoPublicConstructor): def __init__(self, parent_task, cancel_scope): self._parent_task = parent_task + self._tree_context = parent_task._tree_context.copy() parent_task._child_nurseries.append(self) # the cancel status that children inherit - we take a snapshot, so it # won't be affected by any changes in the parent. @@ -1032,10 +1033,9 @@ async def async_fn(arg1, arg2, \*, task_status=trio.TASK_STATUS_IGNORED): async with open_nursery() as old_nursery: task_status = _TaskStatus(old_nursery, self) thunk = functools.partial(async_fn, task_status=task_status) - task = GLOBAL_RUN_CONTEXT.runner.spawn_impl( - thunk, args, old_nursery, name + GLOBAL_RUN_CONTEXT.runner.spawn_impl( + thunk, args, old_nursery, name, eventual_nursery=self, ) - task._eventual_parent_nursery = self # Wait for either _TaskStatus.started or an exception to # cancel this nursery: # If we get here, then the child either got reparented or exited @@ -1067,6 +1067,9 @@ class Task(metaclass=NoPublicConstructor): context = attr.ib() _counter = attr.ib(init=False, factory=itertools.count().__next__) + # Contextvars context that contains ScopeVar values + _tree_context = attr.ib(init=False) + # Invariant: # - for unscheduled tasks, _next_send_fn and _next_send are both None # - for scheduled tasks, _next_send_fn(_next_send) resumes the task; @@ -1093,6 +1096,13 @@ class Task(metaclass=NoPublicConstructor): _cancel_points = attr.ib(default=0) _schedule_points = attr.ib(default=0) + def __attrs_post_init__(self): + if self._parent_nursery is None: + self._tree_context = Context() + else: + parent = self._eventual_parent_nursery or self._parent_nursery + self._tree_context = parent._tree_context.copy() + def __repr__(self): return "".format(self.name, id(self)) @@ -1407,8 +1417,9 @@ def reschedule(self, task, next_send=_NO_SEND): if self.instruments: self.instrument("task_scheduled", task) - def spawn_impl(self, async_fn, args, nursery, name, *, system_task=False): - + def spawn_impl( + self, async_fn, args, nursery, name, *, eventual_nursery=None, system_task=False + ): ###### # Make sure the nursery is in working order ###### @@ -1454,7 +1465,12 @@ async def python_wrapper(orig_coro): # Set up the Task object ###### task = Task._create( - coro=coro, parent_nursery=nursery, runner=self, name=name, context=context, + coro=coro, + parent_nursery=nursery, + runner=self, + name=name, + context=context, + eventual_parent_nursery=eventual_nursery, ) self.tasks.add(task) diff --git a/trio/_core/tests/test_local.py b/trio/_core/tests/test_local.py index 7f403168ea..b43f8406fe 100644 --- a/trio/_core/tests/test_local.py +++ b/trio/_core/tests/test_local.py @@ -1,4 +1,5 @@ import pytest +from functools import partial from ... import _core @@ -113,3 +114,130 @@ async def get_token(): with pytest.raises(RuntimeError): t1.reset(token) + + +async def test_treevar(): + tv1 = _core.TreeVar("tv1") + tv2 = _core.TreeVar("tv2", default=None) + assert tv1.name == "tv1" + assert "TreeVar name='tv2'" in repr(tv2) + + with pytest.raises(LookupError): + tv1.get() + assert tv2.get() is None + assert tv1.get(42) == 42 + assert tv2.get(42) == 42 + + NOTHING = object() + + async def should_be(val1, val2, new1=NOTHING): + assert tv1.get(NOTHING) == val1 + assert tv2.get(NOTHING) == val2 + if new1 is not NOTHING: + tv1.set(new1) + + tok1 = tv1.set(10) + async with _core.open_nursery() as outer: + tok2 = tv1.set(15) + with tv2.being(20): + assert tv2.get_in(_core.current_task()) == 20 + async with _core.open_nursery() as inner: + tv1.reset(tok2) + outer.start_soon(should_be, 10, NOTHING, 100) + inner.start_soon(should_be, 15, 20, 200) + await _core.wait_all_tasks_blocked() + assert tv1.get_in(_core.current_task()) == 10 + await should_be(10, 20, 300) + assert tv1.get_in(inner) == 15 + assert tv1.get_in(outer) == 10 + assert tv1.get_in(_core.current_task()) == 300 + assert tv2.get_in(inner) == 20 + assert tv2.get_in(outer) is None + assert tv2.get_in(_core.current_task()) == 20 + tv1.reset(tok1) + await should_be(NOTHING, 20) + assert tv1.get_in(inner) == 15 + assert tv1.get_in(outer) == 10 + with pytest.raises(LookupError): + assert tv1.get_in(_core.current_task()) + assert tv2.get() is None + assert tv2.get_in(_core.current_task()) is None + + +async def test_treevar_follows_eventual_parent(): + tv1 = _core.TreeVar("tv1") + + def trivial_abort(_): + return _core.Abort.SUCCEEDED # pragma: no cover + + async def manage_target(task_status): + assert tv1.get() == "source nursery" + with tv1.being("target nursery"): + assert tv1.get() == "target nursery" + async with _core.open_nursery() as target_nursery: + with tv1.being("target nested child"): + assert tv1.get() == "target nested child" + task_status.started(target_nursery) + await _core.wait_task_rescheduled(trivial_abort) + assert tv1.get() == "target nested child" + assert tv1.get() == "target nursery" + assert tv1.get() == "target nursery" + assert tv1.get() == "source nursery" + + async def verify(value, *, task_status=_core.TASK_STATUS_IGNORED): + assert tv1.get() == value + task_status.started() + assert tv1.get() == value + + with tv1.being("source nursery"): + async with _core.open_nursery() as source_nursery: + with tv1.being("source->target start call"): + target_nursery = await source_nursery.start(manage_target) + with tv1.being("verify task"): + source_nursery.start_soon(verify, "source nursery") + target_nursery.start_soon(verify, "target nursery") + await source_nursery.start(verify, "source nursery") + await target_nursery.start(verify, "target nursery") + _core.reschedule(target_nursery.parent_task) + + +async def test_treevar_token_bound_to_task_that_obtained_it(): + tv1 = _core.TreeVar("tv1") + token = None + + async def get_token(): + nonlocal token + token = tv1.set(10) + try: + await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED) + finally: + tv1.reset(token) + with pytest.raises(LookupError): + tv1.get() + with pytest.raises(LookupError): + tv1.get_in(_core.current_task()) + + async with _core.open_nursery() as nursery: + nursery.start_soon(get_token) + await _core.wait_all_tasks_blocked() + assert token is not None + with pytest.raises(ValueError, match="different Context"): + tv1.reset(token) + assert tv1.get_in(list(nursery.child_tasks)[0]) == 10 + nursery.cancel_scope.cancel() + + +def test_treevar_outside_run(): + async def run_sync(fn, *args): + return fn(*args) + + tv1 = _core.TreeVar("tv1", default=10) + for operation in ( + tv1.get, + partial(tv1.get, 20), + partial(tv1.set, 30), + lambda: tv1.reset(_core.run(run_sync, tv1.set, 10)), + tv1.being(40).__enter__, + ): + with pytest.raises(RuntimeError, match="must be called from async context"): + operation()