From 76d63b1f3d6deaf6737cee07d97d426ed52bfab4 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Mon, 27 Jan 2025 14:54:54 +0000 Subject: [PATCH] Add keyword to allow disabling config forwarding in SSHCluster --- distributed/deploy/ssh.py | 63 ++++++++++++++++++++++++--------------- 1 file changed, 39 insertions(+), 24 deletions(-) diff --git a/distributed/deploy/ssh.py b/distributed/deploy/ssh.py index 481745d139e..40b545782a2 100644 --- a/distributed/deploy/ssh.py +++ b/distributed/deploy/ssh.py @@ -74,6 +74,7 @@ def __init__( # type: ignore[no-untyped-def] worker_module="deprecated", worker_class="distributed.Nanny", remote_python=None, + forward_config=True, loop=None, name=None, ): @@ -92,6 +93,7 @@ def __init__( # type: ignore[no-untyped-def] self.kwargs = copy.copy(kwargs) self.name = name self.remote_python = remote_python + self.forward_config = forward_config if kwargs.get("nprocs") is not None and kwargs.get("n_workers") is not None: raise ValueError( "Both nprocs and n_workers were specified. Use n_workers only." @@ -135,21 +137,24 @@ async def start(self): self.connection = await asyncssh.connect(self.address, **self.connect_options) - result = await self.connection.run("uname") - if result.exit_status == 0: - set_env = 'env DASK_INTERNAL_INHERIT_CONFIG="{}"'.format( - dask.config.serialize(dask.config.global_config) - ) - else: - result = await self.connection.run("cmd /c ver") + if self.forward_config: + result = await self.connection.run("uname") if result.exit_status == 0: - set_env = "set DASK_INTERNAL_INHERIT_CONFIG={} &&".format( + set_env = 'env DASK_INTERNAL_INHERIT_CONFIG="{}"'.format( dask.config.serialize(dask.config.global_config) ) else: - raise Exception( - "Worker failed to set DASK_INTERNAL_INHERIT_CONFIG variable " - ) + result = await self.connection.run("cmd /c ver") + if result.exit_status == 0: + set_env = "set DASK_INTERNAL_INHERIT_CONFIG={} &&".format( + dask.config.serialize(dask.config.global_config) + ) + else: + raise Exception( + "Worker failed to set DASK_INTERNAL_INHERIT_CONFIG variable " + ) + else: + set_env = "" if not self.remote_python: self.remote_python = sys.executable @@ -175,7 +180,7 @@ async def start(self): } ), ] - ) + ).strip() self.proc = await self.connection.create_process(cmd) @@ -214,6 +219,7 @@ def __init__( connect_options: dict, kwargs: dict, remote_python: str | None = None, + forward_config: bool = True, ): super().__init__() @@ -221,6 +227,7 @@ def __init__( self.kwargs = kwargs self.connect_options = connect_options self.remote_python = remote_python or sys.executable + self.forward_config = forward_config async def start(self): try: @@ -235,21 +242,24 @@ async def start(self): self.connection = await asyncssh.connect(self.address, **self.connect_options) - result = await self.connection.run("uname") - if result.exit_status == 0: - set_env = 'env DASK_INTERNAL_INHERIT_CONFIG="{}"'.format( - dask.config.serialize(dask.config.global_config) - ) - else: - result = await self.connection.run("cmd /c ver") + if self.forward_config: + result = await self.connection.run("uname") if result.exit_status == 0: - set_env = "set DASK_INTERNAL_INHERIT_CONFIG={} &&".format( + set_env = 'env DASK_INTERNAL_INHERIT_CONFIG="{}"'.format( dask.config.serialize(dask.config.global_config) ) else: - raise Exception( - "Scheduler failed to set DASK_INTERNAL_INHERIT_CONFIG variable " - ) + result = await self.connection.run("cmd /c ver") + if result.exit_status == 0: + set_env = "set DASK_INTERNAL_INHERIT_CONFIG={} &&".format( + dask.config.serialize(dask.config.global_config) + ) + else: + raise Exception( + "Scheduler failed to set DASK_INTERNAL_INHERIT_CONFIG variable " + ) + else: + set_env = "" cmd = " ".join( [ @@ -260,7 +270,7 @@ async def start(self): "--spec", "'%s'" % dumps({"cls": "distributed.Scheduler", "opts": self.kwargs}), ] - ) + ).strip() self.proc = await self.connection.create_process(cmd) # We watch stderr in order to get the address, then we return @@ -304,6 +314,7 @@ def SSHCluster( worker_module: str = "deprecated", worker_class: str = "distributed.Nanny", remote_python: str | list[str] | None = None, + forward_config: bool = True, **kwargs: Any, ) -> SpecCluster: """Deploy a Dask cluster using SSH @@ -344,6 +355,8 @@ def SSHCluster( The python class to use to create the worker(s). remote_python Path to Python on remote nodes. + forward_config + Forward the local Dask configuration to the remote nodes. Examples -------- @@ -443,6 +456,7 @@ def SSHCluster( "remote_python": ( remote_python[0] if isinstance(remote_python, list) else remote_python ), + "forward_config": forward_config, }, } workers = { @@ -462,6 +476,7 @@ def SSHCluster( if isinstance(remote_python, list) else remote_python ), + "forward_config": forward_config, }, } for i, host in enumerate(hosts[1:])