Skip to content

Commit

Permalink
Make compute asset subset functions async
Browse files Browse the repository at this point in the history
  • Loading branch information
briantu committed Oct 15, 2024
1 parent 2257294 commit b0136c3
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import (
TYPE_CHECKING,
AbstractSet,
Awaitable,
Callable,
Dict,
Literal,
Expand Down Expand Up @@ -396,18 +397,18 @@ def _compute_execution_failed_check_subset(
def _compute_missing_check_subset(self, key: AssetCheckKey) -> EntitySubset[AssetCheckKey]:
return self.compute_subset_with_status(key, None)

def _compute_run_in_progress_asset_subset(self, key: AssetKey) -> EntitySubset[AssetKey]:
async def _compute_run_in_progress_asset_subset(self, key: AssetKey) -> EntitySubset[AssetKey]:
from dagster._core.storage.partition_status_cache import AssetStatusCacheValue

partitions_def = self._get_partitions_def(key)
if partitions_def:
cache_value = AssetStatusCacheValue.blocking_get(self, (key, partitions_def))
if cache_value is None:
value = partitions_def.empty_subset()
else:
value = cache_value.deserialize_in_progress_partition_subsets(partitions_def)
else:
value = self._queryer.get_in_progress_asset_subset(asset_key=key).value
cache_value = await AssetStatusCacheValue.gen(self, (key, partitions_def))
return (
cache_value.get_in_progress_subset(self, key, partitions_def)
if cache_value
else self.get_empty_subset(key=key)
)
value = self._queryer.get_in_progress_asset_subset(asset_key=key).value
return EntitySubset(self, key=key, value=_ValidatedEntitySubsetValue(value))

def _compute_backfill_in_progress_asset_subset(self, key: AssetKey) -> EntitySubset[AssetKey]:
Expand All @@ -418,21 +419,21 @@ def _compute_backfill_in_progress_asset_subset(self, key: AssetKey) -> EntitySub
)
return EntitySubset(self, key=key, value=_ValidatedEntitySubsetValue(value))

def _compute_execution_failed_asset_subset(self, key: AssetKey) -> EntitySubset[AssetKey]:
async def _compute_execution_failed_asset_subset(self, key: AssetKey) -> EntitySubset[AssetKey]:
from dagster._core.storage.partition_status_cache import AssetStatusCacheValue

partitions_def = self._get_partitions_def(key)
if partitions_def:
cache_value = AssetStatusCacheValue.blocking_get(self, (key, partitions_def))
if cache_value is None:
value = partitions_def.empty_subset()
else:
value = cache_value.deserialize_failed_partition_subsets(partitions_def)
else:
value = self._queryer.get_failed_asset_subset(asset_key=key).value
cache_value = await AssetStatusCacheValue.gen(self, (key, partitions_def))
return (
cache_value.get_failed_subset(self, key, partitions_def)
if cache_value
else self.get_empty_subset(key=key)
)
value = self._queryer.get_failed_asset_subset(asset_key=key).value
return EntitySubset(self, key=key, value=_ValidatedEntitySubsetValue(value))

def _compute_missing_asset_subset(
async def _compute_missing_asset_subset(
self, key: AssetKey, from_subset: EntitySubset
) -> EntitySubset[AssetKey]:
"""Returns a subset which is the subset of the input subset that has never been materialized
Expand All @@ -448,7 +449,7 @@ def _compute_missing_asset_subset(
# cheap call which takes advantage of the partition status cache
partitions_def = self._get_partitions_def(key)
if partitions_def:
cache_value = AssetStatusCacheValue.blocking_get(self, (key, partitions_def))
cache_value = await AssetStatusCacheValue.gen(self, (key, partitions_def))
return (
cache_value.get_materialized_subset(self, key, partitions_def)
if cache_value
Expand Down Expand Up @@ -489,16 +490,18 @@ def compute_backfill_in_progress_subset(self, *, key: EntityKey) -> EntitySubset
)

@cached_method
def compute_execution_failed_subset(self, *, key: EntityKey) -> EntitySubset:
return _dispatch(
async def compute_execution_failed_subset(self, *, key: EntityKey) -> EntitySubset:
return await _dispatch(
key=key,
check_method=self._compute_execution_failed_check_subset,
asset_method=self._compute_execution_failed_asset_subset,
)

@cached_method
def compute_missing_subset(self, *, key: EntityKey, from_subset: EntitySubset) -> EntitySubset:
return _dispatch(
async def compute_missing_subset(
self, *, key: EntityKey, from_subset: EntitySubset
) -> EntitySubset:
return await _dispatch(
key=key,
check_method=self._compute_missing_check_subset,
asset_method=functools.partial(
Expand Down Expand Up @@ -621,14 +624,14 @@ def _build_multi_partition_subset(
O_Dispatch = TypeVar("O_Dispatch")


def _dispatch(
async def _dispatch(
*,
key: EntityKey,
check_method: Callable[[AssetCheckKey], O_Dispatch],
asset_method: Callable[[AssetKey], O_Dispatch],
asset_method: Callable[[AssetKey], Awaitable[O_Dispatch]],
) -> O_Dispatch:
"""Applies a method for either a check or an asset."""
if isinstance(key, AssetCheckKey):
return check_method(key)
else:
return asset_method(key)
return await asset_method(key)
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ class MissingAutomationCondition(SubsetAutomationCondition):
def name(self) -> str:
return "missing"

def compute_subset(self, context: AutomationContext) -> EntitySubset:
return context.asset_graph_view.compute_missing_subset(
async def compute_subset(self, context: AutomationContext) -> EntitySubset:
return await context.asset_graph_view.compute_missing_subset(
key=context.key, from_subset=context.candidate_subset
)

Expand Down Expand Up @@ -95,8 +95,8 @@ class ExecutionFailedAutomationCondition(SubsetAutomationCondition):
def name(self) -> str:
return "execution_failed"

def compute_subset(self, context: AutomationContext) -> EntitySubset:
return context.asset_graph_view.compute_execution_failed_subset(key=context.key)
async def compute_subset(self, context: AutomationContext) -> EntitySubset:
return await context.asset_graph_view.compute_execution_failed_subset(key=context.key)


@whitelist_for_serdes
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from abc import abstractmethod

from dagster._core.asset_graph_view.entity_subset import EntitySubset
Expand All @@ -23,10 +24,14 @@ def compute_subset(
self, context: AutomationContext[T_EntityKey]
) -> EntitySubset[T_EntityKey]: ...

def evaluate(self, context: AutomationContext[T_EntityKey]) -> AutomationResult[T_EntityKey]:
async def evaluate(
self, context: AutomationContext[T_EntityKey]
) -> AutomationResult[T_EntityKey]:
# don't compute anything if there are no candidates
if context.candidate_subset.is_empty:
true_subset = context.get_empty_subset()
elif inspect.iscoroutinefunction(self.compute_subset):
true_subset = await self.compute_subset(context)
else:
true_subset = self.compute_subset(context)

Expand Down

0 comments on commit b0136c3

Please sign in to comment.