From d54f10cf0edd7b8045cd7e11aeda4cf36f8d1eab Mon Sep 17 00:00:00 2001 From: gibsondan Date: Tue, 31 Dec 2024 09:56:52 -0600 Subject: [PATCH] New BFS method for use in asset backfills (#26751) Summary: Pulling out the new BFS utility added in https://github.com/dagster-io/dagster/pull/25997 to its own function with its own tests to keep the PR size manageable. Test Plan: New tests NOCHANGELOG ## Summary & Motivation ## How I Tested These Changes ## Changelog > Insert changelog entry or delete this section. --- .../asset_graph_view/asset_graph_view.py | 30 ++ .../dagster/_core/asset_graph_view/bfs.py | 187 ++++++++ .../_core/definitions/asset_graph_subset.py | 4 + .../asset_graph_view_tests/test_bfs.py | 431 ++++++++++++++++++ 4 files changed, 652 insertions(+) create mode 100644 python_modules/dagster/dagster/_core/asset_graph_view/bfs.py create mode 100644 python_modules/dagster/dagster_tests/asset_defs_tests/asset_graph_view_tests/test_bfs.py diff --git a/python_modules/dagster/dagster/_core/asset_graph_view/asset_graph_view.py b/python_modules/dagster/dagster/_core/asset_graph_view/asset_graph_view.py index 4c2c6917a0d2c..fe537d4140fc1 100644 --- a/python_modules/dagster/dagster/_core/asset_graph_view/asset_graph_view.py +++ b/python_modules/dagster/dagster/_core/asset_graph_view/asset_graph_view.py @@ -6,6 +6,7 @@ Awaitable, Callable, Dict, + Iterable, Literal, NamedTuple, Optional, @@ -17,6 +18,7 @@ from dagster import _check as check from dagster._core.asset_graph_view.entity_subset import EntitySubset, _ValidatedEntitySubsetValue from dagster._core.asset_graph_view.serializable_entity_subset import SerializableEntitySubset +from dagster._core.definitions.asset_graph_subset import AssetGraphSubset from dagster._core.definitions.asset_key import AssetCheckKey, AssetKey, EntityKey, T_EntityKey from dagster._core.definitions.events import AssetKeyPartitionKey from dagster._core.definitions.multi_dimensional_partitions import ( @@ -193,6 +195,34 @@ def get_empty_subset(self, *, key: T_EntityKey) -> EntitySubset[T_EntityKey]: value = partitions_def.empty_subset() if partitions_def else False return EntitySubset(self, key=key, value=_ValidatedEntitySubsetValue(value)) + def get_entity_subset_from_asset_graph_subset( + self, asset_graph_subset: AssetGraphSubset, key: AssetKey + ) -> EntitySubset[AssetKey]: + check.invariant( + self.asset_graph.has(key), f"Asset graph does not contain {key.to_user_string()}" + ) + + serializable_subset = asset_graph_subset.get_asset_subset(key, self.asset_graph) + check.invariant( + serializable_subset.is_compatible_with_partitions_def( + self._get_partitions_def(key), + ), + f"Partitions definition for {key.to_user_string()} is not compatible with the passed in AssetGraphSubset", + ) + + return EntitySubset( + self, key=key, value=_ValidatedEntitySubsetValue(serializable_subset.value) + ) + + def iterate_asset_subsets( + self, asset_graph_subset: AssetGraphSubset + ) -> Iterable[EntitySubset[AssetKey]]: + """Returns an Iterable of EntitySubsets representing the subset of each asset that this + AssetGraphSubset contains. + """ + for asset_key in asset_graph_subset.asset_keys: + yield self.get_entity_subset_from_asset_graph_subset(asset_graph_subset, asset_key) + def get_subset_from_serializable_subset( self, serializable_subset: SerializableEntitySubset[T_EntityKey] ) -> Optional[EntitySubset[T_EntityKey]]: diff --git a/python_modules/dagster/dagster/_core/asset_graph_view/bfs.py b/python_modules/dagster/dagster/_core/asset_graph_view/bfs.py new file mode 100644 index 0000000000000..2333bc23f9aa4 --- /dev/null +++ b/python_modules/dagster/dagster/_core/asset_graph_view/bfs.py @@ -0,0 +1,187 @@ +from functools import total_ordering +from heapq import heapify, heappop, heappush +from typing import TYPE_CHECKING, Callable, Iterable, NamedTuple, Sequence, Tuple + +import dagster._check as check +from dagster._core.asset_graph_view.asset_graph_view import AssetGraphView +from dagster._core.asset_graph_view.serializable_entity_subset import SerializableEntitySubset +from dagster._core.definitions.asset_graph_subset import AssetGraphSubset + +if TYPE_CHECKING: + from dagster._core.asset_graph_view.entity_subset import EntitySubset + + +class AssetGraphViewBfsFilterConditionResult(NamedTuple): + passed_asset_graph_subset: AssetGraphSubset + excluded_asset_graph_subsets_and_reasons: Sequence[Tuple[AssetGraphSubset, str]] + + +def bfs_filter_asset_graph_view( + asset_graph_view: AssetGraphView, + condition_fn: Callable[ + ["AssetGraphSubset", "AssetGraphSubset"], + AssetGraphViewBfsFilterConditionResult, + ], + initial_asset_graph_subset: "AssetGraphSubset", + include_full_execution_set: bool, +) -> Tuple[AssetGraphSubset, Sequence[Tuple[AssetGraphSubset, str]]]: + """Returns the subset of the graph that satisfy supplied criteria. + + - Are >= initial_asset_graph_subset + - Match the condition_fn + - Any of their ancestors >= initial_asset_graph_subset match the condition_fn + + Also returns a list of tuples, where each tuple is an asset subset that did not + satisfy the condition and the reason they were filtered out. + + The condition_fn takes in: + - a subset of the asset graph to evaluate the condition for. If include_full_execution_set=True, + the asset keys are all part of the same execution set (i.e. non-subsettable multi-asset). If + include_full_execution_set=False, only a single asset key will be in the subset. + + - An AssetGraphSubset for the portion of the graph that has so far been visited and passed + the condition. + + The condition_fn should return a object with an AssetGraphSubset indicating the portion + of the subset that passes the condition, and a list of (AssetGraphSubset, str) + tuples with more information about why certain subsets were excluded. + + Visits parents before children. + """ + initial_subsets = list(asset_graph_view.iterate_asset_subsets(initial_asset_graph_subset)) + + # invariant: we never consider an asset partition before considering its ancestors + queue = ToposortedPriorityQueue( + asset_graph_view, initial_subsets, include_full_execution_set=include_full_execution_set + ) + + visited_graph_subset = initial_asset_graph_subset + + result: AssetGraphSubset = AssetGraphSubset.empty() + failed_reasons: Sequence[Tuple[AssetGraphSubset, str]] = [] + + asset_graph = asset_graph_view.asset_graph + + while len(queue) > 0: + candidate_subset = queue.dequeue() + condition_result = condition_fn(candidate_subset, result) + + subset_that_meets_condition = condition_result.passed_asset_graph_subset + failed_reasons.extend(condition_result.excluded_asset_graph_subsets_and_reasons) + + result = result | subset_that_meets_condition + + for matching_entity_subset in asset_graph_view.iterate_asset_subsets( + subset_that_meets_condition + ): + # Add any child subsets that have not yet been visited to the queue + for child_key in asset_graph.get(matching_entity_subset.key).child_keys: + child_subset = asset_graph_view.compute_child_subset( + child_key, matching_entity_subset + ) + unvisited_child_subset = child_subset.compute_difference( + asset_graph_view.get_entity_subset_from_asset_graph_subset( + visited_graph_subset, child_key + ) + ) + if not unvisited_child_subset.is_empty: + queue.enqueue(unvisited_child_subset) + visited_graph_subset = ( + visited_graph_subset + | AssetGraphSubset.from_entity_subsets([unvisited_child_subset]) + ) + + return result, failed_reasons + + +class ToposortedPriorityQueue: + """Queue that returns parents before their children.""" + + @total_ordering + class QueueItem(NamedTuple): + level: int + sort_key: str + asset_graph_subset: AssetGraphSubset + + def __eq__(self, other: object) -> bool: + if isinstance(other, ToposortedPriorityQueue.QueueItem): + return self.level == other.level and self.sort_key == other.sort_key + return False + + def __lt__(self, other: object) -> bool: + if isinstance(other, ToposortedPriorityQueue.QueueItem): + return self.level < other.level or ( + self.level == other.level and self.sort_key < other.sort_key + ) + raise TypeError() + + def __init__( + self, + asset_graph_view: AssetGraphView, + items: Iterable["EntitySubset"], + include_full_execution_set: bool, + ): + self._asset_graph_view = asset_graph_view + self._include_full_execution_set = include_full_execution_set + + self._toposort_level_by_asset_key = { + asset_key: i + for i, asset_keys in enumerate( + asset_graph_view.asset_graph.toposorted_asset_keys_by_level + ) + for asset_key in asset_keys + } + self._heap = [self._queue_item(entity_subset) for entity_subset in items] + heapify(self._heap) + + def enqueue(self, entity_subset: "EntitySubset") -> None: + heappush(self._heap, self._queue_item(entity_subset)) + + def dequeue(self) -> AssetGraphSubset: + # For multi-assets, will include all required multi-asset keys if + # include_full_execution_set is set to True, or just the passed in + # asset key if it was not. If there are multiple assets in the subset + # the subset will have the same partitions included for each asset. + heap_value = heappop(self._heap) + return heap_value.asset_graph_subset + + def _queue_item(self, entity_subset: "EntitySubset") -> "ToposortedPriorityQueue.QueueItem": + asset_key = entity_subset.key + + if self._include_full_execution_set: + execution_set_keys = self._asset_graph_view.asset_graph.get( + asset_key + ).execution_set_asset_keys + else: + execution_set_keys = {asset_key} + + level = max( + self._toposort_level_by_asset_key[asset_key] for asset_key in execution_set_keys + ) + + serializable_entity_subset = entity_subset.convert_to_serializable_subset() + + serializable_entity_subsets = [ + SerializableEntitySubset(key=asset_key, value=serializable_entity_subset.value) + for asset_key in execution_set_keys + ] + + entity_subsets = [ + check.not_none( + self._asset_graph_view.get_subset_from_serializable_subset( + serializable_entity_subset + ) + ) + for serializable_entity_subset in serializable_entity_subsets + ] + + asset_graph_subset = AssetGraphSubset.from_entity_subsets(entity_subsets) + + return ToposortedPriorityQueue.QueueItem( + level, + asset_key.to_string(), + asset_graph_subset=asset_graph_subset, + ) + + def __len__(self) -> int: + return len(self._heap) diff --git a/python_modules/dagster/dagster/_core/definitions/asset_graph_subset.py b/python_modules/dagster/dagster/_core/definitions/asset_graph_subset.py index dd184e4c42b0b..c746ddd45428f 100644 --- a/python_modules/dagster/dagster/_core/definitions/asset_graph_subset.py +++ b/python_modules/dagster/dagster/_core/definitions/asset_graph_subset.py @@ -268,6 +268,10 @@ def from_asset_partition_set( non_partitioned_asset_keys=non_partitioned_asset_keys, ) + @classmethod + def empty(cls) -> "AssetGraphSubset": + return AssetGraphSubset({}, set()) + @classmethod def from_entity_subsets( cls, entity_subsets: Iterable[EntitySubset[AssetKey]] diff --git a/python_modules/dagster/dagster_tests/asset_defs_tests/asset_graph_view_tests/test_bfs.py b/python_modules/dagster/dagster_tests/asset_defs_tests/asset_graph_view_tests/test_bfs.py new file mode 100644 index 0000000000000..91bf7b5239ea7 --- /dev/null +++ b/python_modules/dagster/dagster_tests/asset_defs_tests/asset_graph_view_tests/test_bfs.py @@ -0,0 +1,431 @@ +from typing import cast + +from dagster import ( + AssetDep, + AssetKey, + AssetSpec, + DailyPartitionsDefinition, + Definitions, + HourlyPartitionsDefinition, + TimeWindow, + TimeWindowPartitionMapping, + TimeWindowPartitionsDefinition, + asset, + multi_asset, +) +from dagster._core.asset_graph_view.asset_graph_view import AssetGraphView +from dagster._core.asset_graph_view.bfs import ( + AssetGraphViewBfsFilterConditionResult, + bfs_filter_asset_graph_view, +) +from dagster._core.definitions.asset_graph_subset import AssetGraphSubset +from dagster._core.definitions.events import AssetKeyPartitionKey +from dagster._core.definitions.partition_mapping import IdentityPartitionMapping +from dagster._time import create_datetime + + +def _get_subset_with_keys(graph_view, keys): + return AssetGraphSubset.from_entity_subsets( + [graph_view.get_full_subset(key=key) for key in keys] + ) + + +def test_bfs_filter_empty_graph(): + """Test BFS filter on empty graph returns empty result.""" + graph_view = AssetGraphView.for_test(Definitions()) + initial_subset = AssetGraphSubset.empty() + + def condition_fn(subset, _visited): + return AssetGraphViewBfsFilterConditionResult(subset, []) + + result, failed = bfs_filter_asset_graph_view( + graph_view, condition_fn, initial_subset, include_full_execution_set=False + ) + + assert result == AssetGraphSubset.empty() + assert failed == [] + + +def test_bfs_filter_dependency_chain(): + """Test BFS filter on linear dependency chain.""" + + @asset + def asset1(): + return 1 + + @asset + def asset2(asset1): + return asset1 + 1 + + @asset + def asset3(asset2): + return asset2 + 1 + + @asset + def asset4(asset3): + return asset3 + 1 + + graph_view = AssetGraphView.for_test(Definitions(assets=[asset1, asset2, asset3, asset4])) + initial_subset = AssetGraphSubset.from_entity_subsets( + [graph_view.get_full_subset(key=asset1.key)] + ) + + def condition_fn(subset, visited): + # Only allow asset1 and asset2 + if AssetKey("asset3") in subset.asset_keys: + assert visited == AssetGraphSubset.from_entity_subsets( + [ + graph_view.get_full_subset(key=asset1.key), + graph_view.get_full_subset(key=asset2.key), + ] + ) + + return AssetGraphViewBfsFilterConditionResult( + AssetGraphSubset.empty(), + [(subset, "asset3 not allowed")], + ) + return AssetGraphViewBfsFilterConditionResult(subset, []) + + result, failed = bfs_filter_asset_graph_view( + graph_view, condition_fn, initial_subset, include_full_execution_set=False + ) + + assert result == AssetGraphSubset.from_entity_subsets( + [graph_view.get_full_subset(key=asset1.key), graph_view.get_full_subset(key=asset2.key)] + ) + assert len(failed) == 1 + assert failed[0][0] == AssetGraphSubset.from_entity_subsets( + [graph_view.get_full_subset(key=asset3.key)] + ) + assert failed[0][1] == "asset3 not allowed" + + +def test_bfs_filter_multi_asset(): + """Test BFS filter with multi-asset.""" + + @asset + def a(): + return 1 + + @asset + def b(): + return 2 + + @multi_asset( + specs=[AssetSpec("c", deps=["a"]), AssetSpec("d", deps=["b"])] + ) # d is level 1, e is level 3, gets assigned 3 + def my_multi_asset(): + pass + + @asset + def e(d): + pass + + graph_view = AssetGraphView.for_test(Definitions(assets=[a, b, my_multi_asset, e])) + initial_subset = _get_subset_with_keys(graph_view, [a.key]) + + def condition_fn(subset, _visited): + return AssetGraphViewBfsFilterConditionResult(subset, []) + + result_without_full_execution_set, _failed = bfs_filter_asset_graph_view( + graph_view, condition_fn, initial_subset, include_full_execution_set=False + ) + + assert result_without_full_execution_set == _get_subset_with_keys( + graph_view, [a.key, AssetKey(["c"])] + ) + + result_with_full_execution_set, _failed = bfs_filter_asset_graph_view( + graph_view, condition_fn, initial_subset, include_full_execution_set=True + ) + + assert result_with_full_execution_set == _get_subset_with_keys( + graph_view, [a.key, AssetKey(["c"]), AssetKey(["d"]), e.key] + ) + + +def test_bfs_filter_diamond(): + """Test BFS filter with a diamond-shaped graph to ensure bottom node is visited once.""" + + @asset + def top(): + return 1 + + @asset + def left(top): + return top + 1 + + @asset + def right(top): + return top + 2 + + @asset + def bottom(left, right): + return left + right + + graph_view = AssetGraphView.for_test(Definitions(assets=[top, left, right, bottom])) + initial_subset = _get_subset_with_keys(graph_view, [AssetKey("top")]) + + visit_count = {} + + def condition_fn(subset, _visited): + for key in subset.asset_keys: + visit_count[key] = visit_count.get(key, 0) + 1 + return AssetGraphViewBfsFilterConditionResult(subset, []) + + result, failed = bfs_filter_asset_graph_view( + graph_view, condition_fn, initial_subset, include_full_execution_set=True + ) + + # Each node should be visited exactly once + assert visit_count == { + AssetKey("bottom"): 1, + AssetKey("left"): 1, + AssetKey("right"): 1, + AssetKey("top"): 1, + } + assert result == _get_subset_with_keys( + graph_view, [AssetKey("top"), AssetKey("left"), AssetKey("right"), AssetKey("bottom")] + ) + assert failed == [] + + +from dagster import AssetIn, StaticPartitionsDefinition + + +def test_bfs_filter_with_partitions(): + """Test BFS filter with partitioned assets where condition filters some partitions.""" + + @asset(partitions_def=StaticPartitionsDefinition(["a", "b", "c"])) + def upstream(): + return 1 + + @asset( + partitions_def=StaticPartitionsDefinition(["a", "b", "c"]), + ins={"upstream": AssetIn(partition_mapping=IdentityPartitionMapping())}, + ) + def downstream(upstream): + return upstream + 1 + + graph_view = AssetGraphView.for_test(Definitions(assets=[upstream, downstream])) + initial_subset = AssetGraphSubset.from_asset_partition_set( + { + AssetKeyPartitionKey(AssetKey(["upstream"]), "a"), + AssetKeyPartitionKey(AssetKey(["upstream"]), "b"), + }, + graph_view.asset_graph, + ) + + def condition_fn(subset, _visited): + # Only allow partition "a" for upstream and corresponding mapped partition "x" for downstream + included = set() + excluded = set() + + for entity_subset in graph_view.iterate_asset_subsets(subset): + for asset_partition in entity_subset.expensively_compute_asset_partitions(): + if asset_partition.partition_key == "b": + excluded.add(asset_partition) + else: + included.add(asset_partition) + + return AssetGraphViewBfsFilterConditionResult( + AssetGraphSubset.from_asset_partition_set(included, graph_view.asset_graph), + ( + [ + ( + AssetGraphSubset.from_asset_partition_set(excluded, graph_view.asset_graph), + "b is not welcome here", + ) + ] + if excluded + else [] + ), + ) + + result, failed = bfs_filter_asset_graph_view( + graph_view, condition_fn, initial_subset, include_full_execution_set=True + ) + + # Should only include partition "a" for upstream and "x" for downstream + assert result == AssetGraphSubset.from_asset_partition_set( + { + AssetKeyPartitionKey(AssetKey(["upstream"]), "a"), + AssetKeyPartitionKey(AssetKey(["downstream"]), "a"), + }, + graph_view.asset_graph, + ) + + assert failed == [ + ( + AssetGraphSubset.from_asset_partition_set( + { + AssetKeyPartitionKey(AssetKey(["upstream"]), "b"), + }, + graph_view.asset_graph, + ), + "b is not welcome here", + ), + ] + + +def test_bfs_filter_time_window_partitions(): + # Create assets with daily partitions + daily_partitions = DailyPartitionsDefinition(start_date="2023-01-01") + hourly_partitions = HourlyPartitionsDefinition(start_date="2023-01-01-00:00") + + @asset(partitions_def=daily_partitions) + def upstream(): + pass + + @asset(partitions_def=hourly_partitions) + def downstream(upstream) -> None: + pass + + graph_view = AssetGraphView.for_test(Definitions(assets=[upstream, downstream])) + + # Initial subset with multiple days + initial_subset = AssetGraphSubset.from_asset_partition_set( + { + AssetKeyPartitionKey(AssetKey(["upstream"]), "2023-01-01"), + AssetKeyPartitionKey(AssetKey(["upstream"]), "2023-01-02"), + AssetKeyPartitionKey(AssetKey(["upstream"]), "2023-01-03"), + }, + graph_view.asset_graph, + ) + + def condition_fn(subset, visited): + # Filter out weekends + included = set() + excluded = set() + + for entity_subset in graph_view.iterate_asset_subsets(subset): + for asset_partition in entity_subset.expensively_compute_asset_partitions(): + partition_date = ( + cast( + TimeWindowPartitionsDefinition, + graph_view.asset_graph.get(entity_subset.key).partitions_def, + ) + .time_window_for_partition_key(asset_partition.partition_key) + .start + ) + if partition_date.weekday() >= 5: # Saturday = 5, Sunday = 6 + excluded.add(asset_partition) + else: + included.add(asset_partition) + + return AssetGraphViewBfsFilterConditionResult( + AssetGraphSubset.from_asset_partition_set(included, graph_view.asset_graph), + ( + [ + ( + AssetGraphSubset.from_asset_partition_set(excluded, graph_view.asset_graph), + "Weekend partitions are excluded", + ) + ] + if excluded + else [] + ), + ) + + result, failed = bfs_filter_asset_graph_view( + graph_view, condition_fn, initial_subset, include_full_execution_set=True + ) + + # Jan 1, 2023 was a Sunday + assert result == AssetGraphSubset( + partitions_subsets_by_asset_key={ + AssetKey(["upstream"]): daily_partitions.get_partition_subset_in_time_window( + TimeWindow(start=create_datetime(2023, 1, 2), end=create_datetime(2023, 1, 4)) + ), + AssetKey(["downstream"]): hourly_partitions.get_partition_subset_in_time_window( + TimeWindow(start=create_datetime(2023, 1, 2), end=create_datetime(2023, 1, 4)) + ), + } + ) + + assert failed == [ + ( + AssetGraphSubset.from_asset_partition_set( + { + AssetKeyPartitionKey(AssetKey(["upstream"]), "2023-01-01"), + }, + graph_view.asset_graph, + ), + "Weekend partitions are excluded", + ), + ] + + +def test_bfs_filter_self_dependent_asset(): + daily_partitions_def = DailyPartitionsDefinition(start_date="2023-01-01") + + @asset( + partitions_def=daily_partitions_def, + deps=[ + AssetDep( + "self_dependent", + partition_mapping=TimeWindowPartitionMapping(start_offset=-1, end_offset=-1), + ), + ], + ) + def self_dependent(context) -> None: + pass + + graph_view = AssetGraphView.for_test(Definitions([self_dependent])) + + initial_subset = AssetGraphSubset.from_asset_partition_set( + { + AssetKeyPartitionKey(AssetKey(["self_dependent"]), "2023-01-01"), + AssetKeyPartitionKey(AssetKey(["self_dependent"]), "2023-01-02"), + }, + graph_view.asset_graph, + ) + + def condition_fn(subset, _visited): + included = set() + excluded = set() + + for asset_partition in subset.iterate_asset_partitions(): + partition_key = asset_partition.partition_key + if partition_key == "2023-01-05": + excluded.add(asset_partition) + else: + included.add(asset_partition) + + return AssetGraphViewBfsFilterConditionResult( + AssetGraphSubset.from_asset_partition_set(included, graph_view.asset_graph), + ( + [ + ( + AssetGraphSubset.from_asset_partition_set(excluded, graph_view.asset_graph), + "2023-01-05 excluded", + ) + ] + if excluded + else [] + ), + ) + + result, failed = bfs_filter_asset_graph_view( + graph_view, condition_fn, initial_subset, include_full_execution_set=True + ) + + assert result == AssetGraphSubset( + partitions_subsets_by_asset_key={ + AssetKey("self_dependent"): daily_partitions_def.get_partition_subset_in_time_window( + TimeWindow(start=create_datetime(2023, 1, 1), end=create_datetime(2023, 1, 5)) + ), + } + ) + + assert failed == [ + ( + AssetGraphSubset.from_asset_partition_set( + { + AssetKeyPartitionKey(AssetKey("self_dependent"), "2023-01-05"), + }, + graph_view.asset_graph, + ), + "2023-01-05 excluded", + ), + ]