diff --git a/ipykernel/displayhook.py b/ipykernel/displayhook.py index 5f42a1445..d5e136748 100644 --- a/ipykernel/displayhook.py +++ b/ipykernel/displayhook.py @@ -29,6 +29,7 @@ def __init__(self, session, pub_socket): self._parent_header: ContextVar[dict[str, Any]] = ContextVar("parent_header") self._parent_header.set({}) + self._parent_header_global = {} def get_execution_count(self): """This method is replaced in kernelapp""" @@ -57,11 +58,16 @@ def __call__(self, obj): @property def parent_header(self): - return self._parent_header.get() + try: + return self._parent_header.get() + except LookupError: + return self._parent_header_global def set_parent(self, parent): """Set the parent header.""" - self._parent_header.set(extract_header(parent)) + parent_header = extract_header(parent) + self._parent_header.set(parent_header) + self._parent_header_global = parent_header class ZMQShellDisplayHook(DisplayHook): @@ -83,11 +89,16 @@ def __init__(self, *args, **kwargs): @property def parent_header(self): - return self._parent_header.get() + try: + return self._parent_header.get() + except LookupError: + return self._parent_header_global def set_parent(self, parent): - """Set the parent for outbound messages.""" - self._parent_header.set(extract_header(parent)) + """Set the parent header.""" + parent_header = extract_header(parent) + self._parent_header.set(parent_header) + self._parent_header_global = parent_header def start_displayhook(self): """Start the display hook.""" diff --git a/ipykernel/iostream.py b/ipykernel/iostream.py index d9c90af4f..0a2115f3b 100644 --- a/ipykernel/iostream.py +++ b/ipykernel/iostream.py @@ -456,8 +456,6 @@ def __init__( "parent_header" ) self._parent_header.set({}) - self._thread_to_parent = {} - self._thread_to_parent_header = {} self._parent_header_global = {} self._master_pid = os.getpid() self._flush_pending = False @@ -512,21 +510,11 @@ def __init__( @property def parent_header(self): try: - # asyncio-specific + # asyncio or thread-specific return self._parent_header.get() except LookupError: - try: - # thread-specific - identity = threading.current_thread().ident - # retrieve the outermost (oldest ancestor, - # discounting the kernel thread) thread identity - while identity in self._thread_to_parent: - identity = self._thread_to_parent[identity] - # use the header of the oldest ancestor - return self._thread_to_parent_header[identity] - except KeyError: - # global (fallback) - return self._parent_header_global + # global (fallback) + return self._parent_header_global @parent_header.setter def parent_header(self, value): diff --git a/ipykernel/ipkernel.py b/ipykernel/ipkernel.py index 3e3927cc5..b8508ad68 100644 --- a/ipykernel/ipkernel.py +++ b/ipykernel/ipkernel.py @@ -4,7 +4,6 @@ import asyncio import builtins -import gc import getpass import os import signal @@ -17,7 +16,6 @@ import comm from IPython.core import release from IPython.utils.tokenutil import line_at_cursor, token_at_cursor -from jupyter_client.session import extract_header from traitlets import Any, Bool, HasTraits, Instance, List, Type, default, observe, observe_compat from zmq.eventloop.zmqstream import ZMQStream @@ -25,7 +23,6 @@ from .comm.manager import CommManager from .compiler import XCachingCompiler from .eventloops import _use_appnope -from .iostream import OutStream from .kernelbase import Kernel as KernelBase from .kernelbase import _accepts_parameters from .zmqshell import ZMQInteractiveShell @@ -167,14 +164,6 @@ def __init__(self, **kwargs): appnope.nope() - self._new_threads_parent_header = {} - self._initialize_thread_hooks() - - if hasattr(gc, "callbacks"): - # while `gc.callbacks` exists since Python 3.3, pypy does not - # implement it even as of 3.9. - gc.callbacks.append(self._clean_thread_parent_frames) - help_links = List( [ { @@ -374,8 +363,6 @@ def _dummy_context_manager(self, *args): async def execute_request(self, stream, ident, parent): """Override for cell output - cell reconciliation.""" - parent_header = extract_header(parent) - self._associate_new_top_level_threads_with(parent_header) await super().execute_request(stream, ident, parent) async def do_execute( @@ -750,83 +737,6 @@ def do_clear(self): self.shell.reset(False) return dict(status="ok") - def _associate_new_top_level_threads_with(self, parent_header): - """Store the parent header to associate it with new top-level threads""" - self._new_threads_parent_header = parent_header - - def _initialize_thread_hooks(self): - """Store thread hierarchy and thread-parent_header associations.""" - stdout = self._stdout - stderr = self._stderr - kernel_thread_ident = threading.get_ident() - kernel = self - _threading_Thread_run = threading.Thread.run - _threading_Thread__init__ = threading.Thread.__init__ - - def run_closure(self: threading.Thread): - """Wrap the `threading.Thread.start` to intercept thread identity. - - This is needed because there is no "start" hook yet, but there - might be one in the future: https://bugs.python.org/issue14073 - - This is a no-op if the `self._stdout` and `self._stderr` are not - sub-classes of `OutStream`. - """ - - try: - parent = self._ipykernel_parent_thread_ident # type:ignore[attr-defined] - except AttributeError: - return - for stream in [stdout, stderr]: - if isinstance(stream, OutStream): - if parent == kernel_thread_ident: - stream._thread_to_parent_header[self.ident] = ( - kernel._new_threads_parent_header - ) - else: - stream._thread_to_parent[self.ident] = parent - _threading_Thread_run(self) - - def init_closure(self: threading.Thread, *args, **kwargs): - _threading_Thread__init__(self, *args, **kwargs) - self._ipykernel_parent_thread_ident = threading.get_ident() # type:ignore[attr-defined] - - threading.Thread.__init__ = init_closure # type:ignore[method-assign] - threading.Thread.run = run_closure # type:ignore[method-assign] - - def _clean_thread_parent_frames( - self, phase: t.Literal["start", "stop"], info: dict[str, t.Any] - ): - """Clean parent frames of threads which are no longer running. - This is meant to be invoked by garbage collector callback hook. - - The implementation enumerates the threads because there is no "exit" hook yet, - but there might be one in the future: https://bugs.python.org/issue14073 - - This is a no-op if the `self._stdout` and `self._stderr` are not - sub-classes of `OutStream`. - """ - # Only run before the garbage collector starts - if phase != "start": - return - active_threads = {thread.ident for thread in threading.enumerate()} - for stream in [self._stdout, self._stderr]: - if isinstance(stream, OutStream): - thread_to_parent_header = stream._thread_to_parent_header - for identity in list(thread_to_parent_header.keys()): - if identity not in active_threads: - try: - del thread_to_parent_header[identity] - except KeyError: - pass - thread_to_parent = stream._thread_to_parent - for identity in list(thread_to_parent.keys()): - if identity not in active_threads: - try: - del thread_to_parent[identity] - except KeyError: - pass - # This exists only for backwards compatibility - use IPythonKernel instead diff --git a/ipykernel/kernelbase.py b/ipykernel/kernelbase.py index b815b934b..931b7aac3 100644 --- a/ipykernel/kernelbase.py +++ b/ipykernel/kernelbase.py @@ -17,7 +17,7 @@ import uuid import warnings from collections.abc import Mapping -from contextvars import ContextVar +from contextvars import Context, ContextVar, copy_context from datetime import datetime from functools import partial from signal import SIGINT, SIGTERM, Signals, default_int_handler, signal @@ -72,6 +72,8 @@ " ipykernel 6.0 (2021). {target} does not seem to return an awaitable" ) +T = t.TypeVar("T") + def _accepts_parameters(meth, param_names): parameters = inspect.signature(meth).parameters @@ -201,6 +203,7 @@ def _default_ident(self): _control_parent_ident: bytes = b"" _shell_parent: ContextVar[dict[str, Any]] _shell_parent_ident: ContextVar[bytes] + _shell_context: Context # Kept for backward-compatibility, accesses _control_parent_ident and _shell_parent_ident, # see https://github.com/jupyterlab/jupyterlab/issues/17785 _parent_ident: Mapping[str, bytes] @@ -320,13 +323,14 @@ def __init__(self, **kwargs): self._shell_parent.set({}) self._shell_parent_ident = ContextVar("shell_parent_ident") self._shell_parent_ident.set(b"") + self._shell_context = copy_context() # For backward compatibility so that _parent_ident["shell"] and _parent_ident["control"] # work as they used to for ipykernel >= 7 self._parent_ident = LazyDict( { "control": lambda: self._control_parent_ident, - "shell": lambda: self._shell_parent_ident.get(), + "shell": lambda: self._get_shell_context_var(self._shell_parent_ident), } ) @@ -768,6 +772,8 @@ def set_parent(self, ident, parent, channel="shell"): else: self._shell_parent_ident.set(ident) self._shell_parent.set(parent) + # preserve the last call to set_parent + self._shell_context = copy_context() def get_parent(self, channel=None): """Get the parent request associated with a channel. @@ -794,7 +800,20 @@ def get_parent(self, channel=None): if channel == "control": return self._control_parent - return self._shell_parent.get() + + return self._get_shell_context_var(self._shell_parent) + + def _get_shell_context_var(self, var: ContextVar[T]) -> T: + """Lookup a ContextVar, falling back on the shell context + + Allows for user-launched Threads to still resolve to the shell's main context + + necessary for e.g. display from threads. + """ + try: + return var.get() + except LookupError: + return self._shell_context[var] def send_response( self, @@ -1455,7 +1474,7 @@ def getpass(self, prompt="", stream=None): ) return self._input_request( prompt, - self._shell_parent_ident.get(), + self._get_shell_context_var(self._shell_parent_ident), self.get_parent("shell"), password=True, ) @@ -1472,7 +1491,7 @@ def raw_input(self, prompt=""): raise StdinNotImplementedError(msg) return self._input_request( str(prompt), - self._shell_parent_ident.get(), + self._get_shell_context_var(self._shell_parent_ident), self.get_parent("shell"), password=False, ) diff --git a/ipykernel/zmqshell.py b/ipykernel/zmqshell.py index ba707d481..39e2f1381 100644 --- a/ipykernel/zmqshell.py +++ b/ipykernel/zmqshell.py @@ -73,14 +73,20 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._parent_header = contextvars.ContextVar("parent_header") self._parent_header.set({}) + self._parent_header_global = {} @property def parent_header(self): - return self._parent_header.get() + try: + return self._parent_header.get() + except LookupError: + return self._parent_header_global def set_parent(self, parent): """Set the parent for outbound messages.""" - self._parent_header.set(extract_header(parent)) + parent_header = extract_header(parent) + self._parent_header.set(parent_header) + self._parent_header_global = parent_header def _flush_streams(self): """flush IO Streams prior to display""" @@ -698,11 +704,23 @@ def set_next_input(self, text, replace=False): @property def parent_header(self): - return self._parent_header.get() + try: + return self._parent_header.get() + except LookupError: + return self._parent_header_global + + @parent_header.setter + def parent_header(self, value): + self._parent_header_global = value + self._parent_header.set(value) def set_parent(self, parent): - """Set the parent header for associating output with its triggering input""" - self._parent_header.set(parent) + """Set the parent header for associating output with its triggering input + + When called from a thread, sets the thread-local value, which persists + until the next call from this thread. + """ + self.parent_header = parent self.displayhook.set_parent(parent) # type:ignore[attr-defined] self.display_pub.set_parent(parent) # type:ignore[attr-defined] if hasattr(self, "_data_pub"): @@ -713,7 +731,12 @@ def set_parent(self, parent): sys.stderr.set_parent(parent) def get_parent(self): - """Get the parent header.""" + """Get the parent header. + + If set_parent has never been called from the current thread, + the value from the last call to set_parent from _any_ thread will be used + (typically the currently running cell). + """ return self.parent_header def init_magics(self): diff --git a/tests/test_kernel.py b/tests/test_kernel.py index 3f05f75dd..2c5fe2b58 100644 --- a/tests/test_kernel.py +++ b/tests/test_kernel.py @@ -58,36 +58,119 @@ def test_simple_print(): _check_master(kc, expected=True) -def test_print_to_correct_cell_from_thread(): - """should print to the cell that spawned the thread, not a subsequently run cell""" - iterations = 5 - interval = 0.25 - code = f"""\ - from threading import Thread - from time import sleep +def collect_outputs(get_iopub_msg, parent_msg_id, timeout=5): + """Collect outputs until we get an idle message - def thread_target(): - for i in range({iterations}): - print(i, end='', flush=True) - sleep({interval}) + Returns list of complete output messages. + """ + while True: + msg = get_iopub_msg(timeout=timeout) + msg_type = msg["msg_type"] + content = msg["content"] + + if ( + msg["parent_header"]["msg_id"] == parent_msg_id + and msg_type == "status" + and content["execution_state"] == "idle" + ): + # idle message signals end of output + break + elif msg["msg_type"] in {"stream", "display_data"}: + yield msg + elif msg["msg_type"] == "error": + tb = "\n".join(msg["content"]["traceback"]) + raise RuntimeError(f"Error during execution: {tb}") + else: + # other output, ignored + print(msg["msg_type"]) + + +@pytest.mark.parametrize("explicit_parent", [True, False]) +def test_print_to_correct_cell_from_thread(explicit_parent: bool): + """should print to the current cell unless + + get_ipython().set_parent sets the thread-local value, + which supersedes the default. - Thread(target=thread_target).start() """ + code = f"""\ + from threading import Event, Thread + from time import sleep + from IPython.display import display + + explicit_parent = {explicit_parent} + parent = get_ipython().get_parent() + + cell_start_event = Event() + cell_end_event = Event() + + def thread_target(): + if explicit_parent: + get_ipython().set_parent(parent) + + print("before", flush=True) + display(1) + cell_start_event.wait(timeout=10) + cell_start_event.clear() + + print("during", flush=True) + display(2) + cell_end_event.set() + cell_start_event.wait(timeout=10) + cell_start_event.clear() + print("after", flush=True) + display(3) + + thread = Thread(target=thread_target) + thread.start() + """ + outputs = {} + + def add_output(msg): + parent_id = msg["parent_header"]["msg_id"] + if parent_id not in outputs: + outputs[parent_id] = { + "stdout": "", + "stderr": "", + "display_data": [], + } + cell_outputs = outputs[parent_id] + msg_type = msg["header"]["msg_type"] + content = msg["content"] + if msg_type == "stream": + cell_outputs[content["name"]] += content["text"] + else: + cell_outputs[msg_type].append(msg["content"]["data"]["text/plain"]) + with kernel() as kc: thread_msg_id = kc.execute(code) - _ = kc.execute("pass") - - received = 0 - while received < iterations: - msg = kc.get_iopub_msg(timeout=interval * 2) - if msg["msg_type"] != "stream": - continue - content = msg["content"] - assert content["name"] == "stdout" - assert content["text"] == str(received) - # this is crucial as the parent header decides to which cell the output goes - assert msg["parent_header"]["msg_id"] == thread_msg_id - received += 1 + for msg in collect_outputs(kc.get_iopub_msg, thread_msg_id): + add_output(msg) + + next_cell_msg_id = kc.execute("cell_start_event.set()\ncell_end_event.wait(timeout=10)") + for msg in collect_outputs(kc.get_iopub_msg, next_cell_msg_id): + add_output(msg) + + last_cell_msg_id = kc.execute("cell_start_event.set()\nthread.join()") + for msg in collect_outputs(kc.get_iopub_msg, last_cell_msg_id): + add_output(msg) + print(outputs) + if explicit_parent: + # assert next_cell_msg_id not in outputs + # assert last_cell_msg_id not in outputs + thread_cell_output = outputs[thread_msg_id] + assert thread_cell_output["stdout"] == "before\nduring\nafter\n" + assert thread_cell_output["display_data"] == ["1", "2", "3"] + else: + thread_cell_output = outputs[thread_msg_id] + assert thread_cell_output["stdout"] == "before\n" + assert thread_cell_output["display_data"] == ["1"] + next_cell_output = outputs[next_cell_msg_id] + assert next_cell_output["stdout"] == "during\n" + assert next_cell_output["display_data"] == ["2"] + last_cell_output = outputs[last_cell_msg_id] + assert last_cell_output["stdout"] == "after\n" + assert last_cell_output["display_data"] == ["3"] def test_print_to_correct_cell_from_child_thread(): @@ -98,7 +181,10 @@ def test_print_to_correct_cell_from_child_thread(): from threading import Thread from time import sleep + parent = get_ipython().get_parent() + def child_target(): + get_ipython().set_parent(parent) for i in range({iterations}): print(i, end='', flush=True) sleep({interval})