Skip to content

Commit

Permalink
Run the task with the configured dag bundle
Browse files Browse the repository at this point in the history
  • Loading branch information
dstandish committed Dec 6, 2024
1 parent 1b4922d commit 5fd9bfe
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 48 deletions.
16 changes: 15 additions & 1 deletion airflow/executors/workloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ class BaseActivity(BaseModel):
"""The identity token for this workload"""


class DagBundle(BaseModel):
name: str
classpath: str
kwargs: dict
version: str


class TaskInstance(BaseModel):
"""Schema for TaskInstance with minimal required fields needed for Executors and Task SDK."""

Expand All @@ -49,6 +56,7 @@ class TaskInstance(BaseModel):
run_id: str
try_number: int
map_index: int | None = None
bundle: DagBundle | None = None

# TODO: Task-SDK: Can we replace TastInstanceKey with just the uuid across the codebase?
@property
Expand Down Expand Up @@ -84,7 +92,13 @@ def make(cls, ti: TIModel, dag_path: Path | None = None) -> ExecuteTask:
from airflow.utils.helpers import log_filename_template_renderer

ser_ti = TaskInstance.model_validate(ti, from_attributes=True)

bundle = ti.dag_model.dag_bundle
ser_ti.bundle = DagBundle.model_construct(
name=bundle.name,
classpath=bundle.classpath,
kwargs=bundle.kwargs,
version=bundle.version,
)
dag_path = dag_path or Path(ti.dag_run.dag_model.relative_fileloc)

if dag_path and not dag_path.is_absolute():
Expand Down
111 changes: 66 additions & 45 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
UnknownExecutorException,
)
from airflow.executors.executor_loader import ExecutorLoader
from airflow.executors.workloads import DagBundle
from airflow.models.asset import (
AssetDagRunQueue,
AssetModel,
Expand Down Expand Up @@ -241,51 +242,6 @@ def _triggerer_is_healthy():
return job and job.is_alive()


@provide_session
def _create_orm_dagrun(
dag,
dag_id,
run_id,
logical_date,
start_date,
external_trigger,
conf,
state,
run_type,
dag_version,
creating_job_id,
data_interval,
backfill_id,
session,
triggered_by,
):
run = DagRun(
dag_id=dag_id,
run_id=run_id,
logical_date=logical_date,
start_date=start_date,
external_trigger=external_trigger,
conf=conf,
state=state,
run_type=run_type,
dag_version=dag_version,
creating_job_id=creating_job_id,
data_interval=data_interval,
triggered_by=triggered_by,
backfill_id=backfill_id,
)
# Load defaults into the following two fields to ensure result can be serialized detached
run.log_template_id = int(session.scalar(select(func.max(LogTemplate.__table__.c.id))))
run.consumed_asset_events = []
session.add(run)
session.flush()
run.dag = dag
# create the associated task instances
# state is None at the moment of creation
run.verify_integrity(session=session)
return run


if TYPE_CHECKING:
dag = task_sdk_dag_decorator
else:
Expand Down Expand Up @@ -1957,6 +1913,55 @@ def get_num_task_instances(dag_id, run_id=None, task_ids=None, states=None, sess
return session.scalar(qry)


@provide_session
def _create_orm_dagrun(
dag: DAG,
dag_id,
run_id,
logical_date,
start_date,
external_trigger,
conf,
state,
run_type,
dag_version,
creating_job_id,
data_interval,
backfill_id,
session,
triggered_by,
):
bundle_version = session.scalar(
select(DagModel.latest_bundle_version).where(DagModel.dag_id == dag.dag_id)
)
run = DagRun(
dag_id=dag_id,
run_id=run_id,
logical_date=logical_date,
start_date=start_date,
external_trigger=external_trigger,
conf=conf,
state=state,
run_type=run_type,
dag_version=dag_version,
creating_job_id=creating_job_id,
data_interval=data_interval,
triggered_by=triggered_by,
backfill_id=backfill_id,
bundle_version=bundle_version,
)
# Load defaults into the following two fields to ensure result can be serialized detached
run.log_template_id = int(session.scalar(select(func.max(LogTemplate.__table__.c.id))))
run.consumed_asset_events = []
session.add(run)
session.flush()
run.dag = dag
# create the associated task instances
# state is None at the moment of creation
run.verify_integrity(session=session)
return run


class DagTag(Base):
"""A tag name per dag, to allow quick filtering in the DAG view."""

Expand Down Expand Up @@ -2002,6 +2007,18 @@ def get_all(cls, session) -> dict[str, dict[str, str]]:
return dag_links


bundle = DagBundle.model_construct(
name="my-bundle",
classpath="airflow.dag_processing.bundles.git.GitDagBundle",
kwargs={
"repo_url": "[email protected]:dstandish/my-dags.git",
"tracking_ref": "main",
"subdir": "dags",
},
version="dd4399e",
)


class DagModel(Base):
"""Table containing DAG properties."""

Expand Down Expand Up @@ -2093,6 +2110,10 @@ class DagModel(Base):
"DagVersion", back_populates="dag_model", cascade="all, delete, delete-orphan"
)

# todo: uncomment this when parsing side is "there"
# dag_bundle = relationship("DagBundle")
dag_bundle = bundle

def __init__(self, **kwargs):
super().__init__(**kwargs)
if self.max_active_tasks is None:
Expand Down
4 changes: 3 additions & 1 deletion airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class DagRun(Base, LoggingMixin):
"""
dag_version_id = Column(UUIDType(binary=False), ForeignKey("dag_version.id", ondelete="CASCADE"))
dag_version = relationship("DagVersion", back_populates="dag_runs")
bundle_version = Column(StringID())

# Remove this `if` after upgrading Sphinx-AutoAPI
if not TYPE_CHECKING and "BUILDING_AIRFLOW_DOCS" in os.environ:
Expand Down Expand Up @@ -236,13 +237,14 @@ def __init__(
triggered_by: DagRunTriggeredByType | None = None,
backfill_id: int | None = None,
dag_version: DagVersion | None = None,
bundle_version: str | None = None,
):
if data_interval is None:
# Legacy: Only happen for runs created prior to Airflow 2.2.
self.data_interval_start = self.data_interval_end = None
else:
self.data_interval_start, self.data_interval_end = data_interval

self.bundle_version = bundle_version
self.dag_id = dag_id
self.run_id = run_id
self.logical_date = logical_date
Expand Down
8 changes: 8 additions & 0 deletions task_sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,13 @@ class XComResponse(BaseModel):
value: Annotated[Any, Field(title="Value")]


class DagBundle(BaseModel):
name: str
classpath: str
kwargs: dict
version: str


class TaskInstance(BaseModel):
"""
Schema for TaskInstance model with minimal required fields needed for Runtime.
Expand All @@ -157,6 +164,7 @@ class TaskInstance(BaseModel):
run_id: Annotated[str, Field(title="Run Id")]
try_number: Annotated[int, Field(title="Try Number")]
map_index: Annotated[int | None, Field(title="Map Index")] = None
bundle: DagBundle | None = None


class HTTPValidationError(BaseModel):
Expand Down
9 changes: 8 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.execution_time.comms import DeferTask, StartupDetails, TaskState, ToSupervisor, ToTask
from airflow.utils.module_loading import import_string

if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger as Logger
Expand All @@ -49,8 +50,14 @@ def parse(what: StartupDetails) -> RuntimeTaskInstance:

from airflow.models.dagbag import DagBag

bundle_info = what.ti.bundle
if TYPE_CHECKING:
assert bundle_info
bundle_cls = import_string(bundle_info.classpath)
bundle_instance = bundle_cls(name=bundle_info.name, version=bundle_info.version, **bundle_info.kwargs)

bag = DagBag(
dag_folder=what.file,
dag_folder=bundle_instance.path,
include_examples=False,
safe_mode=False,
load_op_links=False,
Expand Down

0 comments on commit 5fd9bfe

Please sign in to comment.