Skip to content

Commit 4e2eb36

Browse files
committed
add support for run-granularity op concurrency
1 parent 3efd17a commit 4e2eb36

File tree

10 files changed

+259
-71
lines changed

10 files changed

+259
-71
lines changed

python_modules/dagster/dagster/_core/instance/config.py

+7
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,13 @@ def validate_concurrency_config(dagster_config_dict: Mapping[str, Any]):
158158
[],
159159
None,
160160
)
161+
granularity = concurrency_config.get("pools", {}).get("granularity")
162+
if granularity and granularity not in ["run", "op"]:
163+
raise DagsterInvalidConfigError(
164+
f"Found value `{granularity}` for `granularity`, Expected value 'run' or 'op'.",
165+
[],
166+
None,
167+
)
161168

162169
default_concurrency_limit = check.opt_inst(
163170
pluck_config_value(concurrency_config, ["pools", "default_limit"]), int
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
11
import os
22
from collections import defaultdict
33
from collections.abc import Mapping, Sequence
4-
from typing import Optional
4+
from typing import TYPE_CHECKING, Optional
55

66
from dagster._core.instance import DagsterInstance
7+
from dagster._core.run_coordinator.queued_run_coordinator import PoolGranularity
78
from dagster._core.snap.execution_plan_snapshot import ExecutionPlanSnapshot
89
from dagster._core.storage.dagster_run import (
10+
IN_PROGRESS_RUN_STATUSES,
911
DagsterRun,
1012
DagsterRunStatus,
1113
RunOpConcurrency,
1214
RunRecord,
1315
)
1416
from dagster._time import get_current_timestamp
1517

18+
if TYPE_CHECKING:
19+
from dagster._utils.concurrency import ConcurrencyKeyInfo
20+
1621

1722
def compute_run_op_concurrency_info_for_snapshot(
1823
plan_snapshot: ExecutionPlanSnapshot,
@@ -23,21 +28,26 @@ def compute_run_op_concurrency_info_for_snapshot(
2328
root_step_keys = set(
2429
[step_key for step_key, deps in plan_snapshot.step_deps.items() if not deps]
2530
)
26-
pool_counts: Mapping[str, int] = defaultdict(int)
31+
root_pool_counts: Mapping[str, int] = defaultdict(int)
32+
all_pools: set[str] = set()
2733
has_unconstrained_root_nodes = False
2834
for step in plan_snapshot.steps:
29-
if step.key not in root_step_keys:
30-
continue
31-
if step.pool is None:
35+
if step.pool is None and step.key in root_step_keys:
3236
has_unconstrained_root_nodes = True
37+
elif step.pool is None:
38+
continue
39+
elif step.key in root_step_keys:
40+
root_pool_counts[step.pool] += 1
41+
all_pools.add(step.pool)
3342
else:
34-
pool_counts[step.pool] += 1
43+
all_pools.add(step.pool)
3544

36-
if len(pool_counts) == 0:
45+
if len(all_pools) == 0:
3746
return None
3847

3948
return RunOpConcurrency(
40-
root_key_counts=dict(pool_counts),
49+
all_pools=all_pools,
50+
root_key_counts=dict(root_pool_counts),
4151
has_unconstrained_root_nodes=has_unconstrained_root_nodes,
4252
)
4353

@@ -49,12 +59,14 @@ def __init__(
4959
runs: Sequence[DagsterRun],
5060
in_progress_run_records: Sequence[RunRecord],
5161
slot_count_offset: int = 0,
62+
pool_granularity: Optional[PoolGranularity] = None,
5263
):
5364
self._root_pools_by_run = {}
54-
self._concurrency_info_by_pool = {}
65+
self._concurrency_info_by_key: dict[str, ConcurrencyKeyInfo] = {}
5566
self._launched_pool_counts = defaultdict(int)
5667
self._in_progress_pool_counts = defaultdict(int)
5768
self._slot_count_offset = slot_count_offset
69+
self._pool_granularity = pool_granularity if pool_granularity else PoolGranularity.RUN
5870
self._in_progress_run_ids: set[str] = set(
5971
[record.dagster_run.run_id for record in in_progress_run_records]
6072
)
@@ -66,89 +78,152 @@ def __init__(
6678
# priority order
6779
self._fetch_concurrency_info(instance, runs)
6880

69-
# fetch all the outstanding concurrency keys for in-progress runs
81+
# fetch all the outstanding pools for in-progress runs
7082
self._process_in_progress_runs(in_progress_run_records)
7183

7284
def _fetch_concurrency_info(self, instance: DagsterInstance, queued_runs: Sequence[DagsterRun]):
73-
# fetch all the concurrency slot information for the root concurrency keys of all the queued
74-
# runs
75-
all_run_pools = set()
85+
# fetch all the concurrency slot information for all the queued runs
86+
all_pools = set()
7687

7788
configured_pools = instance.event_log_storage.get_concurrency_keys()
7889
for run in queued_runs:
7990
if run.run_op_concurrency:
80-
all_run_pools.update(run.run_op_concurrency.root_key_counts.keys())
81-
82-
for key in all_run_pools:
83-
if key is None:
91+
# if using run granularity, consider all the concurrency keys required by the run
92+
# if using op granularity, consider only the root keys
93+
run_pools = (
94+
run.run_op_concurrency.root_key_counts.keys()
95+
if self._pool_granularity == PoolGranularity.OP
96+
else run.run_op_concurrency.all_pools or []
97+
)
98+
all_pools.update(run_pools)
99+
100+
for pool in all_pools:
101+
if pool is None:
84102
continue
85103

86-
if key not in configured_pools:
87-
instance.event_log_storage.initialize_concurrency_limit_to_default(key)
104+
if pool not in configured_pools:
105+
instance.event_log_storage.initialize_concurrency_limit_to_default(pool)
88106

89-
self._concurrency_info_by_pool[key] = instance.event_log_storage.get_concurrency_info(
90-
key
107+
self._concurrency_info_by_key[pool] = instance.event_log_storage.get_concurrency_info(
108+
pool
91109
)
92110

93-
def _should_allocate_slots_for_root_pools(self, record: RunRecord):
111+
def _should_allocate_slots_for_in_progress_run(self, record: RunRecord):
112+
if not record.dagster_run.run_op_concurrency:
113+
return False
114+
94115
status = record.dagster_run.status
116+
if status not in IN_PROGRESS_RUN_STATUSES:
117+
return False
118+
119+
if self._pool_granularity == PoolGranularity.RUN:
120+
return True
121+
95122
if status == DagsterRunStatus.STARTING:
96123
return True
124+
97125
if status != DagsterRunStatus.STARTED or not record.start_time:
98126
return False
127+
99128
time_elapsed = get_current_timestamp() - record.start_time
100129
if time_elapsed < self._started_run_pools_allotted_seconds:
101130
return True
102131

132+
def _slot_counts_for_run(self, run: DagsterRun) -> Mapping[str, int]:
133+
if not run.run_op_concurrency:
134+
return {}
135+
136+
if self._pool_granularity == PoolGranularity.OP:
137+
return {**run.run_op_concurrency.root_key_counts}
138+
139+
else:
140+
assert self._pool_granularity == PoolGranularity.RUN
141+
return {pool: 1 for pool in run.run_op_concurrency.all_pools or []}
142+
103143
def _process_in_progress_runs(self, in_progress_records: Sequence[RunRecord]):
104144
for record in in_progress_records:
105-
if (
106-
self._should_allocate_slots_for_root_pools(record)
107-
and record.dagster_run.run_op_concurrency
108-
):
109-
for (
110-
pool,
111-
count,
112-
) in record.dagster_run.run_op_concurrency.root_key_counts.items():
113-
self._in_progress_pool_counts[pool] += count
145+
if not self._should_allocate_slots_for_in_progress_run(record):
146+
continue
147+
148+
for pool, count in self._slot_counts_for_run(record.dagster_run).items():
149+
self._in_progress_pool_counts[pool] += count
114150

115151
def is_blocked(self, run: DagsterRun) -> bool:
116152
# if any of the ops in the run can make progress (not blocked by concurrency keys), we
117153
# should dequeue
118-
if not run.run_op_concurrency or run.run_op_concurrency.has_unconstrained_root_nodes:
119-
# if there exists a root node that is not concurrency blocked, we should dequeue.
154+
if not run.run_op_concurrency:
120155
return False
121156

122-
for pool in run.run_op_concurrency.root_key_counts.keys():
123-
if pool not in self._concurrency_info_by_pool:
124-
# there is no concurrency limit set for this key, we should dequeue
125-
return False
126-
127-
key_info = self._concurrency_info_by_pool[pool]
128-
available_count = (
129-
key_info.slot_count
130-
- len(key_info.pending_steps)
131-
- self._launched_pool_counts[pool]
132-
- self._in_progress_pool_counts[pool]
133-
)
134-
if available_count > -1 * self._slot_count_offset:
135-
# there exists a root concurrency key that is not blocked, we should dequeue
136-
return False
157+
if (
158+
self._pool_granularity == PoolGranularity.OP
159+
and run.run_op_concurrency.has_unconstrained_root_nodes
160+
):
161+
# if the granularity is at the op level and there exists a root node that is not
162+
# concurrency blocked, we should dequeue.
163+
return False
137164

138-
# if we reached here, then every root concurrency key is blocked, so we should not dequeue
139-
return True
165+
if self._pool_granularity == PoolGranularity.OP:
166+
# we just need to check all of the root concurrency keys, instead of all the concurrency keys
167+
# in the run
168+
for pool in run.run_op_concurrency.root_key_counts.keys():
169+
if pool not in self._concurrency_info_by_key:
170+
# there is no concurrency limit set for this key, we should dequeue
171+
return False
172+
173+
key_info = self._concurrency_info_by_key[pool]
174+
unaccounted_occupied_slots = [
175+
pending_step
176+
for pending_step in key_info.pending_steps
177+
if pending_step.run_id not in self._in_progress_run_ids
178+
]
179+
available_count = (
180+
key_info.slot_count
181+
- len(unaccounted_occupied_slots)
182+
- self._launched_pool_counts[pool]
183+
- self._in_progress_pool_counts[pool]
184+
)
185+
if available_count + self._slot_count_offset > 0:
186+
# there exists a root concurrency key that is not blocked, we should dequeue
187+
return False
188+
189+
# if we reached here, then every root concurrency key is blocked, so we should not dequeue
190+
return True
191+
192+
else:
193+
assert self._pool_granularity == PoolGranularity.RUN
194+
195+
# if the granularity is at the run level, we should check if any of the concurrency
196+
# keys are blocked
197+
for pool in run.run_op_concurrency.all_pools or []:
198+
if pool not in self._concurrency_info_by_key:
199+
# there is no concurrency limit set for this key
200+
continue
201+
202+
key_info = self._concurrency_info_by_key[pool]
203+
available_count = (
204+
key_info.slot_count
205+
- self._launched_pool_counts[pool]
206+
- self._in_progress_pool_counts[pool]
207+
)
208+
if available_count + self._slot_count_offset <= 0:
209+
return True
210+
211+
# if we reached here then there is at least one available slot for every single concurrency key
212+
# required by this run, so we should dequeue
213+
return False
140214

141215
def get_blocked_run_debug_info(self, run: DagsterRun) -> Mapping:
142216
if not run.run_op_concurrency:
143217
return {}
144218

145219
log_info = {}
146220
for pool in run.run_op_concurrency.root_key_counts.keys():
147-
concurrency_info = self._concurrency_info_by_pool.get(pool)
221+
concurrency_info = self._concurrency_info_by_key.get(pool)
148222
if not concurrency_info:
149223
continue
150224

151225
log_info[pool] = {
226+
"granularity": self._pool_granularity.value,
152227
"slot_count": concurrency_info.slot_count,
153228
"pending_step_count": len(concurrency_info.pending_steps),
154229
"pending_step_run_ids": list(
@@ -160,8 +235,5 @@ def get_blocked_run_debug_info(self, run: DagsterRun) -> Mapping:
160235
return log_info
161236

162237
def update_counters_with_launched_item(self, run: DagsterRun):
163-
if not run.run_op_concurrency:
164-
return
165-
for pool, count in run.run_op_concurrency.root_key_counts.items():
166-
if pool:
167-
self._launched_pool_counts[pool] += count
238+
for pool, count in self._slot_counts_for_run(run).items():
239+
self._launched_pool_counts[pool] += count

0 commit comments

Comments
 (0)