Skip to content

Commit

Permalink
New BFS method for use in asset backfills (#26751)
Browse files Browse the repository at this point in the history
Summary:
Pulling out the new BFS utility added in
#25997 to its own function
with its own tests to keep the PR size manageable.

Test Plan: New tests

NOCHANGELOG

## Summary & Motivation

## How I Tested These Changes

## Changelog

> Insert changelog entry or delete this section.
  • Loading branch information
gibsondan authored Dec 31, 2024
1 parent 32b7670 commit d54f10c
Show file tree
Hide file tree
Showing 4 changed files with 652 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Awaitable,
Callable,
Dict,
Iterable,
Literal,
NamedTuple,
Optional,
Expand All @@ -17,6 +18,7 @@
from dagster import _check as check
from dagster._core.asset_graph_view.entity_subset import EntitySubset, _ValidatedEntitySubsetValue
from dagster._core.asset_graph_view.serializable_entity_subset import SerializableEntitySubset
from dagster._core.definitions.asset_graph_subset import AssetGraphSubset
from dagster._core.definitions.asset_key import AssetCheckKey, AssetKey, EntityKey, T_EntityKey
from dagster._core.definitions.events import AssetKeyPartitionKey
from dagster._core.definitions.multi_dimensional_partitions import (
Expand Down Expand Up @@ -193,6 +195,34 @@ def get_empty_subset(self, *, key: T_EntityKey) -> EntitySubset[T_EntityKey]:
value = partitions_def.empty_subset() if partitions_def else False
return EntitySubset(self, key=key, value=_ValidatedEntitySubsetValue(value))

def get_entity_subset_from_asset_graph_subset(
self, asset_graph_subset: AssetGraphSubset, key: AssetKey
) -> EntitySubset[AssetKey]:
check.invariant(
self.asset_graph.has(key), f"Asset graph does not contain {key.to_user_string()}"
)

serializable_subset = asset_graph_subset.get_asset_subset(key, self.asset_graph)
check.invariant(
serializable_subset.is_compatible_with_partitions_def(
self._get_partitions_def(key),
),
f"Partitions definition for {key.to_user_string()} is not compatible with the passed in AssetGraphSubset",
)

return EntitySubset(
self, key=key, value=_ValidatedEntitySubsetValue(serializable_subset.value)
)

def iterate_asset_subsets(
self, asset_graph_subset: AssetGraphSubset
) -> Iterable[EntitySubset[AssetKey]]:
"""Returns an Iterable of EntitySubsets representing the subset of each asset that this
AssetGraphSubset contains.
"""
for asset_key in asset_graph_subset.asset_keys:
yield self.get_entity_subset_from_asset_graph_subset(asset_graph_subset, asset_key)

def get_subset_from_serializable_subset(
self, serializable_subset: SerializableEntitySubset[T_EntityKey]
) -> Optional[EntitySubset[T_EntityKey]]:
Expand Down
187 changes: 187 additions & 0 deletions python_modules/dagster/dagster/_core/asset_graph_view/bfs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from functools import total_ordering
from heapq import heapify, heappop, heappush
from typing import TYPE_CHECKING, Callable, Iterable, NamedTuple, Sequence, Tuple

import dagster._check as check
from dagster._core.asset_graph_view.asset_graph_view import AssetGraphView
from dagster._core.asset_graph_view.serializable_entity_subset import SerializableEntitySubset
from dagster._core.definitions.asset_graph_subset import AssetGraphSubset

if TYPE_CHECKING:
from dagster._core.asset_graph_view.entity_subset import EntitySubset


class AssetGraphViewBfsFilterConditionResult(NamedTuple):
passed_asset_graph_subset: AssetGraphSubset
excluded_asset_graph_subsets_and_reasons: Sequence[Tuple[AssetGraphSubset, str]]


def bfs_filter_asset_graph_view(
asset_graph_view: AssetGraphView,
condition_fn: Callable[
["AssetGraphSubset", "AssetGraphSubset"],
AssetGraphViewBfsFilterConditionResult,
],
initial_asset_graph_subset: "AssetGraphSubset",
include_full_execution_set: bool,
) -> Tuple[AssetGraphSubset, Sequence[Tuple[AssetGraphSubset, str]]]:
"""Returns the subset of the graph that satisfy supplied criteria.
- Are >= initial_asset_graph_subset
- Match the condition_fn
- Any of their ancestors >= initial_asset_graph_subset match the condition_fn
Also returns a list of tuples, where each tuple is an asset subset that did not
satisfy the condition and the reason they were filtered out.
The condition_fn takes in:
- a subset of the asset graph to evaluate the condition for. If include_full_execution_set=True,
the asset keys are all part of the same execution set (i.e. non-subsettable multi-asset). If
include_full_execution_set=False, only a single asset key will be in the subset.
- An AssetGraphSubset for the portion of the graph that has so far been visited and passed
the condition.
The condition_fn should return a object with an AssetGraphSubset indicating the portion
of the subset that passes the condition, and a list of (AssetGraphSubset, str)
tuples with more information about why certain subsets were excluded.
Visits parents before children.
"""
initial_subsets = list(asset_graph_view.iterate_asset_subsets(initial_asset_graph_subset))

# invariant: we never consider an asset partition before considering its ancestors
queue = ToposortedPriorityQueue(
asset_graph_view, initial_subsets, include_full_execution_set=include_full_execution_set
)

visited_graph_subset = initial_asset_graph_subset

result: AssetGraphSubset = AssetGraphSubset.empty()
failed_reasons: Sequence[Tuple[AssetGraphSubset, str]] = []

asset_graph = asset_graph_view.asset_graph

while len(queue) > 0:
candidate_subset = queue.dequeue()
condition_result = condition_fn(candidate_subset, result)

subset_that_meets_condition = condition_result.passed_asset_graph_subset
failed_reasons.extend(condition_result.excluded_asset_graph_subsets_and_reasons)

result = result | subset_that_meets_condition

for matching_entity_subset in asset_graph_view.iterate_asset_subsets(
subset_that_meets_condition
):
# Add any child subsets that have not yet been visited to the queue
for child_key in asset_graph.get(matching_entity_subset.key).child_keys:
child_subset = asset_graph_view.compute_child_subset(
child_key, matching_entity_subset
)
unvisited_child_subset = child_subset.compute_difference(
asset_graph_view.get_entity_subset_from_asset_graph_subset(
visited_graph_subset, child_key
)
)
if not unvisited_child_subset.is_empty:
queue.enqueue(unvisited_child_subset)
visited_graph_subset = (
visited_graph_subset
| AssetGraphSubset.from_entity_subsets([unvisited_child_subset])
)

return result, failed_reasons


class ToposortedPriorityQueue:
"""Queue that returns parents before their children."""

@total_ordering
class QueueItem(NamedTuple):
level: int
sort_key: str
asset_graph_subset: AssetGraphSubset

def __eq__(self, other: object) -> bool:
if isinstance(other, ToposortedPriorityQueue.QueueItem):
return self.level == other.level and self.sort_key == other.sort_key
return False

def __lt__(self, other: object) -> bool:
if isinstance(other, ToposortedPriorityQueue.QueueItem):
return self.level < other.level or (
self.level == other.level and self.sort_key < other.sort_key
)
raise TypeError()

def __init__(
self,
asset_graph_view: AssetGraphView,
items: Iterable["EntitySubset"],
include_full_execution_set: bool,
):
self._asset_graph_view = asset_graph_view
self._include_full_execution_set = include_full_execution_set

self._toposort_level_by_asset_key = {
asset_key: i
for i, asset_keys in enumerate(
asset_graph_view.asset_graph.toposorted_asset_keys_by_level
)
for asset_key in asset_keys
}
self._heap = [self._queue_item(entity_subset) for entity_subset in items]
heapify(self._heap)

def enqueue(self, entity_subset: "EntitySubset") -> None:
heappush(self._heap, self._queue_item(entity_subset))

def dequeue(self) -> AssetGraphSubset:
# For multi-assets, will include all required multi-asset keys if
# include_full_execution_set is set to True, or just the passed in
# asset key if it was not. If there are multiple assets in the subset
# the subset will have the same partitions included for each asset.
heap_value = heappop(self._heap)
return heap_value.asset_graph_subset

def _queue_item(self, entity_subset: "EntitySubset") -> "ToposortedPriorityQueue.QueueItem":
asset_key = entity_subset.key

if self._include_full_execution_set:
execution_set_keys = self._asset_graph_view.asset_graph.get(
asset_key
).execution_set_asset_keys
else:
execution_set_keys = {asset_key}

level = max(
self._toposort_level_by_asset_key[asset_key] for asset_key in execution_set_keys
)

serializable_entity_subset = entity_subset.convert_to_serializable_subset()

serializable_entity_subsets = [
SerializableEntitySubset(key=asset_key, value=serializable_entity_subset.value)
for asset_key in execution_set_keys
]

entity_subsets = [
check.not_none(
self._asset_graph_view.get_subset_from_serializable_subset(
serializable_entity_subset
)
)
for serializable_entity_subset in serializable_entity_subsets
]

asset_graph_subset = AssetGraphSubset.from_entity_subsets(entity_subsets)

return ToposortedPriorityQueue.QueueItem(
level,
asset_key.to_string(),
asset_graph_subset=asset_graph_subset,
)

def __len__(self) -> int:
return len(self._heap)
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,10 @@ def from_asset_partition_set(
non_partitioned_asset_keys=non_partitioned_asset_keys,
)

@classmethod
def empty(cls) -> "AssetGraphSubset":
return AssetGraphSubset({}, set())

@classmethod
def from_entity_subsets(
cls, entity_subsets: Iterable[EntitySubset[AssetKey]]
Expand Down
Loading

0 comments on commit d54f10c

Please sign in to comment.