diff --git a/ipykernel/_version.py b/ipykernel/_version.py index 5907d150..b4c5b1da 100644 --- a/ipykernel/_version.py +++ b/ipykernel/_version.py @@ -1,6 +1,8 @@ """ store the current version info of the server. """ +from __future__ import annotations + import re # Version string must appear intact for hatch versioning diff --git a/ipykernel/embed.py b/ipykernel/embed.py index 3e4abd39..ad22e2a1 100644 --- a/ipykernel/embed.py +++ b/ipykernel/embed.py @@ -55,3 +55,4 @@ def embed_kernel(module=None, local_ns=None, **kwargs): app.kernel.user_ns = local_ns app.shell.set_completer_frame() # type:ignore[union-attr] app.start() + app.close() diff --git a/ipykernel/inprocess/channels.py b/ipykernel/inprocess/channels.py index 4c01c5bc..a886f6c8 100644 --- a/ipykernel/inprocess/channels.py +++ b/ipykernel/inprocess/channels.py @@ -2,6 +2,7 @@ # Copyright (c) IPython Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations from jupyter_client.channelsabc import HBChannelABC diff --git a/ipykernel/iostream.py b/ipykernel/iostream.py index d8171017..62a29a2f 100644 --- a/ipykernel/iostream.py +++ b/ipykernel/iostream.py @@ -398,8 +398,8 @@ def fileno(self): """ Things like subprocess will peak and write to the fileno() of stderr/stdout. """ - if getattr(self, "_original_stdstream_copy", None) is not None: - return self._original_stdstream_copy + if getattr(self, "_original_stdstream_fd", None) is not None: + return self._original_stdstream_fd msg = "fileno" raise io.UnsupportedOperation(msg) @@ -527,10 +527,7 @@ def __init__( # echo on the _copy_ we made during # this is the actual terminal FD now echo = io.TextIOWrapper( - io.FileIO( - self._original_stdstream_copy, - "w", - ) + io.FileIO(self._original_stdstream_copy, "w", closefd=False) ) self.echo = echo else: @@ -595,9 +592,10 @@ def close(self): self._should_watch = False # thread won't wake unless there's something to read # writing something after _should_watch will not be echoed - os.write(self._original_stdstream_fd, b"\0") - if self.watch_fd_thread is not None: + if self.watch_fd_thread is not None and self.watch_fd_thread.is_alive(): + os.write(self._original_stdstream_fd, b"\0") self.watch_fd_thread.join() + self.echo = None # restore original FDs os.dup2(self._original_stdstream_copy, self._original_stdstream_fd) os.close(self._original_stdstream_copy) diff --git a/ipykernel/kernelapp.py b/ipykernel/kernelapp.py index 66b750b2..453cd99e 100644 --- a/ipykernel/kernelapp.py +++ b/ipykernel/kernelapp.py @@ -151,6 +151,11 @@ class IPKernelApp(BaseIPythonApplication, InteractiveShellApp, ConnectionFileMix _ports = Dict() + _original_io = Any() + _log_map = Any() + _io_modified = Bool(False) + _blackhole = Any() + subcommands = { "install": ( "ipykernel.kernelspec.InstallIPythonKernelSpecApp", @@ -470,41 +475,53 @@ def log_connection_info(self): def init_blackhole(self): """redirects stdout/stderr to devnull if necessary""" + self._save_io() if self.no_stdout or self.no_stderr: - blackhole = open(os.devnull, "w") # noqa: SIM115 + # keep reference around so that it would not accidentally close the pipe fds + self._blackhole = open(os.devnull, "w") # noqa: SIM115 if self.no_stdout: - sys.stdout = sys.__stdout__ = blackhole # type:ignore[misc] + if sys.stdout is not None: + sys.stdout.flush() + sys.stdout = self._blackhole if self.no_stderr: - sys.stderr = sys.__stderr__ = blackhole # type:ignore[misc] + if sys.stderr is not None: + sys.stderr.flush() + sys.stderr = self._blackhole def init_io(self): """Redirect input streams and set a display hook.""" + self._save_io() if self.outstream_class: outstream_factory = import_item(str(self.outstream_class)) - if sys.stdout is not None: - sys.stdout.flush() - e_stdout = None if self.quiet else sys.__stdout__ - e_stderr = None if self.quiet else sys.__stderr__ + e_stdout = None if self.quiet else sys.stdout + e_stderr = None if self.quiet else sys.stderr if not self.capture_fd_output: outstream_factory = partial(outstream_factory, watchfd=False) + if sys.stdout is not None: + sys.stdout.flush() sys.stdout = outstream_factory(self.session, self.iopub_thread, "stdout", echo=e_stdout) + if sys.stderr is not None: sys.stderr.flush() sys.stderr = outstream_factory(self.session, self.iopub_thread, "stderr", echo=e_stderr) + if hasattr(sys.stderr, "_original_stdstream_copy"): for handler in self.log.handlers: - if isinstance(handler, StreamHandler) and (handler.stream.buffer.fileno() == 2): + if ( + isinstance(handler, StreamHandler) + and (buffer := getattr(handler.stream, "buffer", None)) + and (fileno := getattr(buffer, "fileno", None)) + and fileno() == sys.stderr._original_stdstream_fd # type:ignore[attr-defined] + ): self.log.debug("Seeing logger to stderr, rerouting to raw filedescriptor.") - - handler.stream = TextIOWrapper( - FileIO( - sys.stderr._original_stdstream_copy, - "w", - ) + io_wrapper = TextIOWrapper( + FileIO(sys.stderr._original_stdstream_copy, "w", closefd=False) ) + self._log_map[id(io_wrapper)] = handler.stream + handler.stream = io_wrapper if self.displayhook_class: displayhook_factory = import_item(str(self.displayhook_class)) self.displayhook = displayhook_factory(self.session, self.iopub_socket) @@ -512,14 +529,39 @@ def init_io(self): self.patch_io() + def _save_io(self): + if not self._io_modified: + self._original_io = sys.stdout, sys.stderr, sys.displayhook + self._log_map = {} + self._io_modified = True + def reset_io(self): """restore original io restores state after init_io """ - sys.stdout = sys.__stdout__ - sys.stderr = sys.__stderr__ - sys.displayhook = sys.__displayhook__ + if not self._io_modified: + return + stdout, stderr, displayhook = sys.stdout, sys.stderr, sys.displayhook + sys.stdout, sys.stderr, sys.displayhook = self._original_io + self._original_io = None + self._io_modified = False + if finish_displayhook := getattr(displayhook, "finish_displayhook", None): + finish_displayhook() + if hasattr(stderr, "_original_stdstream_copy"): + for handler in self.log.handlers: + if orig_stream := self._log_map.get(id(handler.stream)): + self.log.debug("Seeing modified logger, rerouting back to stderr") + handler.stream = orig_stream + self._log_map = None + if self.outstream_class: + outstream_factory = import_item(str(self.outstream_class)) + if isinstance(stderr, outstream_factory): + stderr.close() + if isinstance(stdout, outstream_factory): + stdout.close() + if self._blackhole: + self._blackhole.close() def patch_io(self): """Patch important libraries that can't handle sys.stdout forwarding""" diff --git a/ipykernel/pickleutil.py b/ipykernel/pickleutil.py index 4ffa5262..15fc0e67 100644 --- a/ipykernel/pickleutil.py +++ b/ipykernel/pickleutil.py @@ -2,6 +2,8 @@ # Copyright (c) IPython Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations + import copy import pickle import sys diff --git a/ipykernel/thread.py b/ipykernel/thread.py index 40509ece..b3b33d32 100644 --- a/ipykernel/thread.py +++ b/ipykernel/thread.py @@ -1,4 +1,6 @@ """Base class for threads.""" +from __future__ import annotations + import typing as t from threading import Event, Thread diff --git a/tests/test_kernelapp.py b/tests/test_kernelapp.py index 05f6e557..ec91687f 100644 --- a/tests/test_kernelapp.py +++ b/tests/test_kernelapp.py @@ -31,6 +31,7 @@ def test_blackhole(): app.no_stderr = True app.no_stdout = True app.init_blackhole() + app.close() def test_start_app():