1
1
import os
2
2
from collections import defaultdict
3
3
from collections .abc import Mapping , Sequence
4
- from typing import Optional
4
+ from typing import TYPE_CHECKING , Optional
5
5
6
6
from dagster ._core .instance import DagsterInstance
7
+ from dagster ._core .run_coordinator .queued_run_coordinator import PoolGranularity
7
8
from dagster ._core .snap .execution_plan_snapshot import ExecutionPlanSnapshot
8
9
from dagster ._core .storage .dagster_run import (
10
+ IN_PROGRESS_RUN_STATUSES ,
9
11
DagsterRun ,
10
12
DagsterRunStatus ,
11
13
RunOpConcurrency ,
12
14
RunRecord ,
13
15
)
14
16
from dagster ._time import get_current_timestamp
15
17
18
+ if TYPE_CHECKING :
19
+ from dagster ._utils .concurrency import ConcurrencyKeyInfo
20
+
16
21
17
22
def compute_run_op_concurrency_info_for_snapshot (
18
23
plan_snapshot : ExecutionPlanSnapshot ,
@@ -23,21 +28,26 @@ def compute_run_op_concurrency_info_for_snapshot(
23
28
root_step_keys = set (
24
29
[step_key for step_key , deps in plan_snapshot .step_deps .items () if not deps ]
25
30
)
26
- pool_counts : Mapping [str , int ] = defaultdict (int )
31
+ root_pool_counts : Mapping [str , int ] = defaultdict (int )
32
+ all_pools : set [str ] = set ()
27
33
has_unconstrained_root_nodes = False
28
34
for step in plan_snapshot .steps :
29
- if step .key not in root_step_keys :
30
- continue
31
- if step .pool is None :
35
+ if step .pool is None and step .key in root_step_keys :
32
36
has_unconstrained_root_nodes = True
37
+ elif step .pool is None :
38
+ continue
39
+ elif step .key in root_step_keys :
40
+ root_pool_counts [step .pool ] += 1
41
+ all_pools .add (step .pool )
33
42
else :
34
- pool_counts [ step .pool ] += 1
43
+ all_pools . add ( step .pool )
35
44
36
- if len (pool_counts ) == 0 :
45
+ if len (all_pools ) == 0 :
37
46
return None
38
47
39
48
return RunOpConcurrency (
40
- root_key_counts = dict (pool_counts ),
49
+ all_pools = all_pools ,
50
+ root_key_counts = dict (root_pool_counts ),
41
51
has_unconstrained_root_nodes = has_unconstrained_root_nodes ,
42
52
)
43
53
@@ -49,12 +59,14 @@ def __init__(
49
59
runs : Sequence [DagsterRun ],
50
60
in_progress_run_records : Sequence [RunRecord ],
51
61
slot_count_offset : int = 0 ,
62
+ pool_granularity : Optional [PoolGranularity ] = None ,
52
63
):
53
64
self ._root_pools_by_run = {}
54
- self ._concurrency_info_by_pool = {}
65
+ self ._concurrency_info_by_key : dict [ str , ConcurrencyKeyInfo ] = {}
55
66
self ._launched_pool_counts = defaultdict (int )
56
67
self ._in_progress_pool_counts = defaultdict (int )
57
68
self ._slot_count_offset = slot_count_offset
69
+ self ._pool_granularity = pool_granularity if pool_granularity else PoolGranularity .RUN
58
70
self ._in_progress_run_ids : set [str ] = set (
59
71
[record .dagster_run .run_id for record in in_progress_run_records ]
60
72
)
@@ -66,89 +78,152 @@ def __init__(
66
78
# priority order
67
79
self ._fetch_concurrency_info (instance , runs )
68
80
69
- # fetch all the outstanding concurrency keys for in-progress runs
81
+ # fetch all the outstanding pools for in-progress runs
70
82
self ._process_in_progress_runs (in_progress_run_records )
71
83
72
84
def _fetch_concurrency_info (self , instance : DagsterInstance , queued_runs : Sequence [DagsterRun ]):
73
- # fetch all the concurrency slot information for the root concurrency keys of all the queued
74
- # runs
75
- all_run_pools = set ()
85
+ # fetch all the concurrency slot information for all the queued runs
86
+ all_pools = set ()
76
87
77
88
configured_pools = instance .event_log_storage .get_concurrency_keys ()
78
89
for run in queued_runs :
79
90
if run .run_op_concurrency :
80
- all_run_pools .update (run .run_op_concurrency .root_key_counts .keys ())
81
-
82
- for key in all_run_pools :
83
- if key is None :
91
+ # if using run granularity, consider all the concurrency keys required by the run
92
+ # if using op granularity, consider only the root keys
93
+ run_pools = (
94
+ run .run_op_concurrency .root_key_counts .keys ()
95
+ if self ._pool_granularity == PoolGranularity .OP
96
+ else run .run_op_concurrency .all_pools or []
97
+ )
98
+ all_pools .update (run_pools )
99
+
100
+ for pool in all_pools :
101
+ if pool is None :
84
102
continue
85
103
86
- if key not in configured_pools :
87
- instance .event_log_storage .initialize_concurrency_limit_to_default (key )
104
+ if pool not in configured_pools :
105
+ instance .event_log_storage .initialize_concurrency_limit_to_default (pool )
88
106
89
- self ._concurrency_info_by_pool [ key ] = instance .event_log_storage .get_concurrency_info (
90
- key
107
+ self ._concurrency_info_by_key [ pool ] = instance .event_log_storage .get_concurrency_info (
108
+ pool
91
109
)
92
110
93
- def _should_allocate_slots_for_root_pools (self , record : RunRecord ):
111
+ def _should_allocate_slots_for_in_progress_run (self , record : RunRecord ):
112
+ if not record .dagster_run .run_op_concurrency :
113
+ return False
114
+
94
115
status = record .dagster_run .status
116
+ if status not in IN_PROGRESS_RUN_STATUSES :
117
+ return False
118
+
119
+ if self ._pool_granularity == PoolGranularity .RUN :
120
+ return True
121
+
95
122
if status == DagsterRunStatus .STARTING :
96
123
return True
124
+
97
125
if status != DagsterRunStatus .STARTED or not record .start_time :
98
126
return False
127
+
99
128
time_elapsed = get_current_timestamp () - record .start_time
100
129
if time_elapsed < self ._started_run_pools_allotted_seconds :
101
130
return True
102
131
132
+ def _slot_counts_for_run (self , run : DagsterRun ) -> Mapping [str , int ]:
133
+ if not run .run_op_concurrency :
134
+ return {}
135
+
136
+ if self ._pool_granularity == PoolGranularity .OP :
137
+ return {** run .run_op_concurrency .root_key_counts }
138
+
139
+ else :
140
+ assert self ._pool_granularity == PoolGranularity .RUN
141
+ return {pool : 1 for pool in run .run_op_concurrency .all_pools or []}
142
+
103
143
def _process_in_progress_runs (self , in_progress_records : Sequence [RunRecord ]):
104
144
for record in in_progress_records :
105
- if (
106
- self ._should_allocate_slots_for_root_pools (record )
107
- and record .dagster_run .run_op_concurrency
108
- ):
109
- for (
110
- pool ,
111
- count ,
112
- ) in record .dagster_run .run_op_concurrency .root_key_counts .items ():
113
- self ._in_progress_pool_counts [pool ] += count
145
+ if not self ._should_allocate_slots_for_in_progress_run (record ):
146
+ continue
147
+
148
+ for pool , count in self ._slot_counts_for_run (record .dagster_run ).items ():
149
+ self ._in_progress_pool_counts [pool ] += count
114
150
115
151
def is_blocked (self , run : DagsterRun ) -> bool :
116
152
# if any of the ops in the run can make progress (not blocked by concurrency keys), we
117
153
# should dequeue
118
- if not run .run_op_concurrency or run .run_op_concurrency .has_unconstrained_root_nodes :
119
- # if there exists a root node that is not concurrency blocked, we should dequeue.
154
+ if not run .run_op_concurrency :
120
155
return False
121
156
122
- for pool in run .run_op_concurrency .root_key_counts .keys ():
123
- if pool not in self ._concurrency_info_by_pool :
124
- # there is no concurrency limit set for this key, we should dequeue
125
- return False
126
-
127
- key_info = self ._concurrency_info_by_pool [pool ]
128
- available_count = (
129
- key_info .slot_count
130
- - len (key_info .pending_steps )
131
- - self ._launched_pool_counts [pool ]
132
- - self ._in_progress_pool_counts [pool ]
133
- )
134
- if available_count > - 1 * self ._slot_count_offset :
135
- # there exists a root concurrency key that is not blocked, we should dequeue
136
- return False
157
+ if (
158
+ self ._pool_granularity == PoolGranularity .OP
159
+ and run .run_op_concurrency .has_unconstrained_root_nodes
160
+ ):
161
+ # if the granularity is at the op level and there exists a root node that is not
162
+ # concurrency blocked, we should dequeue.
163
+ return False
137
164
138
- # if we reached here, then every root concurrency key is blocked, so we should not dequeue
139
- return True
165
+ if self ._pool_granularity == PoolGranularity .OP :
166
+ # we just need to check all of the root concurrency keys, instead of all the concurrency keys
167
+ # in the run
168
+ for pool in run .run_op_concurrency .root_key_counts .keys ():
169
+ if pool not in self ._concurrency_info_by_key :
170
+ # there is no concurrency limit set for this key, we should dequeue
171
+ return False
172
+
173
+ key_info = self ._concurrency_info_by_key [pool ]
174
+ unaccounted_occupied_slots = [
175
+ pending_step
176
+ for pending_step in key_info .pending_steps
177
+ if pending_step .run_id not in self ._in_progress_run_ids
178
+ ]
179
+ available_count = (
180
+ key_info .slot_count
181
+ - len (unaccounted_occupied_slots )
182
+ - self ._launched_pool_counts [pool ]
183
+ - self ._in_progress_pool_counts [pool ]
184
+ )
185
+ if available_count + self ._slot_count_offset > 0 :
186
+ # there exists a root concurrency key that is not blocked, we should dequeue
187
+ return False
188
+
189
+ # if we reached here, then every root concurrency key is blocked, so we should not dequeue
190
+ return True
191
+
192
+ else :
193
+ assert self ._pool_granularity == PoolGranularity .RUN
194
+
195
+ # if the granularity is at the run level, we should check if any of the concurrency
196
+ # keys are blocked
197
+ for pool in run .run_op_concurrency .all_pools or []:
198
+ if pool not in self ._concurrency_info_by_key :
199
+ # there is no concurrency limit set for this key
200
+ continue
201
+
202
+ key_info = self ._concurrency_info_by_key [pool ]
203
+ available_count = (
204
+ key_info .slot_count
205
+ - self ._launched_pool_counts [pool ]
206
+ - self ._in_progress_pool_counts [pool ]
207
+ )
208
+ if available_count + self ._slot_count_offset <= 0 :
209
+ return True
210
+
211
+ # if we reached here then there is at least one available slot for every single concurrency key
212
+ # required by this run, so we should dequeue
213
+ return False
140
214
141
215
def get_blocked_run_debug_info (self , run : DagsterRun ) -> Mapping :
142
216
if not run .run_op_concurrency :
143
217
return {}
144
218
145
219
log_info = {}
146
220
for pool in run .run_op_concurrency .root_key_counts .keys ():
147
- concurrency_info = self ._concurrency_info_by_pool .get (pool )
221
+ concurrency_info = self ._concurrency_info_by_key .get (pool )
148
222
if not concurrency_info :
149
223
continue
150
224
151
225
log_info [pool ] = {
226
+ "granularity" : self ._pool_granularity .value ,
152
227
"slot_count" : concurrency_info .slot_count ,
153
228
"pending_step_count" : len (concurrency_info .pending_steps ),
154
229
"pending_step_run_ids" : list (
@@ -160,8 +235,5 @@ def get_blocked_run_debug_info(self, run: DagsterRun) -> Mapping:
160
235
return log_info
161
236
162
237
def update_counters_with_launched_item (self , run : DagsterRun ):
163
- if not run .run_op_concurrency :
164
- return
165
- for pool , count in run .run_op_concurrency .root_key_counts .items ():
166
- if pool :
167
- self ._launched_pool_counts [pool ] += count
238
+ for pool , count in self ._slot_counts_for_run (run ).items ():
239
+ self ._launched_pool_counts [pool ] += count
0 commit comments