Skip to content

Commit

Permalink
Remove Pydanitc models introduced for AIP-44
Browse files Browse the repository at this point in the history
The Pudanic models have been used in a number of places still and
we are using them also for context passing for PythonVirtualEnv
and ExternaPythonOperator  - this PR removes all the models and
their usages.

Closes: #44436

 # Please enter the commit message for your changes. Lines starting
  • Loading branch information
potiuk committed Dec 3, 2024
1 parent 40821bf commit bcbbb4f
Show file tree
Hide file tree
Showing 22 changed files with 45 additions and 1,134 deletions.
4 changes: 2 additions & 2 deletions airflow/api_fastapi/core_api/datamodels/dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.configuration import conf
from airflow.serialization.pydantic.dag import DagTagPydantic
from airflow.models import DagTag


class DAGResponse(BaseModel):
Expand All @@ -50,7 +50,7 @@ class DAGResponse(BaseModel):
description: str | None
timetable_summary: str | None
timetable_description: str | None
tags: list[DagTagPydantic]
tags: list[DagTag]
max_active_tasks: int
max_active_runs: int | None
max_consecutive_failed_dag_runs: int
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class TIEnterRunningPayload(BaseModel):

state: Annotated[
Literal[TIState.RUNNING],
# Specify a default in the schema, but not in code, so Pydantic marks it as required.
# Specify a default in the schema, but not in code.
WithJsonSchema({"type": "string", "enum": [TIState.RUNNING], "default": TIState.RUNNING}),
]
hostname: str
Expand Down
20 changes: 2 additions & 18 deletions airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from airflow.models.dagrun import DagRun
from airflow.models.param import ParamsDict
from airflow.models.taskinstance import TaskReturnCode
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
from airflow.settings import IS_EXECUTOR_CONTAINER, IS_K8S_EXECUTOR_POD
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS
Expand All @@ -74,7 +73,6 @@
from sqlalchemy.orm.session import Session

from airflow.models.operator import Operator
from airflow.serialization.pydantic.dag_run import DagRunPydantic

log = logging.getLogger(__name__)

Expand All @@ -96,7 +94,7 @@ def _fetch_dag_run_from_run_id_or_logical_date_string(
dag_id: str,
value: str,
session: Session,
) -> tuple[DagRun | DagRunPydantic, pendulum.DateTime | None]:
) -> tuple[DagRun, pendulum.DateTime | None]:
"""
Try to find a DAG run with a given string value.
Expand Down Expand Up @@ -132,7 +130,7 @@ def _get_dag_run(
create_if_necessary: CreateIfNecessary,
logical_date_or_run_id: str | None = None,
session: Session | None = None,
) -> tuple[DagRun | DagRunPydantic, bool]:
) -> tuple[DagRun, bool]:
"""
Try to retrieve a DAG run from a string representing either a run ID or logical date.
Expand Down Expand Up @@ -259,8 +257,6 @@ def _run_task_by_selected_method(args, dag: DAG, ti: TaskInstance) -> None | Tas
- as raw task
- by executor
"""
if TYPE_CHECKING:
assert not isinstance(ti, TaskInstancePydantic) # Wait for AIP-44 implementation to complete
if args.local:
return _run_task_by_local_task_job(args, ti)
if args.raw:
Expand Down Expand Up @@ -497,9 +493,6 @@ def task_failed_deps(args) -> None:
dag = get_dag(args.subdir, args.dag_id)
task = dag.get_task(task_id=args.task_id)
ti, _ = _get_ti(task, args.map_index, logical_date_or_run_id=args.logical_date_or_run_id)
# tasks_failed-deps is executed with access to the database.
if isinstance(ti, TaskInstancePydantic):
raise ValueError("not a TaskInstance")
dep_context = DepContext(deps=SCHEDULER_QUEUED_DEPS)
failed_deps = list(ti.get_failed_dep_statuses(dep_context=dep_context))
# TODO, Do we want to print or log this
Expand All @@ -524,9 +517,6 @@ def task_state(args) -> None:
dag = get_dag(args.subdir, args.dag_id)
task = dag.get_task(task_id=args.task_id)
ti, _ = _get_ti(task, args.map_index, logical_date_or_run_id=args.logical_date_or_run_id)
# task_state is executed with access to the database.
if isinstance(ti, TaskInstancePydantic):
raise ValueError("not a TaskInstance")
print(ti.current_state())


Expand Down Expand Up @@ -654,9 +644,6 @@ def task_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> N
ti, dr_created = _get_ti(
task, args.map_index, logical_date_or_run_id=args.logical_date_or_run_id, create_if_necessary="db"
)
# task_test is executed with access to the database.
if isinstance(ti, TaskInstancePydantic):
raise ValueError("not a TaskInstance")
try:
with redirect_stdout(RedactedIO()):
if args.dry_run:
Expand Down Expand Up @@ -705,9 +692,6 @@ def task_render(args, dag: DAG | None = None) -> None:
ti, _ = _get_ti(
task, args.map_index, logical_date_or_run_id=args.logical_date_or_run_id, create_if_necessary="memory"
)
# task_render is executed with access to the database.
if isinstance(ti, TaskInstancePydantic):
raise ValueError("not a TaskInstance")
with create_session() as session, set_current_task_instance_session(session=session):
ti.render_templates()
for attr in task.template_fields:
Expand Down
8 changes: 4 additions & 4 deletions airflow/jobs/JOB_LIFECYCLE.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ sequenceDiagram
DB --> Internal API: Close Session
deactivate DB
Internal API->>CLI component: JobPydantic object
Internal API->>CLI component: Job object
CLI component->>JobRunner: Create Job Runner
JobRunner ->> CLI component: JobRunner object
Expand All @@ -109,7 +109,7 @@ sequenceDiagram
activate JobRunner
JobRunner->>Internal API: prepare_for_execution [JobPydantic]
JobRunner->>Internal API: prepare_for_execution [Job]
Internal API-->>DB: Create Session
activate DB
Expand All @@ -131,7 +131,7 @@ sequenceDiagram
deactivate DB
Internal API ->> JobRunner: returned data
and
JobRunner->>Internal API: perform_heartbeat <br> [Job Pydantic]
JobRunner->>Internal API: perform_heartbeat <br> [Job]
Internal API-->>DB: Create Session
activate DB
Internal API->>DB: perform_heartbeat [Job]
Expand All @@ -142,7 +142,7 @@ sequenceDiagram
deactivate DB
end
JobRunner->>Internal API: complete_execution <br> [Job Pydantic]
JobRunner->>Internal API: complete_execution <br> [Job]
Internal API-->>DB: Create Session
Internal API->>DB: complete_execution [Job]
activate DB
Expand Down
3 changes: 1 addition & 2 deletions airflow/jobs/base_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from sqlalchemy.orm import Session

from airflow.jobs.job import Job
from airflow.serialization.pydantic.job import JobPydantic


class BaseJobRunner:
Expand Down Expand Up @@ -64,7 +63,7 @@ def heartbeat_callback(self, session: Session = NEW_SESSION) -> None:

@classmethod
@provide_session
def most_recent_job(cls, session: Session = NEW_SESSION) -> Job | JobPydantic | None:
def most_recent_job(cls, session: Session = NEW_SESSION) -> Job | None:
"""Return the most recent job of this type, if any, based on last heartbeat received."""
from airflow.jobs.job import most_recent_job

Expand Down
3 changes: 1 addition & 2 deletions airflow/models/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
from airflow.sdk.definitions.dag import DAG
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.utils.context import Context

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -334,7 +333,7 @@ def deserialize(cls, data: dict, dags: dict) -> DagParam:
def process_params(
dag: DAG,
task: Operator,
dag_run: DagRun | DagRunPydantic | None,
dag_run: DagRun | None,
*,
suppress_exception: bool,
) -> dict[str, Any]:
Expand Down
5 changes: 2 additions & 3 deletions airflow/models/skipmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
from airflow.sdk.definitions.node import DAGNode
from airflow.serialization.pydantic.dag_run import DagRunPydantic

# The key used by SkipMixin to store XCom data.
XCOM_SKIPMIXIN_KEY = "skipmixin_key"
Expand All @@ -61,7 +60,7 @@ class SkipMixin(LoggingMixin):

@staticmethod
def _set_state_to_skipped(
dag_run: DagRun | DagRunPydantic,
dag_run: DagRun,
tasks: Sequence[str] | Sequence[tuple[str, int]],
session: Session,
) -> None:
Expand Down Expand Up @@ -95,7 +94,7 @@ def _set_state_to_skipped(
@provide_session
def skip(
self,
dag_run: DagRun | DagRunPydantic,
dag_run: DagRun,
tasks: Iterable[DAGNode],
map_index: int = -1,
session: Session = NEW_SESSION,
Expand Down
8 changes: 3 additions & 5 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,6 @@
from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
from airflow.sdk.definitions.dag import DAG
from airflow.serialization.pydantic.asset import AssetEventPydantic
from airflow.serialization.pydantic.dag import DagModelPydantic
from airflow.timetables.base import DataInterval
from airflow.typing_compat import Literal, TypeGuard
from airflow.utils.task_group import TaskGroup
Expand Down Expand Up @@ -983,7 +981,7 @@ def get_prev_end_date_success() -> pendulum.DateTime | None:
return None
return timezone.coerce_datetime(dagrun.end_date)

def get_triggering_events() -> dict[str, list[AssetEvent | AssetEventPydantic]]:
def get_triggering_events() -> dict[str, list[AssetEvent]]:
if TYPE_CHECKING:
assert session is not None

Expand All @@ -994,7 +992,7 @@ def get_triggering_events() -> dict[str, list[AssetEvent | AssetEventPydantic]]:
if dag_run not in session:
dag_run = session.merge(dag_run, load=False)
asset_events = dag_run.consumed_asset_events
triggering_events: dict[str, list[AssetEvent | AssetEventPydantic]] = defaultdict(list)
triggering_events: dict[str, list[AssetEvent]] = defaultdict(list)
for event in asset_events:
if event.asset:
triggering_events[event.asset.uri].append(event)
Expand Down Expand Up @@ -1891,7 +1889,7 @@ def _command_as_list(
pool: str | None = None,
cfg_path: str | None = None,
) -> list[str]:
dag: DAG | DagModel | DagModelPydantic | None
dag: DAG | DagModel | None
# Use the dag if we have it, else fallback to the ORM dag_model, which might not be loaded
if hasattr(ti, "task") and getattr(ti.task, "dag", None) is not None:
if TYPE_CHECKING:
Expand Down
16 changes: 0 additions & 16 deletions airflow/serialization/pydantic/__init__.py

This file was deleted.

74 changes: 0 additions & 74 deletions airflow/serialization/pydantic/asset.py

This file was deleted.

Loading

0 comments on commit bcbbb4f

Please sign in to comment.