Skip to content

Commit

Permalink
Make AutomationCondition evaluate optionally async (dagster-io#25238)
Browse files Browse the repository at this point in the history
## Summary & Motivation
We now want to make each the `AutomationCondition.evaluate` function optionally async. This is because we still want to keep it non async if users want to make their own automation conditions. To handle this, added a wrapper async evaluate function on `AutomationContext` that checks if the evaluate function is async or not.

## How I Tested These Changes
Existing tests should pass
  • Loading branch information
briantu authored and Grzyblon committed Oct 26, 2024
1 parent f44ac2f commit 9e39cfa
Show file tree
Hide file tree
Showing 27 changed files with 374 additions and 310 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def _evaluate_entity_async(entity_key: EntityKey, offset: int):
)

try:
self.evaluate_entity(entity_key)
await self.evaluate_entity(entity_key)
except Exception as e:
raise Exception(
f"Error while evaluating conditions for {entity_key.to_user_string()}"
Expand Down Expand Up @@ -150,10 +150,9 @@ async def _evaluate_entity_async(entity_key: EntityKey, offset: int):
v for v in self.request_subsets_by_key.values() if not v.is_empty
]

def evaluate_entity(self, key: EntityKey) -> None:
async def evaluate_entity(self, key: EntityKey) -> None:
# evaluate the condition of this asset
context = AutomationContext.create(key=key, evaluator=self)
result = context.condition.evaluate(context)
result = await AutomationContext.create(key=key, evaluator=self).evaluate_async()

# update dictionaries to keep track of this result
self.current_results_by_key[key] = result
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import inspect
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, Mapping, Optional, Type, TypeVar
Expand All @@ -9,6 +10,7 @@
from dagster._core.definitions.asset_key import AssetCheckKey, AssetKey, EntityKey, T_EntityKey
from dagster._core.definitions.declarative_automation.automation_condition import (
AutomationCondition,
AutomationResult,
)
from dagster._core.definitions.declarative_automation.legacy.legacy_context import (
LegacyRuleEvaluationContext,
Expand Down Expand Up @@ -109,6 +111,11 @@ def for_child_condition(
_root_log=self._root_log,
)

async def evaluate_async(self) -> AutomationResult[T_EntityKey]:
if inspect.iscoroutinefunction(self.condition.evaluate):
return await self.condition.evaluate(self)
return self.condition.evaluate(self)

@property
def log(self) -> logging.Logger:
"""The logger for the current condition evaluation."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,13 @@ def children(self) -> Sequence[AutomationCondition]:
def requires_cursor(self) -> bool:
return False

def evaluate(self, context: AutomationContext[AssetKey]) -> AutomationResult[AssetKey]:
child_result = self.operand.evaluate(
context.for_child_condition(
child_condition=self.operand,
child_index=0,
candidate_subset=context.candidate_subset,
)
)
async def evaluate(self, context: AutomationContext[AssetKey]) -> AutomationResult[AssetKey]:
child_result = await context.for_child_condition(
child_condition=self.operand,
child_index=0,
candidate_subset=context.candidate_subset,
).evaluate_async()

return AutomationResult(
context=context, true_subset=child_result.true_subset, child_results=[child_result]
)
Expand Down Expand Up @@ -82,7 +81,7 @@ def _get_validated_downstream_conditions(
if not condition.has_rule_condition
}

def evaluate(self, context: AutomationContext[AssetKey]) -> AutomationResult[AssetKey]:
async def evaluate(self, context: AutomationContext[AssetKey]) -> AutomationResult[AssetKey]:
ignored_conditions = self._get_ignored_conditions(context)
downstream_conditions = self._get_validated_downstream_conditions(
context.asset_graph.get_downstream_automation_conditions(asset_key=context.key)
Expand All @@ -95,15 +94,13 @@ def evaluate(self, context: AutomationContext[AssetKey]) -> AutomationResult[Ass
):
if downstream_condition in ignored_conditions:
continue
child_condition = DownstreamConditionWrapperCondition(
downstream_keys=list(sorted(asset_keys)), operand=downstream_condition
)
child_context = context.for_child_condition(
child_condition=child_condition,
child_result = await context.for_child_condition(
child_condition=DownstreamConditionWrapperCondition(
downstream_keys=list(sorted(asset_keys)), operand=downstream_condition
),
child_index=i,
candidate_subset=context.candidate_subset,
)
child_result = child_condition.evaluate(child_context)
).evaluate_async()

child_results.append(child_result)
true_subset = true_subset.compute_union(child_result.true_subset)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import List, Sequence

import dagster._check as check
Expand Down Expand Up @@ -35,14 +36,15 @@ def children(self) -> Sequence[AutomationCondition[T_EntityKey]]:
def requires_cursor(self) -> bool:
return False

def evaluate(self, context: AutomationContext[T_EntityKey]) -> AutomationResult[T_EntityKey]:
async def evaluate(
self, context: AutomationContext[T_EntityKey]
) -> AutomationResult[T_EntityKey]:
child_results: List[AutomationResult] = []
true_subset = context.candidate_subset
for i, child in enumerate(self.children):
child_context = context.for_child_condition(
child_result = await context.for_child_condition(
child_condition=child, child_index=i, candidate_subset=true_subset
)
child_result = child.evaluate(child_context)
).evaluate_async()
child_results.append(child_result)
true_subset = true_subset.compute_intersection(child_result.true_subset)
return AutomationResult(context, true_subset, child_results=child_results)
Expand Down Expand Up @@ -83,15 +85,20 @@ def children(self) -> Sequence[AutomationCondition[T_EntityKey]]:
def requires_cursor(self) -> bool:
return False

def evaluate(self, context: AutomationContext[T_EntityKey]) -> AutomationResult[T_EntityKey]:
child_results: List[AutomationResult] = []
async def evaluate(
self, context: AutomationContext[T_EntityKey]
) -> AutomationResult[T_EntityKey]:
true_subset = context.get_empty_subset()
for i, child in enumerate(self.children):
child_context = context.for_child_condition(

coroutines = [
context.for_child_condition(
child_condition=child, child_index=i, candidate_subset=context.candidate_subset
)
child_result = child.evaluate(child_context)
child_results.append(child_result)
).evaluate_async()
for i, child in enumerate(self.children)
]

child_results = await asyncio.gather(*coroutines)
for child_result in child_results:
true_subset = true_subset.compute_union(child_result.true_subset)

return AutomationResult(context, true_subset, child_results=child_results)
Expand All @@ -116,11 +123,12 @@ def name(self) -> str:
def children(self) -> Sequence[AutomationCondition[T_EntityKey]]:
return [self.operand]

def evaluate(self, context: AutomationContext[T_EntityKey]) -> AutomationResult[T_EntityKey]:
child_context = context.for_child_condition(
async def evaluate(
self, context: AutomationContext[T_EntityKey]
) -> AutomationResult[T_EntityKey]:
child_result = await context.for_child_condition(
child_condition=self.operand, child_index=0, candidate_subset=context.candidate_subset
)
child_result = self.operand.evaluate(child_context)
).evaluate_async()
true_subset = context.candidate_subset.compute_difference(child_result.true_subset)

return AutomationResult(context, true_subset, child_results=[child_result])
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from abc import abstractmethod
from typing import AbstractSet

Expand Down Expand Up @@ -54,22 +55,22 @@ class AnyChecksCondition(ChecksAutomationCondition):
def base_name(self) -> str:
return "ANY_CHECKS_MATCH"

def evaluate(self, context: AutomationContext[AssetKey]) -> AutomationResult[AssetKey]:
check_results = []
async def evaluate(self, context: AutomationContext[AssetKey]) -> AutomationResult[AssetKey]:
true_subset = context.get_empty_subset()

for i, check_key in enumerate(
sorted(self._get_check_keys(context.key, context.asset_graph))
):
check_condition = EntityMatchesCondition(key=check_key, operand=self.operand)
check_result = check_condition.evaluate(
context.for_child_condition(
child_condition=check_condition,
child_index=i,
candidate_subset=context.candidate_subset,
)
coroutines = [
context.for_child_condition(
child_condition=EntityMatchesCondition(key=check_key, operand=self.operand),
child_index=i,
candidate_subset=context.candidate_subset,
).evaluate_async()
for i, check_key in enumerate(
sorted(self._get_check_keys(context.key, context.asset_graph))
)
check_results.append(check_result)
]

check_results = await asyncio.gather(*coroutines)
for check_result in check_results:
true_subset = true_subset.compute_union(check_result.true_subset)

true_subset = context.candidate_subset.compute_intersection(true_subset)
Expand All @@ -83,21 +84,18 @@ class AllChecksCondition(ChecksAutomationCondition):
def base_name(self) -> str:
return "ALL_CHECKS_MATCH"

def evaluate(self, context: AutomationContext[AssetKey]) -> AutomationResult[AssetKey]:
async def evaluate(self, context: AutomationContext[AssetKey]) -> AutomationResult[AssetKey]:
check_results = []
true_subset = context.candidate_subset

for i, check_key in enumerate(
sorted(self._get_check_keys(context.key, context.asset_graph))
):
check_condition = EntityMatchesCondition(key=check_key, operand=self.operand)
check_result = check_condition.evaluate(
context.for_child_condition(
child_condition=check_condition,
child_index=i,
candidate_subset=context.candidate_subset,
)
)
check_result = await context.for_child_condition(
child_condition=EntityMatchesCondition(key=check_key, operand=self.operand),
child_index=i,
candidate_subset=context.candidate_subset,
).evaluate_async()
check_results.append(check_result)
true_subset = true_subset.compute_intersection(check_result.true_subset)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ class EntityMatchesCondition(
def name(self) -> str:
return self.key.to_user_string()

def evaluate(self, context: AutomationContext[T_EntityKey]) -> AutomationResult[T_EntityKey]:
async def evaluate(
self, context: AutomationContext[T_EntityKey]
) -> AutomationResult[T_EntityKey]:
# if the key we're mapping to is a child of the key we're mapping from and is not
# self-dependent, use the downstream mapping function, otherwise use upstream
if (
Expand All @@ -48,7 +50,7 @@ def evaluate(self, context: AutomationContext[T_EntityKey]) -> AutomationResult[
child_condition=self.operand, child_index=0, candidate_subset=to_candidate_subset
)

to_result = self.operand.evaluate(to_context)
to_result = await to_context.evaluate_async()

true_subset = to_result.true_subset.compute_mapped_subset(
context.key, direction=directions[1]
Expand Down Expand Up @@ -126,19 +128,18 @@ class AnyDepsCondition(DepsAutomationCondition[T_EntityKey]):
def base_name(self) -> str:
return "ANY_DEPS_MATCH"

def evaluate(self, context: AutomationContext[T_EntityKey]) -> AutomationResult[T_EntityKey]:
async def evaluate(
self, context: AutomationContext[T_EntityKey]
) -> AutomationResult[T_EntityKey]:
dep_results = []
true_subset = context.get_empty_subset()

for i, dep_key in enumerate(sorted(self._get_dep_keys(context.key, context.asset_graph))):
dep_condition = EntityMatchesCondition(key=dep_key, operand=self.operand)
dep_result = dep_condition.evaluate(
context.for_child_condition(
child_condition=dep_condition,
child_index=i,
candidate_subset=context.candidate_subset,
)
)
dep_result = await context.for_child_condition(
child_condition=EntityMatchesCondition(key=dep_key, operand=self.operand),
child_index=i,
candidate_subset=context.candidate_subset,
).evaluate_async()
dep_results.append(dep_result)
true_subset = true_subset.compute_union(dep_result.true_subset)

Expand All @@ -152,19 +153,18 @@ class AllDepsCondition(DepsAutomationCondition[T_EntityKey]):
def base_name(self) -> str:
return "ALL_DEPS_MATCH"

def evaluate(self, context: AutomationContext[T_EntityKey]) -> AutomationResult[T_EntityKey]:
async def evaluate(
self, context: AutomationContext[T_EntityKey]
) -> AutomationResult[T_EntityKey]:
dep_results = []
true_subset = context.candidate_subset

for i, dep_key in enumerate(sorted(self._get_dep_keys(context.key, context.asset_graph))):
dep_condition = EntityMatchesCondition(key=dep_key, operand=self.operand)
dep_result = dep_condition.evaluate(
context.for_child_condition(
child_condition=dep_condition,
child_index=i,
candidate_subset=context.candidate_subset,
)
)
dep_result = await context.for_child_condition(
child_condition=EntityMatchesCondition(key=dep_key, operand=self.operand),
child_index=i,
candidate_subset=context.candidate_subset,
).evaluate_async()
dep_results.append(dep_result)
true_subset = true_subset.compute_intersection(dep_result.true_subset)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,14 @@ def _get_previous_child_true_subset(
return None
return context.asset_graph_view.get_subset_from_serializable_subset(true_subset)

def evaluate(self, context: AutomationContext) -> AutomationResult:
async def evaluate(self, context: AutomationContext) -> AutomationResult:
# evaluate child condition
child_context = context.for_child_condition(
child_result = await context.for_child_condition(
self.operand,
child_index=0,
# must evaluate child condition over the entire subset to avoid missing state transitions
candidate_subset=context.asset_graph_view.get_full_subset(key=context.key),
)
child_result = self.operand.evaluate(child_context)
).evaluate_async()

# get the set of asset partitions of the child which newly became true
newly_true_child_subset = child_result.true_subset.compute_difference(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Sequence

from dagster._core.definitions.asset_key import T_EntityKey
Expand Down Expand Up @@ -25,21 +26,23 @@ def name(self) -> str:
def children(self) -> Sequence[AutomationCondition[T_EntityKey]]:
return [self.trigger_condition, self.reset_condition]

def evaluate(self, context: AutomationContext[T_EntityKey]) -> AutomationResult[T_EntityKey]:
async def evaluate(
self, context: AutomationContext[T_EntityKey]
) -> AutomationResult[T_EntityKey]:
# must evaluate child condition over the entire subset to avoid missing state transitions
child_candidate_subset = context.asset_graph_view.get_full_subset(key=context.key)

# compute result for trigger condition
trigger_context = context.for_child_condition(
self.trigger_condition, child_index=0, candidate_subset=child_candidate_subset
# compute result for trigger and reset conditions
trigger_result, reset_result = await asyncio.gather(
*[
context.for_child_condition(
self.trigger_condition, child_index=0, candidate_subset=child_candidate_subset
).evaluate_async(),
context.for_child_condition(
self.reset_condition, child_index=1, candidate_subset=child_candidate_subset
).evaluate_async(),
]
)
trigger_result = self.trigger_condition.evaluate(trigger_context)

# compute result for reset condition
reset_context = context.for_child_condition(
self.reset_condition, child_index=1, candidate_subset=child_candidate_subset
)
reset_result = self.reset_condition.evaluate(reset_context)

# take the previous subset that this was true for
true_subset = context.previous_true_subset or context.get_empty_subset()
Expand Down
Loading

0 comments on commit 9e39cfa

Please sign in to comment.