Skip to content

Commit 50fd3ff

Browse files
authored
Add NannyPlugins (#5118)
This is like WorkerPlugin, but allows for code to run before the Worker starts up. Unfortunately this requires the Nanny to check in with the scheduler before starting the Worker. In principle this should be fast, but it does delay the common case for the uncommon case. This PR includes an Environ nanny-plugin. If we go with this I think that we should move over PipInstall. We might also move over UploadFile and make a new UploadDirectory
1 parent cf1e412 commit 50fd3ff

File tree

8 files changed

+316
-23
lines changed

8 files changed

+316
-23
lines changed

distributed/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,15 @@
2121
)
2222
from .core import Status, connect, rpc
2323
from .deploy import Adaptive, LocalCluster, SpecCluster, SSHCluster
24-
from .diagnostics.plugin import PipInstall, SchedulerPlugin, WorkerPlugin
24+
from .diagnostics.plugin import (
25+
Environ,
26+
NannyPlugin,
27+
PipInstall,
28+
SchedulerPlugin,
29+
UploadDirectory,
30+
UploadFile,
31+
WorkerPlugin,
32+
)
2533
from .diagnostics.progressbar import progress
2634
from .event import Event
2735
from .lock import Lock

distributed/client.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,12 @@
5959
connect,
6060
rpc,
6161
)
62-
from .diagnostics.plugin import UploadFile, WorkerPlugin, _get_worker_plugin_name
62+
from .diagnostics.plugin import (
63+
NannyPlugin,
64+
UploadFile,
65+
WorkerPlugin,
66+
_get_worker_plugin_name,
67+
)
6368
from .metrics import time
6469
from .objects import HasWhat, SchedulerInfo, WhoHas
6570
from .protocol import to_serialize
@@ -4089,18 +4094,21 @@ def register_worker_callbacks(self, setup=None):
40894094
"""
40904095
return self.register_worker_plugin(_WorkerSetupPlugin(setup))
40914096

4092-
async def _register_worker_plugin(self, plugin=None, name=None):
4093-
responses = await self.scheduler.register_worker_plugin(
4094-
plugin=dumps(plugin, protocol=4), name=name
4095-
)
4097+
async def _register_worker_plugin(self, plugin=None, name=None, nanny=None):
4098+
if nanny or nanny is None and isinstance(plugin, NannyPlugin):
4099+
method = self.scheduler.register_nanny_plugin
4100+
else:
4101+
method = self.scheduler.register_worker_plugin
4102+
4103+
responses = await method(plugin=dumps(plugin, protocol=4), name=name)
40964104
for response in responses.values():
40974105
if response["status"] == "error":
40984106
exc = response["exception"]
40994107
tb = response["traceback"]
41004108
raise exc.with_traceback(tb)
41014109
return responses
41024110

4103-
def register_worker_plugin(self, plugin=None, name=None, **kwargs):
4111+
def register_worker_plugin(self, plugin=None, name=None, nanny=None, **kwargs):
41044112
"""
41054113
Registers a lifecycle worker plugin for all current and future workers.
41064114
@@ -4124,12 +4132,14 @@ def register_worker_plugin(self, plugin=None, name=None, **kwargs):
41244132
41254133
Parameters
41264134
----------
4127-
plugin : WorkerPlugin
4128-
The plugin object to pass to the workers
4135+
plugin : WorkerPlugin or NannyPlugin
4136+
The plugin object to register.
41294137
name : str, optional
41304138
A name for the plugin.
41314139
Registering a plugin with the same name will have no effect.
41324140
If plugin has no name attribute a random name is used.
4141+
nanny : bool, optional
4142+
Whether to register the plugin with workers or nannies.
41334143
**kwargs : optional
41344144
If you pass a class as the plugin, instead of a class instance, then the
41354145
class will be instantiated with any extra keyword arguments.
@@ -4174,10 +4184,15 @@ class will be instantiated with any extra keyword arguments.
41744184

41754185
assert name
41764186

4177-
return self.sync(self._register_worker_plugin, plugin=plugin, name=name)
4187+
return self.sync(
4188+
self._register_worker_plugin, plugin=plugin, name=name, nanny=nanny
4189+
)
41784190

4179-
async def _unregister_worker_plugin(self, name):
4180-
responses = await self.scheduler.unregister_worker_plugin(name=name)
4191+
async def _unregister_worker_plugin(self, name, nanny=None):
4192+
if nanny:
4193+
responses = await self.scheduler.unregister_nanny_plugin(name=name)
4194+
else:
4195+
responses = await self.scheduler.unregister_worker_plugin(name=name)
41814196

41824197
for response in responses.values():
41834198
if response["status"] == "error":
@@ -4186,7 +4201,7 @@ async def _unregister_worker_plugin(self, name):
41864201
raise exc.with_traceback(tb)
41874202
return responses
41884203

4189-
def unregister_worker_plugin(self, name):
4204+
def unregister_worker_plugin(self, name, nanny=None):
41904205
"""Unregisters a lifecycle worker plugin
41914206
41924207
This unregisters an existing worker plugin. As part of the unregistration process
@@ -4220,7 +4235,7 @@ def unregister_worker_plugin(self, name):
42204235
--------
42214236
register_worker_plugin
42224237
"""
4223-
return self.sync(self._unregister_worker_plugin, name=name)
4238+
return self.sync(self._unregister_worker_plugin, name=name, nanny=nanny)
42244239

42254240

42264241
class _WorkerSetupPlugin(WorkerPlugin):

distributed/diagnostics/plugin.py

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
import subprocess
55
import sys
66
import uuid
7+
import zipfile
78

8-
from dask.utils import funcname
9+
from dask.utils import funcname, tmpfile
910

1011
logger = logging.getLogger(__name__)
1112

@@ -175,6 +176,41 @@ def release_key(self, key, state, cause, reason, report):
175176
"""
176177

177178

179+
class NannyPlugin:
180+
"""Interface to extend the Nanny
181+
182+
A worker plugin enables custom code to run at different stages of the Workers'
183+
lifecycle. A nanny plugin does the same thing, but benefits from being able
184+
to run code before the worker is started, or to restart the worker if
185+
necessary.
186+
187+
To implement a plugin implement some of the methods of this class and register
188+
the plugin to your client in order to have it attached to every existing and
189+
future nanny by passing ``nanny=True`` to
190+
:meth:`Client.register_worker_plugin<distributed.Client.register_worker_plugin>`.
191+
192+
The ``restart`` attribute is used to control whether or not a running ``Worker``
193+
needs to be restarted when registering the plugin.
194+
195+
See Also
196+
--------
197+
WorkerPlugin
198+
SchedulerPlugin
199+
"""
200+
201+
restart = False
202+
203+
def setup(self, nanny):
204+
"""
205+
Run when the plugin is attached to a nanny. This happens when the plugin is registered
206+
and attached to existing nannies, or when a nanny is created after the plugin has been
207+
registered.
208+
"""
209+
210+
def teardown(self, nanny):
211+
"""Run when the nanny to which the plugin is attached to is closed"""
212+
213+
178214
def _get_worker_plugin_name(plugin) -> str:
179215
"""Returns the worker plugin name. If plugin has no name attribute
180216
a random name is used."""
@@ -289,3 +325,82 @@ async def setup(self, worker):
289325
comm=None, filename=self.filename, data=self.data, load=True
290326
)
291327
assert len(self.data) == response["nbytes"]
328+
329+
330+
class Environ(NannyPlugin):
331+
restart = True
332+
333+
def __init__(self, environ={}):
334+
self.environ = {k: str(v) for k, v in environ.items()}
335+
336+
async def setup(self, nanny):
337+
nanny.env.update(self.environ)
338+
339+
340+
class UploadDirectory(NannyPlugin):
341+
"""A NannyPlugin to upload a local file to workers.
342+
343+
Parameters
344+
----------
345+
path: str
346+
A path to the directory to upload
347+
348+
Examples
349+
--------
350+
>>> from distributed.diagnostics.plugin import UploadDirectory
351+
>>> client.register_worker_plugin(UploadDirectory("/path/to/directory"), nanny=True) # doctest: +SKIP
352+
"""
353+
354+
def __init__(
355+
self,
356+
path,
357+
restart=False,
358+
update_path=False,
359+
skip_words=(".git", ".github", ".pytest_cache", "tests", "docs"),
360+
skip=(lambda fn: os.path.splitext(fn)[1] == ".pyc",),
361+
):
362+
"""
363+
Initialize the plugin by reading in the data from the given file.
364+
"""
365+
path = os.path.expanduser(path)
366+
self.path = os.path.split(path)[-1]
367+
self.restart = restart
368+
self.update_path = update_path
369+
370+
self.name = "upload-directory-" + os.path.split(path)[-1]
371+
372+
with tmpfile(extension="zip") as fn:
373+
with zipfile.ZipFile(fn, "w", zipfile.ZIP_DEFLATED) as z:
374+
for root, dirs, files in os.walk(path):
375+
for file in files:
376+
filename = os.path.join(root, file)
377+
if any(predicate(filename) for predicate in skip):
378+
continue
379+
dirs = filename.split(os.sep)
380+
if any(word in dirs for word in skip_words):
381+
continue
382+
383+
archive_name = os.path.relpath(
384+
os.path.join(root, file), os.path.join(path, "..")
385+
)
386+
z.write(filename, archive_name)
387+
388+
with open(fn, "rb") as f:
389+
self.data = f.read()
390+
391+
async def setup(self, nanny):
392+
fn = os.path.join(nanny.local_directory, f"tmp-{str(uuid.uuid4())}.zip")
393+
with open(fn, "wb") as f:
394+
f.write(self.data)
395+
396+
import zipfile
397+
398+
with zipfile.ZipFile(fn) as z:
399+
z.extractall(path=nanny.local_directory)
400+
401+
if self.update_path:
402+
path = os.path.join(nanny.local_directory, self.path)
403+
if path not in sys.path:
404+
sys.path.insert(0, path)
405+
406+
os.remove(fn)

distributed/nanny.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import warnings
99
import weakref
1010
from contextlib import suppress
11+
from inspect import isawaitable
1112
from multiprocessing.queues import Empty
1213
from time import sleep as sync_sleep
1314

@@ -22,15 +23,18 @@
2223
from . import preloading
2324
from .comm import get_address_host, unparse_host_port
2425
from .comm.addressing import address_from_user_args
25-
from .core import CommClosedError, RPCClosed, Status, coerce_to_address
26+
from .core import CommClosedError, RPCClosed, Status, coerce_to_address, error_message
27+
from .diagnostics.plugin import _get_worker_plugin_name
2628
from .node import ServerNode
2729
from .process import AsyncProcess
2830
from .proctitle import enable_proctitle_on_children
31+
from .protocol import pickle
2932
from .security import Security
3033
from .utils import (
3134
TimeoutError,
3235
get_ip,
3336
json_load_robust,
37+
log_errors,
3438
mp_context,
3539
parse_ports,
3640
silence_logging,
@@ -110,14 +114,14 @@ def __init__(
110114

111115
if local_directory is None:
112116
local_directory = dask.config.get("temporary-directory") or os.getcwd()
113-
if not os.path.exists(local_directory):
114-
os.makedirs(local_directory)
115117
self._original_local_dir = local_directory
116118
local_directory = os.path.join(local_directory, "dask-worker-space")
117119
else:
118120
self._original_local_dir = local_directory
119121

120122
self.local_directory = local_directory
123+
if not os.path.exists(self.local_directory):
124+
os.makedirs(self.local_directory, exist_ok=True)
121125

122126
self.preload = preload
123127
if self.preload is None:
@@ -205,8 +209,12 @@ def __init__(
205209
"terminate": self.close,
206210
"close_gracefully": self.close_gracefully,
207211
"run": self.run,
212+
"plugin_add": self.plugin_add,
213+
"plugin_remove": self.plugin_remove,
208214
}
209215

216+
self.plugins = {}
217+
210218
super().__init__(
211219
handlers=handlers, io_loop=self.loop, connection_args=self.connection_args
212220
)
@@ -300,6 +308,10 @@ async def start(self):
300308
for preload in self.preloads:
301309
await preload.start()
302310

311+
msg = await self.scheduler.register_nanny()
312+
for name, plugin in msg["nanny-plugins"].items():
313+
await self.plugin_add(plugin=plugin, name=name)
314+
303315
logger.info(" Start Nanny at: %r", self.address)
304316
response = await self.instantiate()
305317
if response == Status.running:
@@ -390,6 +402,47 @@ async def instantiate(self, comm=None) -> Status:
390402
raise
391403
return result
392404

405+
async def plugin_add(self, comm=None, plugin=None, name=None):
406+
with log_errors(pdb=False):
407+
if isinstance(plugin, bytes):
408+
plugin = pickle.loads(plugin)
409+
410+
if name is None:
411+
name = _get_worker_plugin_name(plugin)
412+
413+
assert name
414+
415+
self.plugins[name] = plugin
416+
417+
logger.info("Starting Nanny plugin %s" % name)
418+
if hasattr(plugin, "setup"):
419+
try:
420+
result = plugin.setup(nanny=self)
421+
if isawaitable(result):
422+
result = await result
423+
except Exception as e:
424+
msg = error_message(e)
425+
return msg
426+
if getattr(plugin, "restart", False):
427+
await self.restart()
428+
429+
return {"status": "OK"}
430+
431+
async def plugin_remove(self, comm=None, name=None):
432+
with log_errors(pdb=False):
433+
logger.info(f"Removing Nanny plugin {name}")
434+
try:
435+
plugin = self.plugins.pop(name)
436+
if hasattr(plugin, "teardown"):
437+
result = plugin.teardown(nanny=self)
438+
if isawaitable(result):
439+
result = await result
440+
except Exception as e:
441+
msg = error_message(e)
442+
return msg
443+
444+
return {"status": "OK"}
445+
393446
async def restart(self, comm=None, timeout=30, executor_wait=True):
394447
async def _():
395448
if self.process is not None:
@@ -514,6 +567,14 @@ async def close(self, comm=None, timeout=5, report=None):
514567
for preload in self.preloads:
515568
await preload.teardown()
516569

570+
teardowns = [
571+
plugin.teardown(self)
572+
for plugin in self.plugins.values()
573+
if hasattr(plugin, "teardown")
574+
]
575+
576+
await asyncio.gather(*[td for td in teardowns if isawaitable(td)])
577+
517578
self.stop()
518579
try:
519580
if self.process is not None:

0 commit comments

Comments
 (0)