From 090db47da31bc76df3bdc526819f4e12fa59ea18 Mon Sep 17 00:00:00 2001 From: fjetter Date: Tue, 25 Mar 2025 12:03:36 +0100 Subject: [PATCH 1/3] Reduce dask.order overhead by removing stripped_dep computation --- distributed/scheduler.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index ae64633c50..a7792eb39f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4867,19 +4867,7 @@ async def update_graph( dsk = _cull(dsk, keys) if not internal_priority: - # Removing all non-local keys before calling order() - dsk_keys = set( - dsk - ) # intersection() of sets is much faster than dict_keys - stripped_deps = { - k: v.intersection(dsk_keys) - for k, v in dependencies.items() - if k in dsk_keys - } - - internal_priority = await offload( - dask.order.order, dsk=dsk, dependencies=stripped_deps - ) + internal_priority = await offload(dask.order.order, dsk=dsk) ordering_done = time() logger.debug("Ordering done.") From c04f57cbf07065dc988ed26e8ffe78017803987e Mon Sep 17 00:00:00 2001 From: fjetter Date: Tue, 25 Mar 2025 14:07:55 +0100 Subject: [PATCH 2/3] Finish scheduler type annotations --- distributed/client.py | 2 +- distributed/preloading.py | 4 +- distributed/scheduler.py | 217 +++++++++++++---------- distributed/tests/test_worker_metrics.py | 2 +- pyproject.toml | 1 + 5 files changed, 130 insertions(+), 96 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index e9102b3ce9..6ded7ac338 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -886,7 +886,7 @@ def keys(self) -> Iterable[Key]: def _meta(self): return [] - def _layer(self): + def _layer(self) -> dict[Key, GraphNode]: dsk: _T_LowLevelGraph = {} if not self.kwargs: diff --git a/distributed/preloading.py b/distributed/preloading.py index 96c3889327..a81a0b7660 100644 --- a/distributed/preloading.py +++ b/distributed/preloading.py @@ -251,8 +251,8 @@ def __len__(self) -> int: def process_preloads( dask_server: Server | Client, - preload: str | list[str], - preload_argv: list[str] | list[list[str]], + preload: str | Sequence[str], + preload_argv: Sequence[str] | Sequence[Sequence[str]], *, file_dir: str | None = None, ) -> PreloadManager: diff --git a/distributed/scheduler.py b/distributed/scheduler.py index a7792eb39f..bd336b7f57 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -140,10 +140,14 @@ if TYPE_CHECKING: # TODO import from typing (requires Python >=3.10) # TODO import from typing (requires Python >=3.11) + from typing import TypeVar + from typing_extensions import Self, TypeAlias from dask._expr import Expr + FuncT = TypeVar("FuncT", bound=Callable[..., Any]) + # Not to be confused with distributed.worker_state_machine.TaskStateState TaskStateState: TypeAlias = Literal[ "released", @@ -889,7 +893,7 @@ class Computation: __slots__ = tuple(__annotations__) - def __init__(self): + def __init__(self) -> None: self.start = time() self.groups = set() self.code = SortedSet() @@ -3674,32 +3678,44 @@ class Scheduler(SchedulerState, ServerNode): _workers_removed_total: int _active_graph_updates: int + _starting_nannies: set[str] + worker_plugins: dict[str, bytes] + nanny_plugins: dict[str, bytes] + + client_comms: dict[str, BatchedSend] + stream_comms: dict[str, BatchedSend] + + cumulative_worker_metrics: defaultdict[tuple | str, int] + bandwidth_types: defaultdict[str, float] + bandwidth_workers: defaultdict[tuple[str, str], float] + services: dict + def __init__( self, - loop=None, - services=None, - service_kwargs=None, - allowed_failures=None, - extensions=None, - validate=None, - scheduler_file=None, - security=None, - worker_ttl=None, - idle_timeout=None, - interface=None, - host=None, - port=0, - protocol=None, - dashboard_address=None, - dashboard=None, - http_prefix="/", - preload=None, - preload_argv=(), - plugins=(), - contact_address=None, - transition_counter_max=False, - jupyter=False, - **kwargs, + loop: IOLoop | None = None, + services: dict | None = None, + service_kwargs: dict | None = None, + allowed_failures: int | None = None, + extensions: dict | None = None, + validate: bool | None = None, + scheduler_file: str | None = None, + security: dict | Security | None = None, + worker_ttl: float | None = None, + idle_timeout: float | None = None, + interface: str | None = None, + host: str | None = None, + port: int = 0, + protocol: str | None = None, + dashboard_address: str | None = None, + dashboard: bool | None = None, + http_prefix: str | None = "/", + preload: str | Sequence[str] | None = None, + preload_argv: str | Sequence[str] | Sequence[Sequence[str]] = (), + plugins: Sequence[SchedulerPlugin] = (), + contact_address: str | None = None, + transition_counter_max: bool | int = False, + jupyter: bool = False, + **kwargs: Any, ): if dask.config.get("distributed.scheduler.pickle", default=True) is False: raise RuntimeError( @@ -3754,7 +3770,11 @@ def __init__( preload = dask.config.get("distributed.scheduler.preload") if not preload_argv: preload_argv = dask.config.get("distributed.scheduler.preload-argv") - self.preloads = preloading.process_preloads(self, preload, preload_argv) + self.preloads = preloading.process_preloads( + self, + preload, # type: ignore + preload_argv, + ) if isinstance(security, dict): security = Security(**security) @@ -3809,7 +3829,7 @@ def __init__( from jupyter_server.auth import authorized except ImportError: - def authorized(c): + def authorized(c: FuncT) -> FuncT: return c from jupyter_server.base.handlers import JupyterHandler @@ -3820,8 +3840,8 @@ class ShutdownHandler(JupyterHandler): auth_resource = "server" @tornado.web.authenticated - @authorized - async def post(self): + @authorized # type: ignore + async def post(self) -> None: """Shut down the server.""" self.log.info("Shutting down on /api/shutdown request.") @@ -3860,27 +3880,25 @@ async def post(self): self.stream_comms = {} # Task state - tasks = {} + tasks: dict[Key, TaskState] = {} self.generation = 0 self._last_client = None - self._last_time = 0 - unrunnable = {} + self._last_time = 0.0 + unrunnable: dict[TaskState, float] = {} queued = HeapSet(key=operator.attrgetter("priority")) - self.datasets = {} - # Prefix-keyed containers # Client state - clients = {} + clients: dict[str, ClientState] = {} # Worker state workers = SortedDict() - host_info = {} - resources = {} - aliases = {} + host_info: dict[str, dict[str, Any]] = {} + resources: dict[str, dict[str, float]] = {} + aliases: dict[Hashable, str] = {} self._worker_collections = [ workers, @@ -4013,7 +4031,7 @@ async def post(self): pc = PeriodicCallback(self.check_worker_ttl, self.worker_ttl * 1000) self.periodic_callbacks["worker-ttl"] = pc - pc = PeriodicCallback(self.check_idle, 250) + pc = PeriodicCallback(self.check_idle, 250) # type: ignore self.periodic_callbacks["idle-timeout"] = pc pc = PeriodicCallback(self._check_no_workers, 250) @@ -4251,19 +4269,17 @@ def del_scheduler_file() -> None: setproctitle(f"dask scheduler [{self.address}]") return self - async def close(self, fast=None, close_workers=None, reason="unknown"): + async def close( + self, + timeout: float | None = None, + reason: str = "unknown", + ) -> None: """Send cleanup signal to all coroutines then wait until finished See Also -------- Scheduler.cleanup """ - if fast is not None or close_workers is not None: - warnings.warn( - "The 'fast' and 'close_workers' parameters in Scheduler.close have no " - "effect and will be removed in a future version of distributed.", - FutureWarning, - ) if self.status in (Status.closing, Status.closed): await self.finished() return @@ -4272,7 +4288,7 @@ async def close(self, fast=None, close_workers=None, reason="unknown"): logger.info("Closing scheduler. Reason: %s", reason) setproctitle("dask scheduler [closing]") - async def log_errors(func): + async def log_errors(func: Callable) -> None: try: await func() except Exception: @@ -5295,14 +5311,14 @@ def stimulus_task_finished( def stimulus_task_erred( self, - key=None, - worker=None, - exception=None, - stimulus_id=None, - traceback=None, - run_id=None, - **kwargs, - ): + key: Key, + worker: str, + exception: Any, + stimulus_id: str, + traceback: Any, + run_id: str, + **kwargs: Any, + ) -> RecsMsgs: """Mark that a task has erred on a particular worker""" logger.debug("Stimulus task erred %s, %s", key, worker) @@ -5879,6 +5895,7 @@ def report( # Notify all clients client_keys = list(self.client_comms) elif ts is None: + assert client is not None client_keys = [client] else: # Notify clients interested in key (including `client`) @@ -6301,7 +6318,7 @@ def worker_send(self, worker: str, msg: dict[str, Any]) -> None: stimulus_id=f"worker-send-comm-fail-{time()}", ) - def client_send(self, client, msg): + def client_send(self, client: str, msg: dict) -> None: """Send message to client""" c = self.client_comms.get(client) if c is None: @@ -6352,13 +6369,12 @@ def send_all(self, client_msgs: Msgs, worker_msgs: Msgs) -> None: async def scatter( self, - comm=None, - data=None, - workers=None, - client=None, - broadcast=False, - timeout=2, - ): + data: dict, + workers: Iterable | None, + client: str, + broadcast: bool = False, + timeout: float = 2, + ) -> list[Key]: """Send data out to workers See also @@ -6704,7 +6720,7 @@ async def broadcast( ERROR = object() - async def send_message(addr): + async def send_message(addr: str) -> Any: try: comm = await self.rpc.connect(addr) comm.name = "Scheduler Broadcast" @@ -7198,14 +7214,13 @@ async def _rebalance_move_data( async def replicate( self, - comm=None, - keys=None, - n=None, - workers=None, - branching_factor=2, - delete=True, - stimulus_id=None, - ): + keys: list[Key], + n: int | None = None, + workers: Iterable | None = None, + branching_factor: int = 2, + delete: bool = True, + stimulus_id: str | None = None, + ) -> dict | None: """Replicate data throughout cluster This performs a tree copy of the data throughout the network @@ -7253,6 +7268,7 @@ async def replicate( if delete: del_worker_tasks = defaultdict(set) for ts in tasks: + assert ts.who_has is not None del_candidates = tuple(ts.who_has & workers) if len(del_candidates) > n: for ws in random.sample( @@ -7271,6 +7287,7 @@ async def replicate( ) # Copy not-yet-filled data + gathers: defaultdict[str, dict[Key, list[str]]] while tasks: gathers = defaultdict(dict) for ts in list(tasks): @@ -7312,6 +7329,7 @@ async def replicate( "branching-factor": branching_factor, }, ) + return None @log_errors def workers_to_close( @@ -7414,7 +7432,7 @@ def workers_to_close( limit = sum(limit_bytes.values()) total = sum(group_bytes.values()) - def _key(group): + def _key(group: str) -> tuple[bool, int]: is_idle = not any([wws.processing for wws in groups[group]]) bytes = -group_bytes[group] return is_idle, bytes @@ -7772,7 +7790,13 @@ def report_on_key(self, key: Key, *, client: str | None = None) -> None: ... @overload def report_on_key(self, *, ts: TaskState, client: str | None = None) -> None: ... - def report_on_key(self, key=None, *, ts=None, client=None): + def report_on_key( + self, + key: Key | None = None, + *, + ts: TaskState | None = None, + client: str | None = None, + ) -> None: if (ts is None) == (key is None): raise ValueError( # pragma: nocover f"ts and key are mutually exclusive; received {key=!r}, {ts=!r}" @@ -8339,22 +8363,25 @@ def workers_list(self, workers: Iterable[str] | None) -> list[str]: async def get_profile( self, - comm=None, - workers=None, - scheduler=False, - server=False, - merge_workers=True, - start=None, - stop=None, - key=None, - ): + workers: Iterable | None = None, + scheduler: bool = False, + server: bool = False, + merge_workers: bool = True, + start: float | None = None, + stop: float | None = None, + key: Key | None = None, + ) -> dict: if workers is None: workers = self.workers else: workers = set(self.workers) & set(workers) if scheduler: - return profile.get_profile(self.io_loop.profile, start=start, stop=stop) + return profile.get_profile( + self.io_loop.profile, # type: ignore[attr-defined] + start=start, + stop=stop, + ) results = await asyncio.gather( *( @@ -8366,8 +8393,9 @@ async def get_profile( results = [r for r in results if not isinstance(r, Exception)] + response: dict if merge_workers: - response = profile.merge(*results) + response = profile.merge(*results) # type: ignore else: response = dict(zip(workers, results)) return response @@ -8428,7 +8456,7 @@ async def performance_report( ) -> str: stop = time() # Profiles - compute, scheduler, workers = await asyncio.gather( + compute_d, scheduler_d, workers_d = await asyncio.gather( *[ self.get_profile(start=start), self.get_profile(scheduler=True, start=start), @@ -8437,14 +8465,15 @@ async def performance_report( ) from distributed import profile - def profile_to_figure(state): + def profile_to_figure(state: object) -> object: data = profile.plot_data(state) figure, source = profile.plot_figure(data, sizing_mode="stretch_both") return figure compute, scheduler, workers = map( - profile_to_figure, (compute, scheduler, workers) + profile_to_figure, (compute_d, scheduler_d, workers_d) ) + del compute_d, scheduler_d, workers_d # Task stream task_stream = self.get_task_stream(start=start) @@ -8580,7 +8609,9 @@ def profile_to_figure(state): return data - async def get_worker_logs(self, n=None, workers=None, nanny=False): + async def get_worker_logs( + self, n: int | None = None, workers: list | None = None, nanny: bool = False + ) -> dict: results = await self.broadcast( msg={"op": "get_logs", "n": n}, workers=workers, nanny=nanny ) @@ -8620,7 +8651,9 @@ def get_events( ) -> tuple[tuple[float, Any], ...] | dict[str, tuple[tuple[float, Any], ...]]: return self._broker.get_events(topic) - async def get_worker_monitor_info(self, recent=False, starts=None): + async def get_worker_monitor_info( + self, recent: bool = False, starts: dict | None = None + ) -> dict: if starts is None: starts = {} results = await asyncio.gather( @@ -8842,7 +8875,7 @@ def _refresh_no_workers_since(self, timestamp: float | None = None) -> None: self._no_workers_since = timestamp or monotonic() return - def adaptive_target(self, target_duration=None): + def adaptive_target(self, target_duration: float | None = None) -> int: """Desired number of workers based on the current workload This looks at the current running tasks and memory use, and returns a @@ -8864,7 +8897,7 @@ def adaptive_target(self, target_duration=None): # CPU queued = take(100, concat([self.queued, self.unrunnable.keys()])) - queued_occupancy = 0 + queued_occupancy = 0.0 for ts in queued: queued_occupancy += self._get_prefix_duration(ts.prefix) diff --git a/distributed/tests/test_worker_metrics.py b/distributed/tests/test_worker_metrics.py index 4623cce59d..d52f40b060 100644 --- a/distributed/tests/test_worker_metrics.py +++ b/distributed/tests/test_worker_metrics.py @@ -33,7 +33,7 @@ def get_digests( d = w.digests_total if isinstance(w, Worker) else w.cumulative_worker_metrics digests = { k: v - for k, v in d.items() + for k, v in d.items() # type: ignore if k not in {"latency", "tick-duration", "transfer-bandwidth", "transfer-duration"} and (any(a in k for a in allow) or not allow) diff --git a/pyproject.toml b/pyproject.toml index 3c5bd44721..64b270dfe1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -202,6 +202,7 @@ allow_incomplete_defs = true # Recent or recently overhauled modules featuring stricter validation module = [ "distributed.active_memory_manager", + "distributed.scheduler", "distributed.spans", "distributed.system_monitor", "distributed.worker_memory", From 2fd3de5cc92e57bf5cfb911a8159f96d1aaad541 Mon Sep 17 00:00:00 2001 From: fjetter Date: Tue, 25 Mar 2025 15:18:24 +0100 Subject: [PATCH 3/3] fix tests --- distributed/tests/test_scheduler.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index a9f7d25db5..956c886f1f 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1465,7 +1465,7 @@ async def test_story(c, s, a, b): @gen_cluster(client=True, nthreads=[]) async def test_scatter_no_workers(c, s, direct): with pytest.raises(TimeoutError): - await s.scatter(data={"x": 1}, client="alice", timeout=0.1) + await s.scatter(data={"x": 1}, client="alice", timeout=0.1, workers=None) start = time() with pytest.raises(TimeoutError): @@ -4518,12 +4518,6 @@ async def test_worker_state_unique_regardless_of_address(s, w): assert hash(ws1) != ws2 -@gen_cluster(nthreads=[("", 1)]) -async def test_scheduler_close_fast_deprecated(s, w): - with pytest.warns(FutureWarning): - await s.close(fast=True) - - def test_runspec_regression_sync(loop): # https://github.com/dask/distributed/issues/6624 np = pytest.importorskip("numpy")