Skip to content

Commit 8774f28

Browse files
ashbjedcunninghamdimberman
authored
Swap Dag Parsing to use the TaskSDK machinery. (#44972)
As part of Airflow 3 DAG definition files will have to use the Task SDK for all their classes, and anything involving running user code will need to be de-coupled from the database in the user-code process. This change moves all of the "serialization" change up to the DagFileProcessorManager, using the new function introduced in #44898 and the "subprocess" machinery introduced in #44874. **Important Note**: this change does not remove the ability for dag processes to access the DB for Variables etc. That will come in a future change. Some key parts of this change: - It builds upon the WatchedSubprocess from the TaskSDK. Right now this puts a nasty/unwanted depenednecy between the Dag Parsing code upon the TaskSDK. This will be addressed before release (we have talked about introducing a new "apache-airflow-base-executor" dist where this subprocess+supervisor could live, as the "execution_time" folder in the Task SDK is more a feature of the executor, not of the TaskSDK itself.) - A number of classes that we need to send between processes have been converted to Pydantic for ease of serialization. - In order to not have to serialize everything in the subprocess and deserialize everything in the parent Manager process, we have created a `LazyDeserializedDAG` class that provides lazy access to much of the properties needed to create update the DAG related DB objects, without needing to fully deserialize the entire DAG structure. - Classes switched to attrs based for less boilerplate in constructors. - Internal timers convert to `time.monotonic` where possible, and `time.time` where not, we only need second diff between two points, not datetime objects. - With the earlier removal of "sync mode" for SQLite in #44839 the need for separate TERMINATE and END messages over the control socket can go. --------- Co-authored-by: Jed Cunningham <[email protected]> Co-authored-by: Daniel Imberman <[email protected]>
1 parent cb40ffb commit 8774f28

25 files changed

Lines changed: 1131 additions & 2086 deletions

airflow/callbacks/callback_requests.py

Lines changed: 25 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -16,49 +16,38 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19-
import json
2019
from typing import TYPE_CHECKING
2120

21+
from pydantic import BaseModel
22+
23+
from airflow.api_fastapi.execution_api.datamodels import taskinstance as ti_datamodel # noqa: TC001
2224
from airflow.utils.state import TaskInstanceState
2325

2426
if TYPE_CHECKING:
25-
from airflow.models.taskinstance import SimpleTaskInstance
27+
from airflow.typing_compat import Self
2628

2729

28-
class CallbackRequest:
30+
class CallbackRequest(BaseModel):
2931
"""
3032
Base Class with information about the callback to be executed.
3133
32-
:param full_filepath: File Path to use to run the callback
3334
:param msg: Additional Message that can be used for logging
3435
:param processor_subdir: Directory used by Dag Processor when parsed the dag.
3536
"""
3637

37-
def __init__(
38-
self,
39-
full_filepath: str,
40-
processor_subdir: str | None = None,
41-
msg: str | None = None,
42-
):
43-
self.full_filepath = full_filepath
44-
self.processor_subdir = processor_subdir
45-
self.msg = msg
46-
47-
def __eq__(self, other):
48-
if isinstance(other, self.__class__):
49-
return self.__dict__ == other.__dict__
50-
return NotImplemented
51-
52-
def __repr__(self):
53-
return str(self.__dict__)
54-
55-
def to_json(self) -> str:
56-
return json.dumps(self.__dict__)
38+
full_filepath: str
39+
"""File Path to use to run the callback"""
40+
processor_subdir: str | None = None
41+
"""Directory used by Dag Processor when parsed the dag"""
42+
msg: str | None = None
43+
"""Additional Message that can be used for logging to determine failure/zombie"""
5744

5845
@classmethod
59-
def from_json(cls, json_str: str):
60-
json_object = json.loads(json_str)
61-
return cls(**json_object)
46+
def from_json(cls, data: str | bytes | bytearray) -> Self:
47+
return cls.model_validate_json(data)
48+
49+
def to_json(self, **kwargs) -> str:
50+
return self.model_dump_json(**kwargs)
6251

6352

6453
class TaskCallbackRequest(CallbackRequest):
@@ -67,25 +56,12 @@ class TaskCallbackRequest(CallbackRequest):
6756
6857
A Class with information about the success/failure TI callback to be executed. Currently, only failure
6958
callbacks (when tasks are externally killed) and Zombies are run via DagFileProcessorProcess.
70-
71-
:param full_filepath: File Path to use to run the callback
72-
:param simple_task_instance: Simplified Task Instance representation
73-
:param msg: Additional Message that can be used for logging to determine failure/zombie
74-
:param processor_subdir: Directory used by Dag Processor when parsed the dag.
75-
:param task_callback_type: e.g. whether on success, on failure, on retry.
7659
"""
7760

78-
def __init__(
79-
self,
80-
full_filepath: str,
81-
simple_task_instance: SimpleTaskInstance,
82-
processor_subdir: str | None = None,
83-
msg: str | None = None,
84-
task_callback_type: TaskInstanceState | None = None,
85-
):
86-
super().__init__(full_filepath=full_filepath, processor_subdir=processor_subdir, msg=msg)
87-
self.simple_task_instance = simple_task_instance
88-
self.task_callback_type = task_callback_type
61+
ti: ti_datamodel.TaskInstance
62+
"""Simplified Task Instance representation"""
63+
task_callback_type: TaskInstanceState | None = None
64+
"""Whether on success, on failure, on retry"""
8965

9066
@property
9167
def is_failure_callback(self) -> bool:
@@ -98,42 +74,11 @@ def is_failure_callback(self) -> bool:
9874
TaskInstanceState.UPSTREAM_FAILED,
9975
}
10076

101-
def to_json(self) -> str:
102-
from airflow.serialization.serialized_objects import BaseSerialization
103-
104-
val = BaseSerialization.serialize(self.__dict__, strict=True)
105-
return json.dumps(val)
106-
107-
@classmethod
108-
def from_json(cls, json_str: str):
109-
from airflow.serialization.serialized_objects import BaseSerialization
110-
111-
val = json.loads(json_str)
112-
return cls(**BaseSerialization.deserialize(val))
113-
11477

11578
class DagCallbackRequest(CallbackRequest):
116-
"""
117-
A Class with information about the success/failure DAG callback to be executed.
118-
119-
:param full_filepath: File Path to use to run the callback
120-
:param dag_id: DAG ID
121-
:param run_id: Run ID for the DagRun
122-
:param processor_subdir: Directory used by Dag Processor when parsed the dag.
123-
:param is_failure_callback: Flag to determine whether it is a Failure Callback or Success Callback
124-
:param msg: Additional Message that can be used for logging
125-
"""
79+
"""A Class with information about the success/failure DAG callback to be executed."""
12680

127-
def __init__(
128-
self,
129-
full_filepath: str,
130-
dag_id: str,
131-
run_id: str,
132-
processor_subdir: str | None,
133-
is_failure_callback: bool | None = True,
134-
msg: str | None = None,
135-
):
136-
super().__init__(full_filepath=full_filepath, processor_subdir=processor_subdir, msg=msg)
137-
self.dag_id = dag_id
138-
self.run_id = run_id
139-
self.is_failure_callback = is_failure_callback
81+
dag_id: str
82+
run_id: str
83+
is_failure_callback: bool | None = True
84+
"""Flag to determine whether it is a Failure Callback or Success Callback"""

airflow/cli/commands/local_commands/dag_processor_command.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from __future__ import annotations
2020

2121
import logging
22-
from datetime import timedelta
2322
from typing import Any
2423

2524
from airflow.cli.commands.local_commands.daemon_utils import run_command_with_daemon_option
@@ -36,11 +35,10 @@
3635
def _create_dag_processor_job_runner(args: Any) -> DagProcessorJobRunner:
3736
"""Create DagFileProcessorProcess instance."""
3837
processor_timeout_seconds: int = conf.getint("core", "dag_file_processor_timeout")
39-
processor_timeout = timedelta(seconds=processor_timeout_seconds)
4038
return DagProcessorJobRunner(
4139
job=Job(),
4240
processor=DagFileProcessorManager(
43-
processor_timeout=processor_timeout,
41+
processor_timeout=processor_timeout_seconds,
4442
dag_directory=args.subdir,
4543
max_runs=args.num_runs,
4644
),
@@ -54,10 +52,6 @@ def dag_processor(args):
5452
if not conf.getboolean("scheduler", "standalone_dag_processor"):
5553
raise SystemExit("The option [scheduler/standalone_dag_processor] must be True.")
5654

57-
sql_conn: str = conf.get("database", "sql_alchemy_conn").lower()
58-
if sql_conn.startswith("sqlite"):
59-
raise SystemExit("Standalone DagProcessor is not supported when using sqlite.")
60-
6155
job_runner = _create_dag_processor_job_runner(args)
6256

6357
reload_configuration_for_dag_processing()

airflow/dag_processing/collection.py

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727

2828
from __future__ import annotations
2929

30-
import itertools
3130
import logging
3231
import traceback
3332
from typing import TYPE_CHECKING, NamedTuple
@@ -64,12 +63,13 @@
6463
from sqlalchemy.sql import Select
6564

6665
from airflow.models.dagwarning import DagWarning
66+
from airflow.serialization.serialized_objects import MaybeSerializedDAG
6767
from airflow.typing_compat import Self
6868

6969
log = logging.getLogger(__name__)
7070

7171

72-
def _create_orm_dags(dags: Iterable[DAG], *, session: Session) -> Iterator[DagModel]:
72+
def _create_orm_dags(dags: Iterable[MaybeSerializedDAG], *, session: Session) -> Iterator[DagModel]:
7373
for dag in dags:
7474
orm_dag = DagModel(dag_id=dag.dag_id)
7575
if dag.is_paused_upon_creation is not None:
@@ -124,7 +124,7 @@ class _RunInfo(NamedTuple):
124124
num_active_runs: dict[str, int]
125125

126126
@classmethod
127-
def calculate(cls, dags: dict[str, DAG], *, session: Session) -> Self:
127+
def calculate(cls, dags: dict[str, MaybeSerializedDAG], *, session: Session) -> Self:
128128
"""
129129
Query the the run counts from the db.
130130
@@ -169,7 +169,7 @@ def _update_dag_owner_links(dag_owner_links: dict[str, str], dm: DagModel, *, se
169169
)
170170

171171

172-
def _serialize_dag_capturing_errors(dag: DAG, session: Session, processor_subdir: str | None):
172+
def _serialize_dag_capturing_errors(dag: MaybeSerializedDAG, session: Session, processor_subdir: str | None):
173173
"""
174174
Try to serialize the dag to the DB, but make a note of any errors.
175175
@@ -192,7 +192,7 @@ def _serialize_dag_capturing_errors(dag: DAG, session: Session, processor_subdir
192192
_sync_dag_perms(dag, session=session)
193193
else:
194194
# Check and update DagCode
195-
DagCode.update_source_code(dag)
195+
DagCode.update_source_code(dag.dag_id, dag.fileloc)
196196
return []
197197
except OperationalError:
198198
raise
@@ -202,7 +202,7 @@ def _serialize_dag_capturing_errors(dag: DAG, session: Session, processor_subdir
202202
return [(dag.fileloc, traceback.format_exc(limit=-dagbag_import_error_traceback_depth))]
203203

204204

205-
def _sync_dag_perms(dag: DAG, session: Session):
205+
def _sync_dag_perms(dag: MaybeSerializedDAG, session: Session):
206206
"""Sync DAG specific permissions."""
207207
dag_id = dag.dag_id
208208

@@ -270,7 +270,7 @@ def _update_import_errors(
270270

271271

272272
def update_dag_parsing_results_in_db(
273-
dags: Collection[DAG],
273+
dags: Collection[MaybeSerializedDAG],
274274
import_errors: dict[str, str],
275275
processor_subdir: str | None,
276276
warnings: set[DagWarning],
@@ -347,7 +347,7 @@ def update_dag_parsing_results_in_db(
347347
class DagModelOperation(NamedTuple):
348348
"""Collect DAG objects and perform database operations for them."""
349349

350-
dags: dict[str, DAG]
350+
dags: dict[str, MaybeSerializedDAG]
351351

352352
def find_orm_dags(self, *, session: Session) -> dict[str, DagModel]:
353353
"""Find existing DagModel objects from DAG objects."""
@@ -380,6 +380,8 @@ def update_dags(
380380
processor_subdir: str | None = None,
381381
session: Session,
382382
) -> None:
383+
from airflow.configuration import conf
384+
383385
# we exclude backfill from active run counts since their concurrency is separate
384386
run_info = _RunInfo.calculate(
385387
dags=self.dags,
@@ -393,19 +395,41 @@ def update_dags(
393395
dm.is_active = True
394396
dm.has_import_errors = False
395397
dm.last_parsed_time = utcnow()
396-
dm.default_view = dag.default_view
398+
dm.default_view = dag.default_view or conf.get("webserver", "dag_default_view").lower()
397399
if hasattr(dag, "_dag_display_property_value"):
398400
dm._dag_display_property_value = dag._dag_display_property_value
399401
elif dag.dag_display_name != dag.dag_id:
400402
dm._dag_display_property_value = dag.dag_display_name
401403
dm.description = dag.description
402-
dm.max_active_tasks = dag.max_active_tasks
403-
dm.max_active_runs = dag.max_active_runs
404-
dm.max_consecutive_failed_dag_runs = dag.max_consecutive_failed_dag_runs
405-
dm.has_task_concurrency_limits = any(
406-
t.max_active_tis_per_dag is not None or t.max_active_tis_per_dagrun is not None
407-
for t in dag.tasks
408-
)
404+
405+
# These "is not None" checks are because with a LazySerializedDag object where the user hasn't
406+
# specified an explicit value, we don't get the default values from the config in the lazy
407+
# serialized ver
408+
# we just
409+
if dag.max_active_tasks is not None:
410+
dm.max_active_tasks = dag.max_active_tasks
411+
elif dag.max_active_tasks is None and dm.max_active_tasks is None:
412+
dm.max_active_tasks = conf.getint("core", "max_active_tasks_per_dag")
413+
414+
if dag.max_active_runs is not None:
415+
dm.max_active_runs = dag.max_active_runs
416+
elif dag.max_active_runs is None and dm.max_active_runs is None:
417+
dm.max_active_runs = conf.getint("core", "max_active_runs_per_dag")
418+
419+
if dag.max_consecutive_failed_dag_runs is not None:
420+
dm.max_consecutive_failed_dag_runs = dag.max_consecutive_failed_dag_runs
421+
elif dag.max_consecutive_failed_dag_runs is None and dm.max_consecutive_failed_dag_runs is None:
422+
dm.max_consecutive_failed_dag_runs = conf.getint(
423+
"core", "max_consecutive_failed_dag_runs_per_dag"
424+
)
425+
426+
if hasattr(dag, "has_task_concurrency_limits"):
427+
dm.has_task_concurrency_limits = dag.has_task_concurrency_limits
428+
else:
429+
dm.has_task_concurrency_limits = any(
430+
t.max_active_tis_per_dag is not None or t.max_active_tis_per_dagrun is not None
431+
for t in dag.tasks
432+
)
409433
dm.timetable_summary = dag.timetable.summary
410434
dm.timetable_description = dag.timetable.description
411435
dm.asset_expression = dag.timetable.asset_condition.as_expression()
@@ -419,7 +443,7 @@ def update_dags(
419443
if run_info.num_active_runs.get(dag.dag_id, 0) >= dm.max_active_runs:
420444
dm.next_dagrun_create_after = None
421445
else:
422-
dm.calculate_dagrun_date_fields(dag, last_automated_data_interval)
446+
dm.calculate_dagrun_date_fields(dag, last_automated_data_interval) # type: ignore[arg-type]
423447

424448
if not dag.timetable.asset_condition:
425449
dm.schedule_asset_references = []
@@ -436,24 +460,20 @@ def update_dags(
436460
dm.dag_owner_links = []
437461

438462

439-
def _find_all_assets(dags: Iterable[DAG]) -> Iterator[Asset]:
463+
def _find_all_assets(dags: Iterable[MaybeSerializedDAG]) -> Iterator[Asset]:
440464
for dag in dags:
441465
for _, asset in dag.timetable.asset_condition.iter_assets():
442466
yield asset
443-
for task in dag.task_dict.values():
444-
for obj in itertools.chain(task.inlets, task.outlets):
445-
if isinstance(obj, Asset):
446-
yield obj
467+
for _, alias in dag.get_task_assets(of_type=Asset):
468+
yield alias
447469

448470

449-
def _find_all_asset_aliases(dags: Iterable[DAG]) -> Iterator[AssetAlias]:
471+
def _find_all_asset_aliases(dags: Iterable[MaybeSerializedDAG]) -> Iterator[AssetAlias]:
450472
for dag in dags:
451473
for _, alias in dag.timetable.asset_condition.iter_asset_aliases():
452474
yield alias
453-
for task in dag.task_dict.values():
454-
for obj in itertools.chain(task.inlets, task.outlets):
455-
if isinstance(obj, AssetAlias):
456-
yield obj
475+
for _, alias in dag.get_task_assets(of_type=AssetAlias):
476+
yield alias
457477

458478

459479
def _find_active_assets(name_uri_assets, session: Session):
@@ -500,7 +520,7 @@ class AssetModelOperation(NamedTuple):
500520
asset_aliases: dict[str, AssetAlias]
501521

502522
@classmethod
503-
def collect(cls, dags: dict[str, DAG]) -> Self:
523+
def collect(cls, dags: dict[str, MaybeSerializedDAG]) -> Self:
504524
coll = cls(
505525
schedule_asset_references={
506526
dag_id: [asset for _, asset in dag.timetable.asset_condition.iter_assets()]
@@ -511,13 +531,7 @@ def collect(cls, dags: dict[str, DAG]) -> Self:
511531
for dag_id, dag in dags.items()
512532
},
513533
outlet_references={
514-
dag_id: [
515-
(task_id, outlet)
516-
for task_id, task in dag.task_dict.items()
517-
for outlet in task.outlets
518-
if isinstance(outlet, Asset)
519-
]
520-
for dag_id, dag in dags.items()
534+
dag_id: list(dag.get_task_assets(inlets=False, outlets=True)) for dag_id, dag in dags.items()
521535
},
522536
assets={(asset.name, asset.uri): asset for asset in _find_all_assets(dags.values())},
523537
asset_aliases={alias.name: alias for alias in _find_all_asset_aliases(dags.values())},

0 commit comments

Comments
 (0)