Skip to content

Commit 64e3c7d

Browse files
committed
Add SignalManager
1 parent 1b2e7dc commit 64e3c7d

File tree

4 files changed

+143
-9
lines changed

4 files changed

+143
-9
lines changed

metaflow/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ class and related decorators.
147147
# Runner API
148148
if sys.version_info >= (3, 7):
149149
from .runner.metaflow_runner import Runner
150+
from .runner.signal_manager import SignalManager
150151
from .runner.nbrun import NBRunner
151152
from .runner.deployer import Deployer
152153
from .runner.nbdeploy import NBDeployer

metaflow/runner/metaflow_runner.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from .utils import handle_timeout, async_handle_timeout, clear_and_set_os_environ
1212
from .subprocess_manager import CommandManager, SubprocessManager
13+
from .signal_manager import SignalManager
1314

1415

1516
class ExecutingRun(object):
@@ -231,6 +232,7 @@ def __init__(
231232
env: Optional[Dict] = None,
232233
cwd: Optional[str] = None,
233234
file_read_timeout: int = 3600,
235+
signal_manager: Optional[SignalManager] = None,
234236
**kwargs
235237
):
236238
# these imports are required here and not at the top
@@ -257,7 +259,7 @@ def __init__(
257259

258260
self.cwd = cwd
259261
self.file_read_timeout = file_read_timeout
260-
self.spm = SubprocessManager()
262+
self.spm = SubprocessManager(signal_manager=signal_manager)
261263
self.top_level_kwargs = kwargs
262264
self.api = MetaflowAPI.from_cli(self.flow_file, start)
263265

metaflow/runner/signal_manager.py

+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import asyncio
2+
import signal
3+
from typing import NewType, Mapping, Set, Callable, Optional
4+
5+
SignalHandler = NewType("SignalHandler", Callable[[int, []], None])
6+
7+
8+
class SignalManager:
9+
"""
10+
A context manager for managing signal handlers.
11+
12+
This class works as a context manager, restoring any overwritten
13+
signal handlers when the context is exited. This only works for signals
14+
in a synchronous context (ie. hooked by `signal`).
15+
16+
Parameters
17+
----------
18+
hook_signals : bool
19+
If True, the signal manager will overwrite any existing signal handlers
20+
in either `asyncio` or `signal`. If you already have any signal
21+
handling in place, you can set this to False and use `trigger_signal`
22+
to trigger metaflow-related signal handlers.
23+
event_loop : Optional[asyncio.AbstractEventLoop]
24+
The event loop to use for handling signals.
25+
If None, the current running event loop is used, if any.
26+
"""
27+
28+
hook_signals: bool
29+
event_loop: Optional[asyncio.AbstractEventLoop]
30+
signal_map: Mapping[int, Set[SignalHandler]] = dict()
31+
replaced_signals: Mapping[int, SignalHandler] = dict()
32+
33+
def __init__(
34+
self,
35+
hook_signals: bool = True,
36+
event_loop: Optional[asyncio.AbstractEventLoop] = None,
37+
):
38+
self.hook_signals = hook_signals
39+
try:
40+
self.event_loop = event_loop or asyncio.get_running_loop()
41+
except RuntimeError:
42+
self.event_loop = None
43+
44+
def __exit__(self, exc_type, exc_value, traceback):
45+
for sig in self.signal_map:
46+
self._maybe_remove_signal_handler(sig)
47+
48+
for sig in self.replaced_signals:
49+
signal.signal(sig, self.replaced_signals[sig])
50+
51+
def _handle_signal(self, signum, frame):
52+
for handler in self.signal_map[signum]:
53+
handler(signum, frame)
54+
55+
def _maybe_add_signal_handler(self, sig):
56+
if not self.hook_signals:
57+
return
58+
59+
if self.event_loop is None:
60+
replaced = signal.signal(sig, self._handle_signal)
61+
self.replaced_signals[sig] = replaced
62+
63+
else:
64+
self.event_loop.add_signal_handler(
65+
sig, lambda: self._handle_signal(sig, None)
66+
)
67+
68+
def _maybe_remove_signal_handler(self, sig: int):
69+
if not self.hook_signals:
70+
return
71+
72+
if self.event_loop is None:
73+
signal.signal(sig, self.replaced_signals[sig])
74+
del self.replaced_signals[sig]
75+
else:
76+
self.event_loop.remove_signal_handler(sig)
77+
78+
def add_signal_handler(self, sig: int, handler: SignalHandler):
79+
"""
80+
Add a signal handler for the given signal.
81+
82+
Parameters
83+
----------
84+
sig: int
85+
The signal to handle.
86+
handler: SignalHandler
87+
The handler to call when the signal is received.
88+
"""
89+
if sig not in self.signal_map:
90+
self.signal_map[sig] = set()
91+
self._maybe_add_signal_handler(sig)
92+
93+
self.signal_map[sig].add(handler)
94+
95+
def remove_signal_handler(self, sig: signal.Signals, handler: SignalHandler):
96+
"""
97+
Remove a signal handler for the given signal.
98+
99+
Parameters
100+
----------
101+
sig: int
102+
The signal to handle.
103+
handler: SignalHandler
104+
The handler to remove.
105+
106+
Raises
107+
------
108+
KeyError
109+
If the signal `sig` is not being handled.
110+
"""
111+
if sig not in self.signal_map:
112+
return
113+
114+
self.signal_map[sig].discard(handler)
115+
116+
def trigger_signal(self, sig: int, frame=None):
117+
"""
118+
Trigger a signal handler for the given signal.
119+
120+
Parameters
121+
----------
122+
sig : int
123+
The signal to handle.
124+
frame : [] (optional)
125+
The frame to pass to the signal handler.
126+
Only used in a synchronous context.
127+
"""
128+
self._handle_signal(sig, frame)

metaflow/runner/subprocess_manager.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import threading
1010
from typing import Callable, Dict, Iterator, List, Optional, Tuple
1111

12+
from .signal_manager import SignalManager
13+
1214

1315
def kill_process_and_descendants(pid, termination_timeout):
1416
# TODO: there's a race condition that new descendants might
@@ -73,17 +75,17 @@ class SubprocessManager(object):
7375
CommandManager objects, each of which manages an individual subprocess.
7476
"""
7577

76-
def __init__(self):
78+
def __init__(self, signal_manager: SignalManager):
7779
self.commands: Dict[int, CommandManager] = {}
80+
self.signal_manager = signal_manager or SignalManager()
7881

79-
try:
80-
loop = asyncio.get_running_loop()
81-
loop.add_signal_handler(
82+
if self.signal_manager.event_loop is not None:
83+
self.signal_manager.add_signal_handler(
8284
signal.SIGINT,
83-
lambda: asyncio.create_task(self._async_handle_sigint()),
85+
lambda s, f: asyncio.create_task(self._async_handle_sigint()),
8486
)
85-
except RuntimeError:
86-
signal.signal(signal.SIGINT, self._handle_sigint)
87+
else:
88+
self.signal_manager.add_signal_handler(signal.SIGINT, self._handle_sigint)
8789

8890
async def _async_handle_sigint(self):
8991
pids = [
@@ -193,7 +195,8 @@ def get(self, pid: int) -> Optional["CommandManager"]:
193195
return self.commands.get(pid, None)
194196

195197
def cleanup(self) -> None:
196-
"""Clean up log files for all running subprocesses."""
198+
"""Clean up signal handler and log files for all running subprocesses."""
199+
self.signal_manager.remove_signal_handler(signal.SIGINT, self.signal_handler)
197200

198201
for v in self.commands.values():
199202
v.cleanup()

0 commit comments

Comments
 (0)