Skip to content

Commit d72bd90

Browse files
committed
thread pool info to remote representation
1 parent 9d20984 commit d72bd90

File tree

5 files changed

+83
-0
lines changed

5 files changed

+83
-0
lines changed

python_modules/dagster/dagster/_core/definitions/asset_graph.py

+10
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,16 @@ def tags(self) -> Mapping[str, str]:
8585
def kinds(self) -> AbstractSet[str]:
8686
return self._spec.kinds or set()
8787

88+
@property
89+
def pools(self) -> Optional[set[str]]:
90+
if not self.assets_def.computation:
91+
return None
92+
return set(
93+
op_def.pool
94+
for op_def in self.assets_def.computation.node_def.iterate_op_defs()
95+
if op_def.pool
96+
)
97+
8898
@property
8999
def owners(self) -> Sequence[str]:
90100
return self._spec.owners

python_modules/dagster/dagster/_core/definitions/base_asset_graph.py

+4
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,10 @@ def metadata(self) -> ArbitraryMetadataMapping: ...
135135
@abstractmethod
136136
def tags(self) -> Mapping[str, str]: ...
137137

138+
@property
139+
@abstractmethod
140+
def pools(self) -> Optional[set[str]]: ...
141+
138142
@property
139143
@abstractmethod
140144
def owners(self) -> Sequence[str]: ...

python_modules/dagster/dagster/_core/definitions/remote_asset_graph.py

+15
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ def group_name(self) -> str:
8080
def metadata(self) -> ArbitraryMetadataMapping:
8181
return self.resolve_to_singular_repo_scoped_node().asset_node_snap.metadata
8282

83+
@property
84+
def pools(self) -> Optional[set[str]]:
85+
return self.resolve_to_singular_repo_scoped_node().pools
86+
8387
@property
8488
def tags(self) -> Mapping[str, str]:
8589
return self.resolve_to_singular_repo_scoped_node().asset_node_snap.tags or {}
@@ -187,6 +191,10 @@ def auto_materialize_policy(self) -> Optional[AutoMaterializePolicy]:
187191
def auto_observe_interval_minutes(self) -> Optional[float]:
188192
return self.asset_node_snap.auto_observe_interval_minutes
189193

194+
@property
195+
def pools(self) -> Optional[set[str]]:
196+
return self.asset_node_snap.pools
197+
190198

191199
@whitelist_for_serdes
192200
@record
@@ -268,6 +276,13 @@ def is_external(self) -> bool:
268276
def is_executable(self) -> bool:
269277
return any(node.asset_node.is_executable for node in self.repo_scoped_asset_infos)
270278

279+
@cached_property
280+
def pools(self) -> Optional[set[str]]:
281+
pools = set()
282+
for info in self.repo_scoped_asset_infos:
283+
pools.update(info.asset_node.pools or set())
284+
return pools
285+
271286
@property
272287
def partition_mappings(self) -> Mapping[AssetKey, PartitionMapping]:
273288
if self.is_materializable:

python_modules/dagster/dagster/_core/remote_representation/external_data.py

+11
Original file line numberDiff line numberDiff line change
@@ -1395,6 +1395,7 @@ class AssetNodeSnap(IHaveNew):
13951395
parent_edges: Sequence[AssetParentEdgeSnap]
13961396
child_edges: Sequence[AssetChildEdgeSnap]
13971397
execution_type: AssetExecutionType
1398+
pools: set[str]
13981399
compute_kind: Optional[str]
13991400
op_name: Optional[str]
14001401
op_names: Sequence[str]
@@ -1428,6 +1429,7 @@ def __new__(
14281429
parent_edges: Sequence[AssetParentEdgeSnap],
14291430
child_edges: Sequence[AssetChildEdgeSnap],
14301431
execution_type: Optional[AssetExecutionType] = None,
1432+
pools: Optional[set[str]] = None,
14311433
compute_kind: Optional[str] = None,
14321434
op_name: Optional[str] = None,
14331435
op_names: Optional[Sequence[str]] = None,
@@ -1503,6 +1505,7 @@ def __new__(
15031505
parent_edges=parent_edges or [],
15041506
child_edges=child_edges or [],
15051507
compute_kind=compute_kind,
1508+
pools=pools or set(),
15061509
op_name=op_name,
15071510
op_names=op_names or [],
15081511
code_version=code_version,
@@ -1663,6 +1666,12 @@ def asset_node_snaps_from_repo(repo: RepositoryDefinition) -> Sequence[AssetNode
16631666
graph_name = (
16641667
root_node_handle.name if root_node_handle != output_handle.node_handle else None
16651668
)
1669+
op_defs = [
1670+
cast(OpDefinition, job_def.graph.get_node(node_handle).definition)
1671+
for node_handle in node_handles
1672+
if isinstance(job_def.graph.get_node(node_handle).definition, OpDefinition)
1673+
]
1674+
pools = {op_def.pool for op_def in op_defs if op_def.pool}
16661675
op_names = sorted([str(handle) for handle in node_handles])
16671676
op_name = graph_name or next(iter(op_names), None) or node_def.name
16681677
job_names = sorted([jd.name for jd in job_defs_by_asset_key[key]])
@@ -1680,6 +1689,7 @@ def asset_node_snaps_from_repo(repo: RepositoryDefinition) -> Sequence[AssetNode
16801689

16811690
else:
16821691
graph_name = None
1692+
pools = set()
16831693
op_names = []
16841694
op_name = None
16851695
job_names = []
@@ -1718,6 +1728,7 @@ def asset_node_snaps_from_repo(repo: RepositoryDefinition) -> Sequence[AssetNode
17181728
],
17191729
execution_type=asset_node.execution_type,
17201730
compute_kind=compute_kind,
1731+
pools=pools,
17211732
op_name=op_name,
17221733
op_names=op_names,
17231734
code_version=asset_node.code_version,

python_modules/dagster/dagster_tests/asset_defs_tests/test_external_asset_graph.py

+43
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from dagster import (
99
AssetIn,
1010
AssetKey,
11+
AssetSpec,
1112
DagsterInstance,
1213
DailyPartitionsDefinition,
1314
Definitions,
@@ -16,6 +17,9 @@
1617
StaticPartitionMapping,
1718
StaticPartitionsDefinition,
1819
asset,
20+
graph_asset,
21+
multi_asset,
22+
op,
1923
)
2024
from dagster._core.definitions.auto_materialize_policy import AutoMaterializePolicy
2125
from dagster._core.definitions.backfill_policy import BackfillPolicy
@@ -364,3 +368,42 @@ def test_dup_node_detection(instance):
364368
_ = _make_context(
365369
instance, ["dup_observation_defs_a", "dup_observation_defs_b"]
366370
).asset_graph
371+
372+
373+
@asset(pool="foo")
374+
def my_asset():
375+
pass
376+
377+
378+
@op(pool="bar")
379+
def my_op():
380+
pass
381+
382+
383+
@graph_asset
384+
def my_graph_asset():
385+
return my_op()
386+
387+
388+
@multi_asset(
389+
specs=[
390+
AssetSpec("multi_asset_1"),
391+
AssetSpec("multi_asset_2"),
392+
],
393+
pool="baz",
394+
)
395+
def my_multi_asset():
396+
pass
397+
398+
399+
concurrency_assets = Definitions(assets=[my_asset, my_graph_asset, my_multi_asset])
400+
401+
402+
def test_pool_snap(instance) -> None:
403+
context = _make_context(instance, ["concurrency_assets"])
404+
asset_graph = context.asset_graph
405+
assert asset_graph
406+
assert asset_graph.get(AssetKey("my_asset")).pools == {"foo"}
407+
assert asset_graph.get(AssetKey("my_graph_asset")).pools == {"bar"}
408+
assert asset_graph.get(AssetKey("multi_asset_1")).pools == {"baz"}
409+
assert asset_graph.get(AssetKey("multi_asset_2")).pools == {"baz"}

0 commit comments

Comments
 (0)