Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions airflow-core/src/airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from airflow.executors.workloads.types import state_class_for_key
from airflow.models import Log
from airflow.models.callback import CallbackKey
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.observability.metrics import stats_utils
from airflow.utils.log.logging_mixin import LoggingMixin

Expand Down Expand Up @@ -78,7 +79,6 @@ def get_execution_api_server_url(conf_source: AirflowConfigParser | ExecutorConf
from airflow.executors.workloads import ExecutorWorkload
from airflow.executors.workloads.types import WorkloadKey, WorkloadState
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey

# Event_buffer dict value type
# Tuple of: state, info
Expand Down Expand Up @@ -217,7 +217,7 @@ def __init__(self, parallelism: int = PARALLELISM, team_name: str | None = None)
self.parallelism: int = parallelism
self.team_name: str | None = team_name
self.queued_tasks: dict[TaskInstanceKey, workloads.ExecuteTask] = {}
self.queued_callbacks: dict[str, workloads.ExecuteCallback] = {}
self.queued_callbacks: dict[CallbackKey, workloads.ExecuteCallback] = {}
self.running: set[WorkloadKey] = set()
self.event_buffer: dict[WorkloadKey, EventBufferValueType] = {}
self._task_event_logs: deque[Log] = deque()
Expand Down Expand Up @@ -265,7 +265,7 @@ def queue_workload(self, workload: ExecutorWorkload, session: Session) -> None:
f"Set supports_callbacks = True and implement callback handling in _process_workloads(). "
f"See LocalExecutor or CeleryExecutor for reference implementation."
)
self.queued_callbacks[workload.callback.id] = workload
self.queued_callbacks[workload.key] = workload
else:
raise ValueError(
f"Un-handled workload type {type(workload).__name__!r} in {type(self).__name__}. "
Expand Down Expand Up @@ -497,7 +497,7 @@ def get_event_buffer(self, dag_ids=None) -> dict[WorkloadKey, EventBufferValueTy

In case dag_ids is specified it will only return and flush events
for the given dag_ids. Otherwise, it returns and flushes all events.
Note: Callback events (with string keys) are always included regardless of dag_ids filter.
Note: Callback events (with CallbackKey keys) are always included regardless of dag_ids filter.

:param dag_ids: the dag_ids to return events for; returns all if given ``None``.
:return: a dict of events
Expand All @@ -508,7 +508,9 @@ def get_event_buffer(self, dag_ids=None) -> dict[WorkloadKey, EventBufferValueTy
self.event_buffer = {}
else:
for key in list(self.event_buffer.keys()):
if isinstance(key, CallbackKey) or key.dag_id in dag_ids:
if isinstance(key, CallbackKey) or (
isinstance(key, TaskInstanceKey) and key.dag_id in dag_ids
):
cleared_events[key] = self.event_buffer.pop(key)

return cleared_events
Expand Down
6 changes: 4 additions & 2 deletions airflow-core/src/airflow/executors/workloads/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ def validate_id(cls, v):

@property
def key(self) -> CallbackKey:
"""Return callback ID as key (CallbackKey = str)."""
return self.id
"""Return callback ID as a CallbackKey instance."""
from airflow.models.callback import CallbackKey # circular import

return CallbackKey(id=self.id)


class ExecuteCallback(BaseDagBundleWorkload):
Expand Down
8 changes: 4 additions & 4 deletions airflow-core/src/airflow/executors/workloads/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,12 @@

from typing import TYPE_CHECKING, TypeAlias

from airflow.models.callback import ExecutorCallback
from airflow.models.callback import CallbackKey, ExecutorCallback
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.state import CallbackState, TaskInstanceState

if TYPE_CHECKING:
from airflow.models.callback import CallbackKey

# Type aliases for workload keys and states (used by executor layer)
WorkloadKey: TypeAlias = TaskInstanceKey | CallbackKey
WorkloadState: TypeAlias = TaskInstanceState | CallbackState
Expand All @@ -43,4 +41,6 @@
def state_class_for_key(key: WorkloadKey) -> type[TaskInstanceState] | type[CallbackState]:
if isinstance(key, TaskInstanceKey):
return TaskInstanceState
return CallbackState
if isinstance(key, CallbackKey):
return CallbackState
raise TypeError(f"Unknown workload key type: {type(key)!r}")
11 changes: 6 additions & 5 deletions airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
)
from airflow.models.asset_state import AssetStateModel
from airflow.models.backfill import Backfill, BackfillDagRun
from airflow.models.callback import Callback, CallbackType, ExecutorCallback
from airflow.models.callback import Callback, CallbackKey, CallbackType, ExecutorCallback
from airflow.models.dag import DagModel
from airflow.models.dag_version import DagVersion
from airflow.models.dagbag import DBDagBag
Expand Down Expand Up @@ -1232,7 +1232,7 @@ def process_executor_events(
ti_primary_key_to_try_number_map: dict[tuple[str, str, str, int], int] = {}
event_buffer = executor.get_event_buffer()
tis_with_right_state: list[TaskInstanceKey] = []
callback_keys_with_events: list[str] = []
callback_keys_with_events: list[CallbackKey] = []

# Report execution - handle both task and callback events
for key, (state, _) in event_buffer.items():
Expand All @@ -1257,16 +1257,17 @@ def process_executor_events(
TaskInstanceState.RESTARTING,
):
tis_with_right_state.append(key)
else:
# Callback event (key is string UUID)
elif isinstance(key, CallbackKey):
cls.logger().info("Received executor event with state %s for callback %s", state, key)
if state in (CallbackState.RUNNING, CallbackState.FAILED, CallbackState.SUCCESS):
callback_keys_with_events.append(key)
else:
cls.logger().error("Unknown workload key type in event buffer: %r", key)

# Handle callback state events
for callback_id in callback_keys_with_events:
state, info = event_buffer.pop(callback_id)
callback = session.get(Callback, callback_id)
callback = session.get(Callback, str(callback_id))
if not callback:
# This should not normally happen - we just received an event for this callback.
# Only possible if callback was deleted mid-execution (e.g., cascade delete from DagRun deletion).
Expand Down
12 changes: 11 additions & 1 deletion airflow-core/src/airflow/models/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from importlib import import_module
Expand All @@ -38,7 +39,16 @@
from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime
from airflow.utils.state import CallbackState

CallbackKey = str # Callback keys are str(UUID)

@dataclass(frozen=True, slots=True)
class CallbackKey:
"""Distinct key type for callbacks, preventing any bare string from passing isinstance checks."""

id: str

def __str__(self) -> str:
return self.id


if TYPE_CHECKING:
from sqlalchemy.orm import Session
Expand Down
25 changes: 21 additions & 4 deletions airflow-core/tests/unit/executors/test_base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from airflow.executors.local_executor import LocalExecutor
from airflow.executors.workloads.base import BundleInfo
from airflow.executors.workloads.callback import CallbackDTO
from airflow.models.callback import CallbackFetchMethod
from airflow.models.callback import CallbackFetchMethod, CallbackKey
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.sdk import BaseOperator
from airflow.sdk.execution_time.callback_supervisor import execute_callback
Expand Down Expand Up @@ -100,14 +100,31 @@ def test_get_event_buffer():
assert len(executor.event_buffer) == 0


def test_get_event_buffer_always_includes_callback_keys():
"""CallbackKey events are always returned regardless of the dag_ids filter."""
executor = BaseExecutor()

date = timezone.utcnow()
ti_key = TaskInstanceKey("my_dag1", "my_task1", date, 1)
callback_key = CallbackKey(id="00000000-0000-0000-0000-000000000042")

executor.event_buffer[ti_key] = State.SUCCESS, None
executor.event_buffer[callback_key] = CallbackState.SUCCESS, None

# Filter for a dag that doesn't match the TI key. Callback should still be included
result = executor.get_event_buffer(("other_dag",))
assert callback_key in result
assert ti_key not in result


def test_log_task_event_branches_on_key_type():
executor = BaseExecutor()
ti_key = TaskInstanceKey("my_dag", "my_task", timezone.utcnow(), 1)

executor.log_task_event(event="task_event", extra="extra", ti_key=ti_key)
assert len(executor._task_event_logs) == 1

callback_key = str(UUID("00000000-0000-0000-0000-000000000001"))
callback_key = CallbackKey(id=str(UUID("00000000-0000-0000-0000-000000000001")))
executor.log_task_event(event="callback_event", extra="extra", ti_key=callback_key)
assert len(executor._task_event_logs) == 1

Expand All @@ -123,7 +140,7 @@ def test_log_task_event_branches_on_key_type():
)
def test_state_methods_pick_callback_state_for_callback_key(method_name, expected_state):
executor = BaseExecutor()
callback_key = str(UUID("00000000-0000-0000-0000-000000000002"))
callback_key = CallbackKey(id=str(UUID("00000000-0000-0000-0000-000000000002")))

getattr(executor, method_name)(callback_key)

Expand Down Expand Up @@ -627,7 +644,7 @@ def test_queue_workload_with_execute_callback(self, dag_maker, session):
executor.queue_workload(callback_workload, session)

assert len(executor.queued_callbacks) == 1
assert callback_data.id in executor.queued_callbacks
assert callback_workload.key in executor.queued_callbacks

@pytest.mark.db_test
def test_get_workloads_prioritizes_callbacks(self, dag_maker, session):
Expand Down
2 changes: 1 addition & 1 deletion airflow-core/tests/unit/executors/test_local_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def test_process_callback_workload_queue_management(self):
executor.start()

try:
executor.queued_callbacks[callback_data.id] = callback_workload
executor.queued_callbacks[callback_workload.key] = callback_workload
executor._process_workloads([callback_workload])
assert len(executor.queued_callbacks) == 0
# We can't easily verify worker execution without running the worker,
Expand Down
52 changes: 52 additions & 0 deletions airflow-core/tests/unit/executors/test_workloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,21 @@
# under the License.
from __future__ import annotations

import dataclasses
from pathlib import PurePosixPath
from uuid import uuid4

import jwt
import pytest

from airflow.api_fastapi.auth.tokens import JWTGenerator
from airflow.executors import workloads
from airflow.executors.workloads import TaskInstance, TaskInstanceDTO, base as workloads_base
from airflow.executors.workloads.base import BaseWorkloadSchema, BundleInfo
from airflow.executors.workloads.callback import CallbackDTO, CallbackFetchMethod
from airflow.executors.workloads.task import ExecuteTask
from airflow.executors.workloads.types import state_class_for_key
from airflow.models.callback import CallbackKey


def test_task_instance_alias_keeps_backwards_compat():
Expand Down Expand Up @@ -82,3 +87,50 @@ def test_generate_token_produces_workload_scope(monkeypatch):
def test_generate_token_without_generator():
"""generate_token should return empty string when no generator is provided."""
assert BaseWorkloadSchema.generate_token("ti-123", None) == ""


def test_callback_key_is_frozen_and_hashable():
"""CallbackKey must be usable as a dict key (hashable) and immutable (frozen)."""
cid = "some-uuid-value"

key = CallbackKey(id=cid)
assert hash(key) == hash(CallbackKey(id=cid))
assert key == CallbackKey(id=cid)
assert key != CallbackKey(id="other")

# Frozen: assignment raises
with pytest.raises(dataclasses.FrozenInstanceError):
key.id = "mutated" # type: ignore[misc]


def test_callback_key_str_returns_id():
"""str(CallbackKey) should return the raw id string."""
cid = "some-uuid-value"

key = CallbackKey(id=cid)
assert str(key) == cid


def test_callback_key_is_not_a_string():
"""CallbackKey must NOT pass isinstance(x, str)."""

key = CallbackKey(id="some-uuid-value")
assert not isinstance(key, str)


def test_state_class_for_key_raises_on_unknown_type():
"""state_class_for_key should raise TypeError for unrecognized key types."""

with pytest.raises(TypeError, match="Unknown workload key type"):
state_class_for_key("bare-string-is-not-a-key") # type: ignore[arg-type]


def test_callback_dto_key_returns_callback_key_instance():
"""CallbackDTO.key should return a CallbackKey, not a bare string."""
cid = "some-uuid-value"

callback = CallbackDTO(id=cid, fetch_method=CallbackFetchMethod.IMPORT_PATH, data={})
key = callback.key
assert isinstance(key, CallbackKey)
assert key.id == cid
assert str(key) == cid
Loading