diff --git a/airflow/cli/cli_config.py b/airflow/cli/cli_config.py index 244317b90ecc6..046cc18649648 100644 --- a/airflow/cli/cli_config.py +++ b/airflow/cli/cli_config.py @@ -559,6 +559,11 @@ def string_lower_type(val): action="store_true", ) ARG_RAW = Arg(("-r", "--raw"), argparse.SUPPRESS, "store_true") +ARG_CARRIER = Arg( + ("-c", "--carrier"), + help="Context Carrier, containing the injected context for the Otel task span", + type=str, +) ARG_IGNORE_ALL_DEPENDENCIES = Arg( ("-A", "--ignore-all-dependencies"), help="Ignores all non-critical dependencies, including ignore_ti_state and ignore_task_deps", @@ -1317,6 +1322,7 @@ class GroupCommand(NamedTuple): ARG_CFG_PATH, ARG_LOCAL, ARG_RAW, + ARG_CARRIER, ARG_IGNORE_ALL_DEPENDENCIES, ARG_IGNORE_DEPENDENCIES, ARG_DEPENDS_ON_PAST, diff --git a/airflow/cli/commands/remote_commands/task_command.py b/airflow/cli/commands/remote_commands/task_command.py index 51198af74961a..c5e5330e261a0 100644 --- a/airflow/cli/commands/remote_commands/task_command.py +++ b/airflow/cli/commands/remote_commands/task_command.py @@ -64,6 +64,7 @@ from airflow.utils.net import get_hostname from airflow.utils.providers_configuration_loader import providers_configuration_loaded from airflow.utils.session import NEW_SESSION, create_session, provide_session +from airflow.utils.span_status import SpanStatus from airflow.utils.state import DagRunState from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.types import DagRunTriggeredByType, DagRunType @@ -465,6 +466,16 @@ def task_run(args, dag: DAG | None = None) -> TaskReturnCode | None: log.info("Running %s on host %s", ti, hostname) + if args.carrier is not None: + log.info("Found args.carrier: %s. Setting the value on the ti instance.", args.carrier) + # The arg value is a dict string, and it needs to be converted back to a dict. + carrier_dict = json.loads(args.carrier) + with create_session() as session: + ti.context_carrier = carrier_dict + ti.span_status = SpanStatus.ACTIVE + session.merge(ti) + session.commit() + task_return_code = None try: if args.interactive: diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index a0f48c74b1356..8b712a5958b02 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -18,6 +18,7 @@ from __future__ import annotations +import json import logging import sys from collections import defaultdict, deque @@ -39,6 +40,7 @@ from airflow.traces.utils import gen_span_id_from_ti_key, gen_trace_id from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import TaskInstanceState +from airflow.utils.thread_safe_dict import ThreadSafeDict PARALLELISM: int = conf.getint("core", "PARALLELISM") @@ -117,6 +119,8 @@ class BaseExecutor(LoggingMixin): :param parallelism: how many jobs should run at one time. Set to ``0`` for infinity. """ + active_spans = ThreadSafeDict() + supports_ad_hoc_ti_run: bool = False supports_sentry: bool = False @@ -152,6 +156,10 @@ def __init__(self, parallelism: int = PARALLELISM, team_id: str | None = None): def __repr__(self): return f"{self.__class__.__name__}(parallelism={self.parallelism})" + @classmethod + def set_active_spans(cls, active_spans: ThreadSafeDict): + cls.active_spans = active_spans + def start(self): # pragma: no cover """Executors may need to get things started.""" @@ -339,6 +347,33 @@ def trigger_tasks(self, open_slots: int) -> None: for _ in range(min((open_slots, len(self.queued_tasks)))): key, (command, _, queue, ti) = sorted_queue.pop(0) + # If it's None, then the span for the current TaskInstanceKey hasn't been started. + if self.active_spans is not None and self.active_spans.get(key) is None: + from airflow.models.taskinstance import SimpleTaskInstance + + if isinstance(ti, SimpleTaskInstance): + parent_context = Trace.extract(ti.parent_context_carrier) + else: + parent_context = Trace.extract(ti.dag_run.context_carrier) + # Start a new span using the context from the parent. + # Attributes will be set once the task has finished so that all + # values will be available (end_time, duration, etc.). + span = Trace.start_child_span( + span_name=f"{ti.task_id}", + parent_context=parent_context, + component="task", + start_time=ti.queued_dttm, + start_as_current=False, + ) + self.active_spans.set(key, span) + # Inject the current context into the carrier. + carrier = Trace.inject() + # The carrier needs to be set on the ti, but it can't happen here because db calls are expensive. + # So set the carrier as an argument to the command. + # The command execution will set it on the ti, and it will be propagated to the task itself. + command.append("--carrier") + command.append(json.dumps(carrier)) + # If a task makes it here but is still understood by the executor # to be running, it generally means that the task has been killed # externally and not yet been marked as failed. diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index f8a2eb329537b..8e29672ddd2ed 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -76,8 +76,10 @@ from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.retries import MAX_DB_RETRIES, retry_db_transaction, run_with_db_retries from airflow.utils.session import NEW_SESSION, create_session, provide_session +from airflow.utils.span_status import SpanStatus from airflow.utils.sqlalchemy import is_lock_not_available_error, prohibit_commit, with_row_locks from airflow.utils.state import DagRunState, JobState, State, TaskInstanceState +from airflow.utils.thread_safe_dict import ThreadSafeDict from airflow.utils.types import DagRunTriggeredByType, DagRunType if TYPE_CHECKING: @@ -164,6 +166,12 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin): job_type = "SchedulerJob" + # For a dagrun span + # - key: dag_run.run_id | value: span + # For a ti span + # - key: ti.key | value: span + active_spans = ThreadSafeDict() + def __init__( self, job: Job, @@ -231,6 +239,8 @@ def register_signals(self) -> ExitStack: def _exit_gracefully(self, signum: int, frame: FrameType | None) -> None: """Clean up processor_agent to avoid leaving orphan processes.""" + self._end_active_spans() + if not _is_parent_process(): # Only the parent process should perform the cleanup. return @@ -817,8 +827,18 @@ def process_executor_events( ti.pid, ) - with Trace.start_span_from_taskinstance(ti=ti) as span: - cls._set_span_attrs__process_executor_events(span, state, ti) + active_ti_span = cls.active_spans.get(ti.key) + if active_ti_span is not None: + cls.set_ti_span_attrs(span=active_ti_span, state=state, ti=ti) + # End the span and remove it from the active_spans dict. + active_ti_span.end(end_time=datetime_to_nano(ti.end_date)) + cls.active_spans.delete(ti.key) + ti.span_status = SpanStatus.ENDED + else: + if ti.span_status == SpanStatus.ACTIVE: + # Another scheduler has started the span. + # Update the SpanStatus to let the process know that it must end it. + ti.span_status = SpanStatus.SHOULD_END # There are two scenarios why the same TI with the same try_number is queued # after executor is finished with it: @@ -877,36 +897,36 @@ def process_executor_events( return len(event_buffer) @classmethod - def _set_span_attrs__process_executor_events(cls, span, state, ti): + def set_ti_span_attrs(cls, span, state, ti): span.set_attributes( { - "category": "scheduler", - "task_id": ti.task_id, - "dag_id": ti.dag_id, - "state": ti.state, - "error": True if state == TaskInstanceState.FAILED else False, - "start_date": str(ti.start_date), - "end_date": str(ti.end_date), - "duration": ti.duration, - "executor_config": str(ti.executor_config), - "logical_date": str(ti.logical_date), - "hostname": ti.hostname, - "log_url": ti.log_url, - "operator": str(ti.operator), - "try_number": ti.try_number, - "executor_state": state, - "pool": ti.pool, - "queue": ti.queue, - "priority_weight": ti.priority_weight, - "queued_dttm": str(ti.queued_dttm), - "queued_by_job_id": ti.queued_by_job_id, - "pid": ti.pid, + "airflow.category": "scheduler", + "airflow.task.task_id": ti.task_id, + "airflow.task.dag_id": ti.dag_id, + "airflow.task.state": ti.state, + "airflow.task.error": True if state == TaskInstanceState.FAILED else False, + "airflow.task.start_date": str(ti.start_date), + "airflow.task.end_date": str(ti.end_date), + "airflow.task.duration": ti.duration, + "airflow.task.executor_config": str(ti.executor_config), + "airflow.task.logical_date": str(ti.logical_date), + "airflow.task.hostname": ti.hostname, + "airflow.task.log_url": ti.log_url, + "airflow.task.operator": str(ti.operator), + "airflow.task.try_number": ti.try_number, + "airflow.task.executor_state": state, + "airflow.task.pool": ti.pool, + "airflow.task.queue": ti.queue, + "airflow.task.priority_weight": ti.priority_weight, + "airflow.task.queued_dttm": str(ti.queued_dttm), + "airflow.task.queued_by_job_id": ti.queued_by_job_id, + "airflow.task.pid": ti.pid, } ) if span.is_recording(): - span.add_event(name="queued", timestamp=datetime_to_nano(ti.queued_dttm)) - span.add_event(name="started", timestamp=datetime_to_nano(ti.start_date)) - span.add_event(name="ended", timestamp=datetime_to_nano(ti.end_date)) + span.add_event(name="airflow.task.queued", timestamp=datetime_to_nano(ti.queued_dttm)) + span.add_event(name="airflow.task.started", timestamp=datetime_to_nano(ti.start_date)) + span.add_event(name="airflow.task.ended", timestamp=datetime_to_nano(ti.end_date)) def _execute(self) -> int | None: from airflow.dag_processing.manager import DagFileProcessorAgent @@ -948,6 +968,14 @@ def _execute(self) -> int | None: execute_start_time = timezone.utcnow() + # local import due to type_checking. + from airflow.executors.base_executor import BaseExecutor + + # Pass a reference to the dictionary. + # Any changes made by a dag_run instance, will be reflected to the dictionary of this class. + DagRun.set_active_spans(active_spans=self.active_spans) + BaseExecutor.set_active_spans(active_spans=self.active_spans) + self._run_scheduler_loop() if self.processor_agent: @@ -1009,6 +1037,157 @@ def _update_dag_run_state_for_paused_dags(self, session: Session = NEW_SESSION) except Exception as e: # should not fail the scheduler self.log.exception("Failed to update dag run state for paused dags due to %s", e) + @provide_session + def _end_active_spans(self, session: Session = NEW_SESSION): + # No need to do a commit for every update. The annotation will commit all of them once at the end. + for key, span in self.active_spans.get_all().items(): + from airflow.models.taskinstance import TaskInstanceKey + + if isinstance(key, TaskInstanceKey): # ti span. + # Can't compare the key directly because the try_number or the map_index might not be the same. + ti: TaskInstance = session.scalars( + select(TaskInstance).where( + TaskInstance.dag_id == key.dag_id, + TaskInstance.task_id == key.task_id, + TaskInstance.run_id == key.run_id, + ) + ).one() + if ti.state in State.finished: + self.set_ti_span_attrs(span=span, state=ti.state, ti=ti) + span.end(end_time=datetime_to_nano(ti.end_date)) + ti.span_status = SpanStatus.ENDED + else: + span.end() + ti.span_status = SpanStatus.NEEDS_CONTINUANCE + else: + dag_run: DagRun = session.scalars(select(DagRun).where(DagRun.run_id == key)).one() + if dag_run.state in State.finished_dr_states: + dagv = session.scalar(select(DagVersion).where(DagVersion.id == dag_run.dag_version_id)) + dag_run.set_dagrun_span_attrs(span=span, dagv=dagv) + + span.end(end_time=datetime_to_nano(dag_run.end_date)) + dag_run.span_status = SpanStatus.ENDED + else: + span.end() + dag_run.span_status = SpanStatus.NEEDS_CONTINUANCE + initial_dag_run_context = Trace.extract(dag_run.context_carrier) + with Trace.start_child_span( + span_name="current_scheduler_exited", parent_context=initial_dag_run_context + ) as s: + s.set_attribute("trace_status", "needs continuance") + + self.active_spans.clear() + self.active_spans.clear() + + def _end_spans_of_externally_ended_ops(self, session: Session): + # The scheduler that starts a dag_run or a task is also the one that starts the spans. + # Each scheduler should end the spans that it has started. + # + # Otel spans are implemented in a certain way so that the objects + # can't be shared between processes or get recreated. + # It is done so that the process that starts a span, is also the one that ends it. + # + # If another scheduler has finished processing a dag_run or a task and there is a reference + # on the active_spans dictionary, then the current scheduler started the span, + # and therefore must end it. + dag_runs_should_end: list[DagRun] = session.scalars( + select(DagRun).where(DagRun.span_status == SpanStatus.SHOULD_END) + ).all() + tis_should_end: list[TaskInstance] = session.scalars( + select(TaskInstance).where(TaskInstance.span_status == SpanStatus.SHOULD_END) + ).all() + + for dag_run in dag_runs_should_end: + active_dagrun_span = self.active_spans.get(dag_run.run_id) + if active_dagrun_span is not None: + if dag_run.state in State.finished_dr_states: + dagv = session.scalar(select(DagVersion).where(DagVersion.id == dag_run.dag_version_id)) + dag_run.set_dagrun_span_attrs(span=active_dagrun_span, dagv=dagv) + + active_dagrun_span.end(end_time=datetime_to_nano(dag_run.end_date)) + else: + active_dagrun_span.end() + self.active_spans.delete(dag_run.run_id) + dag_run.span_status = SpanStatus.ENDED + + for ti in tis_should_end: + active_ti_span = self.active_spans.get(ti.key) + if active_ti_span is not None: + if ti.state in State.finished: + self.set_ti_span_attrs(span=active_ti_span, state=ti.state, ti=ti) + active_ti_span.end(end_time=datetime_to_nano(ti.end_date)) + else: + active_ti_span.end() + self.active_spans.delete(ti.key) + ti.span_status = SpanStatus.ENDED + + def _recreate_unhealthy_scheduler_spans_if_needed(self, dag_run: DagRun, session: Session): + # There are two scenarios: + # 1. scheduler is unhealthy but managed to update span_status + # 2. scheduler is unhealthy and didn't manage to make any updates + # Check the span_status first, in case the 2nd db query can be avoided (scenario 1). + + # If the dag_run is scheduled by a different scheduler, and it's still running and the span is active, + # then check the Job table to determine if the initial scheduler is still healthy. + if ( + dag_run.scheduled_by_job_id != self.job.id + and dag_run.state in State.unfinished_dr_states + and dag_run.span_status == SpanStatus.ACTIVE + ): + initial_scheduler_id = dag_run.scheduled_by_job_id + job: Job = session.scalars( + select(Job).where( + Job.id == initial_scheduler_id, + Job.job_type == "SchedulerJob", + ) + ).one() + + if not job.is_alive(): + # Start a new span for the dag_run. + dr_span = Trace.start_root_span( + span_name=f"{dag_run.dag_id}_recreated", + component="dag", + start_time=dag_run.queued_at, + start_as_current=False, + ) + carrier = Trace.inject() + # Update the context_carrier and leave the SpanStatus as ACTIVE. + dag_run.context_carrier = carrier + self.active_spans.set(dag_run.run_id, dr_span) + + tis = dag_run.get_task_instances(session=session) + + # At this point, any tis will have been adopted by the current scheduler, + # and ti.queued_by_job_id will point to the current id. + # Any tis that have been executed by the unhealthy scheduler, will need a new span + # so that it can be associated with the new dag_run span. + tis_needing_spans = [ + ti + for ti in tis + # If it has started and there is a reference on the active_spans dict, + # then it was started by the current scheduler. + if ti.start_date is not None and self.active_spans.get(ti.key) is None + ] + + dr_context = Trace.extract(dag_run.context_carrier) + for ti in tis_needing_spans: + ti_span = Trace.start_child_span( + span_name=f"{ti.task_id}_recreated", + parent_context=dr_context, + start_time=ti.queued_dttm, + start_as_current=False, + ) + ti_carrier = Trace.inject() + ti.context_carrier = ti_carrier + + if ti.state in State.finished: + self.set_ti_span_attrs(span=ti_span, state=ti.state, ti=ti) + ti_span.end(end_time=datetime_to_nano(ti.end_date)) + ti.span_status = SpanStatus.ENDED + else: + ti.span_status = SpanStatus.ACTIVE + self.active_spans.set(ti.key, ti_span) + def _run_scheduler_loop(self) -> None: """ Harvest DAG parsing results, queue tasks, and perform executor heartbeat; the actual scheduler loop. @@ -1087,6 +1266,8 @@ def _run_scheduler_loop(self) -> None: ) with create_session() as session: + self._end_spans_of_externally_ended_ops(session) + # This will schedule for as many executors as possible. num_queued_tis = self._do_scheduling(session) @@ -1723,6 +1904,19 @@ def _schedule_dag_run( "The DAG disappeared before verifying integrity: %s. Skipping.", dag_run.dag_id ) return callback + + if ( + dag_run.scheduled_by_job_id is not None + and dag_run.scheduled_by_job_id != self.job.id + and self.active_spans.get(dag_run.run_id) is None + ): + # If the dag_run has been previously scheduled by another job and there is no active span, + # then check if the job is still healthy. + # If it's not healthy, then recreate the spans. + self._recreate_unhealthy_scheduler_spans_if_needed(dag_run, session) + + dag_run.scheduled_by_job_id = self.job.id + # TODO[HA]: Rename update_state -> schedule_dag_run, ?? something else? schedulable_tis, callback_to_run = dag_run.update_state(session=session, execute_callbacks=False) diff --git a/airflow/migrations/versions/0056_3_0_0_add_new_otel_span_fields.py b/airflow/migrations/versions/0056_3_0_0_add_new_otel_span_fields.py new file mode 100644 index 0000000000000..e8190c3cf6d0e --- /dev/null +++ b/airflow/migrations/versions/0056_3_0_0_add_new_otel_span_fields.py @@ -0,0 +1,60 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Add new otel span fields. + +Revision ID: 70206b887482 +Revises: e39a26ac59f6 +Create Date: 2025-01-15 13:09:35.906137 + +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +from airflow.utils.sqlalchemy import ExtendedJSON + +# revision identifiers, used by Alembic. +revision = "70206b887482" +down_revision = "e39a26ac59f6" +branch_labels = None +depends_on = None +airflow_version = "3.0.0" + + +def upgrade(): + """Apply add new otel span fields.""" + op.add_column("dag_run", sa.Column("scheduled_by_job_id", sa.Integer, nullable=True)) + op.add_column("dag_run", sa.Column("context_carrier", ExtendedJSON, nullable=True)) + op.add_column("dag_run", sa.Column("span_status", sa.String(250), nullable=False)) + + op.add_column("task_instance", sa.Column("context_carrier", ExtendedJSON, nullable=True)) + op.add_column("task_instance", sa.Column("span_status", sa.String(250), nullable=False)) + + +def downgrade(): + """Unapply add new otel span fields.""" + op.drop_column("dag_run", "scheduled_by_job_id") + op.drop_column("dag_run", "context_carrier") + op.drop_column("dag_run", "span_status") + + op.drop_column("task_instance", "context_carrier") + op.drop_column("task_instance", "span_status") diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 8060901f948a3..56e81def28d9c 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -21,7 +21,15 @@ import os from collections import defaultdict from collections.abc import Iterable, Iterator, Sequence -from typing import TYPE_CHECKING, Any, Callable, NamedTuple, TypeVar, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + NamedTuple, + TypeVar, + Union, + overload, +) import re2 from sqlalchemy import ( @@ -48,6 +56,7 @@ from sqlalchemy.dialects import postgresql from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.associationproxy import association_proxy +from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.orm import declared_attr, joinedload, relationship, synonym, validates from sqlalchemy.sql.expression import case, false, select, true from sqlalchemy.sql.functions import coalesce @@ -69,20 +78,23 @@ from airflow.stats import Stats from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_states import SCHEDULEABLE_STATES -from airflow.traces.tracer import Trace +from airflow.traces.tracer import EmptySpan, Trace from airflow.utils import timezone from airflow.utils.dates import datetime_to_nano from airflow.utils.helpers import chunks, is_container, prune_dict from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.retries import retry_db_transaction from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, with_row_locks +from airflow.utils.span_status import SpanStatus +from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, nulls_first, with_row_locks from airflow.utils.state import DagRunState, State, TaskInstanceState +from airflow.utils.thread_safe_dict import ThreadSafeDict from airflow.utils.types import NOTSET, DagRunTriggeredByType, DagRunType if TYPE_CHECKING: from datetime import datetime + from opentelemetry.sdk.trace import Span from sqlalchemy.orm import Query, Session from airflow.models.dag import DAG @@ -123,6 +135,8 @@ class DagRun(Base, LoggingMixin): external trigger (i.e. manual runs). """ + active_spans = ThreadSafeDict() + __tablename__ = "dag_run" id = Column(Integer, primary_key=True) @@ -168,6 +182,11 @@ class DagRun(Base, LoggingMixin): dag_version = relationship("DagVersion", back_populates="dag_runs") bundle_version = Column(StringID()) + scheduled_by_job_id = Column(Integer) + # Span context carrier, used for context propagation. + context_carrier = Column(MutableDict.as_mutable(ExtendedJSON)) + span_status = Column(String(250), default=SpanStatus.NOT_STARTED, nullable=False) + # Remove this `if` after upgrading Sphinx-AutoAPI if not TYPE_CHECKING and "BUILDING_AIRFLOW_DOCS" in os.environ: dag: DAG | None @@ -266,6 +285,8 @@ def __init__( self.clear_number = 0 self.triggered_by = triggered_by self.dag_version = dag_version + self.scheduled_by_job_id = None + self.context_carrier = {} super().__init__() def __repr__(self): @@ -291,6 +312,10 @@ def validate_run_id(self, key: str, run_id: str) -> str | None: def stats_tags(self) -> dict[str, str]: return prune_dict({"dag_id": self.dag_id, "run_type": self.run_type}) + @classmethod + def set_active_spans(cls, active_spans: ThreadSafeDict): + cls.active_spans = active_spans + def get_state(self): return self._state @@ -836,6 +861,145 @@ def is_effective_leaf(task): leaf_tis = {ti for ti in tis if ti.task_id in leaf_task_ids if ti.state != TaskInstanceState.REMOVED} return leaf_tis + def set_dagrun_span_attrs(self, span: Span | EmptySpan, dagv: DagVersion): + if self._state == DagRunState.FAILED: + span.set_attribute("airflow.dag_run.error", True) + + attribute_value_type = Union[ + str, + bool, + int, + float, + Sequence[str], + Sequence[bool], + Sequence[int], + Sequence[float], + ] + + # Explicitly set the value type to Union[...] to avoid a mypy error. + attributes: dict[str, attribute_value_type] = { + "airflow.category": "DAG runs", + "airflow.dag_run.dag_id": str(self.dag_id), + "airflow.dag_run.logical_date": str(self.logical_date), + "airflow.dag_run.run_id": str(self.run_id), + "airflow.dag_run.queued_at": str(self.queued_at), + "airflow.dag_run.run_start_date": str(self.start_date), + "airflow.dag_run.run_end_date": str(self.end_date), + "airflow.dag_run.run_duration": str( + (self.end_date - self.start_date).total_seconds() if self.start_date and self.end_date else 0 + ), + "airflow.dag_run.state": str(self._state), + "airflow.dag_run.external_trigger": str(self.external_trigger), + "airflow.dag_run.run_type": str(self.run_type), + "airflow.dag_run.data_interval_start": str(self.data_interval_start), + "airflow.dag_run.data_interval_end": str(self.data_interval_end), + "airflow.dag_version.version": str(dagv.version if dagv else None), + "airflow.dag_run.conf": str(self.conf), + } + if span.is_recording(): + span.add_event(name="airflow.dag_run.queued", timestamp=datetime_to_nano(self.queued_at)) + span.add_event(name="airflow.dag_run.started", timestamp=datetime_to_nano(self.start_date)) + span.add_event(name="airflow.dag_run.ended", timestamp=datetime_to_nano(self.end_date)) + span.set_attributes(attributes) + + def start_dr_spans_if_needed(self, tis: list[TI]): + # If there is no value in active_spans, then the span hasn't already been started. + if self.active_spans is not None and self.active_spans.get(self.run_id) is None: + if self.span_status == SpanStatus.NOT_STARTED or self.span_status == SpanStatus.NEEDS_CONTINUANCE: + dr_span = None + continue_ti_spans = False + if self.span_status == SpanStatus.NOT_STARTED: + dr_span = Trace.start_root_span( + span_name=f"{self.dag_id}", + component="dag", + start_time=self.queued_at, # This is later converted to nano. + start_as_current=False, + ) + elif self.span_status == SpanStatus.NEEDS_CONTINUANCE: + # Use the existing context_carrier to set the initial dag_run span as the parent. + parent_context = Trace.extract(self.context_carrier) + with Trace.start_child_span( + span_name="new_scheduler", parent_context=parent_context + ) as s: + s.set_attribute("trace_status", "continued") + + dr_span = Trace.start_child_span( + span_name=f"{self.dag_id}_continued", + parent_context=parent_context, + component="dag", + # No start time + start_as_current=False, + ) + # After this span is started, the context_carrier will be replaced by the new one. + # New task span will use this span as the parent. + continue_ti_spans = True + carrier = Trace.inject() + self.context_carrier = carrier + self.span_status = SpanStatus.ACTIVE + # Set the span in a synchronized dictionary, so that the variable can be used to end the span. + self.active_spans.set(self.run_id, dr_span) + self.log.debug( + "DagRun span has been started and the injected context_carrier is: %s", + self.context_carrier, + ) + # Start TI spans that also need continuance. + if continue_ti_spans: + new_dagrun_context = Trace.extract(self.context_carrier) + for ti in tis: + if ti.span_status == SpanStatus.NEEDS_CONTINUANCE: + ti_span = Trace.start_child_span( + span_name=f"{ti.task_id}_continued", + parent_context=new_dagrun_context, + start_as_current=False, + ) + ti_carrier = Trace.inject() + ti.context_carrier = ti_carrier + ti.span_status = SpanStatus.ACTIVE + self.active_spans.set(ti.key, ti_span) + else: + self.log.info( + "Found span_status '%s', while updating state for dag_run '%s'", + self.span_status, + self.run_id, + ) + + def end_dr_span_if_needed(self, dagv: DagVersion): + if self.active_spans is not None: + active_span = self.active_spans.get(self.run_id) + if active_span is not None: + self.log.debug( + "Found active span with span_id: %s, for dag_id: %s, run_id: %s, state: %s", + active_span.get_span_context().span_id, + self.dag_id, + self.run_id, + self.state, + ) + + self.set_dagrun_span_attrs(span=active_span, dagv=dagv) + active_span.end(end_time=datetime_to_nano(self.end_date)) + # Remove the span from the dict. + self.active_spans.delete(self.run_id) + self.span_status = SpanStatus.ENDED + else: + if self.span_status == SpanStatus.ACTIVE: + # Another scheduler has started the span. + # Update the DB SpanStatus to notify the owner to end it. + self.span_status = SpanStatus.SHOULD_END + elif self.span_status == SpanStatus.NEEDS_CONTINUANCE: + # This is a corner case where the scheduler exited gracefully + # while the dag_run was almost done. + # Since it reached this point, the dag has finished but there has been no time + # to create a new span for the current scheduler. + # There is no need for more spans, update the status on the db. + self.span_status = SpanStatus.ENDED + else: + self.log.debug( + "No active span has been found for dag_id: %s, run_id: %s, state: %s", + self.dag_id, + self.run_id, + self.state, + ) + @provide_session def update_state( self, session: Session = NEW_SESSION, execute_callbacks: bool = True @@ -962,6 +1126,9 @@ def recalculate(self) -> _UnfinishedStates: # finally, if the leaves aren't done, the dag is still running else: + # It might need to start TI spans as well. + self.start_dr_spans_if_needed(tis=tis) + self.set_state(DagRunState.RUNNING) if self._state == DagRunState.FAILED or self._state == DagRunState.SUCCESS: @@ -992,7 +1159,7 @@ def recalculate(self) -> _UnfinishedStates: dagv.version if dagv else None, ) - self._trace_dagrun(dagv) + self.end_dr_span_if_needed(dagv=dagv) session.flush() @@ -1004,35 +1171,6 @@ def recalculate(self) -> _UnfinishedStates: return schedulable_tis, callback - def _trace_dagrun(self, dagv) -> None: - with Trace.start_span_from_dagrun(dagrun=self) as span: - if self._state == DagRunState.FAILED: - span.set_attribute("error", True) - attributes = { - "category": "DAG runs", - "dag_id": self.dag_id, - "logical_date": str(self.logical_date), - "run_id": self.run_id, - "queued_at": str(self.queued_at), - "run_start_date": str(self.start_date), - "run_end_date": str(self.end_date), - "run_duration": (self.end_date - self.start_date).total_seconds() - if self.start_date and self.end_date - else 0, - "state": str(self._state), - "external_trigger": self.external_trigger, - "run_type": str(self.run_type), - "data_interval_start": str(self.data_interval_start), - "data_interval_end": str(self.data_interval_end), - "dag_version": str(dagv.version if dagv else None), - "conf": str(self.conf), - } - if span.is_recording(): - span.add_event(name="queued", timestamp=datetime_to_nano(self.queued_at)) - span.add_event(name="started", timestamp=datetime_to_nano(self.start_date)) - span.add_event(name="ended", timestamp=datetime_to_nano(self.end_date)) - span.set_attributes(attributes) - @provide_session def task_instance_scheduling_decisions(self, session: Session = NEW_SESSION) -> TISchedulingDecision: tis = self.get_task_instances(session=session, state=State.task_states) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 2f3fa4e8fb4a9..16d790c66b8df 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -114,7 +114,6 @@ from airflow.stats import Stats from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS -from airflow.traces.tracer import Trace from airflow.utils import timezone from airflow.utils.context import ( ConnectionAccessor, @@ -133,6 +132,7 @@ from airflow.utils.platform import getuser from airflow.utils.retries import run_with_db_retries from airflow.utils.session import NEW_SESSION, create_session, provide_session +from airflow.utils.span_status import SpanStatus from airflow.utils.sqlalchemy import ExecutorConfigType, ExtendedJSON, UtcDateTime from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.task_group import MappedTaskGroup @@ -807,6 +807,8 @@ def _set_ti_attrs(target, source, include_dag_run=False): target.next_method = source.next_method target.next_kwargs = source.next_kwargs target.dag_version_id = source.dag_version_id + target.context_carrier = source.context_carrier + target.span_status = source.span_status if include_dag_run: target.logical_date = source.logical_date @@ -828,6 +830,9 @@ def _set_ti_attrs(target, source, include_dag_run=False): target.dag_run.dag_version_id = source.dag_run.dag_version_id target.dag_run.updated_at = source.dag_run.updated_at target.dag_run.log_template_id = source.dag_run.log_template_id + target.dag_run.scheduled_by_job_id = source.dag_run.scheduled_by_job_id + target.dag_run.context_carrier = source.dag_run.context_carrier + target.dag_run.span_status = source.dag_run.span_status def _refresh_from_db( @@ -857,7 +862,31 @@ def _refresh_from_db( ) if ti: - _set_ti_attrs(task_instance, ti, include_dag_run=False) + # Check if the ti is detached or the dag_run relationship isn't loaded. + # If the scheduler that started the dag_run has exited (gracefully or forcefully), + # there will be changes to the dag_run span context_carrier. + # It's best to include the dag_run whenever possible, so that the ti will contain the updates. + task_instance_inspector = inspect(task_instance) + is_task_instance_bound_to_session = task_instance_inspector.session is not None + + # If the check is false, then it will try load the dag_run relationship from the task_instance + # and it will fail with this error: + # + # sqlalchemy.orm.exc.DetachedInstanceError: Parent instance + # is not bound to a Session; lazy load operation of attribute 'dag_run' cannot proceed + if is_task_instance_bound_to_session: + ti_inspector = inspect(ti) + dr_inspector = inspect(ti.dag_run) + + is_ti_attached = not ti_inspector.detached + is_dr_attached = not dr_inspector.detached + is_dr_loaded = "dag_run" not in ti_inspector.unloaded + + include_dag_run = is_ti_attached and is_dr_attached and is_dr_loaded + else: + include_dag_run = False + + _set_ti_attrs(task_instance, ti, include_dag_run=include_dag_run) else: task_instance.state = None @@ -1129,29 +1158,6 @@ def _handle_failure( if not test_mode: TaskInstance.save_to_db(failure_context["ti"], session) - with Trace.start_span_from_taskinstance(ti=task_instance) as span: - span.set_attributes( - { - # ---- error info ---- - "error": "true", - "error_msg": str(error), - "force_fail": force_fail, - # ---- common info ---- - "category": "DAG runs", - "task_id": task_instance.task_id, - "dag_id": task_instance.dag_id, - "state": task_instance.state, - "start_date": str(task_instance.start_date), - "end_date": str(task_instance.end_date), - "duration": task_instance.duration, - "executor_config": str(task_instance.executor_config), - "logical_date": str(task_instance.logical_date), - "hostname": task_instance.hostname, - "operator": str(task_instance.operator), - } - ) - span.set_attribute("log_url", task_instance.log_url) - def _refresh_from_task( *, task_instance: TaskInstance, task: Operator, pool_override: str | None = None @@ -1708,6 +1714,8 @@ class TaskInstance(Base, LoggingMixin): executor_config = Column(ExecutorConfigType(pickler=dill)) updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow) rendered_map_index = Column(String(250)) + context_carrier = Column(MutableDict.as_mutable(ExtendedJSON)) + span_status = Column(String(250), default=SpanStatus.NOT_STARTED, nullable=False) external_executor_id = Column(StringID()) @@ -1819,6 +1827,7 @@ def __init__( self.raw = False # can be changed when calling 'run' self.test_mode = False + self.context_carrier = {} def __hash__(self): return hash((self.task_id, self.dag_id, self.run_id, self.map_index)) @@ -3761,6 +3770,7 @@ def __init__( dag_id: str, task_id: str, run_id: str, + queued_dttm: datetime | None, start_date: datetime | None, end_date: datetime | None, try_number: int, @@ -3773,11 +3783,15 @@ def __init__( key: TaskInstanceKey, run_as_user: str | None = None, priority_weight: int | None = None, + parent_context_carrier: dict | None = None, + context_carrier: dict | None = None, + span_status: str | None = None, ): self.dag_id = dag_id self.task_id = task_id self.run_id = run_id self.map_index = map_index + self.queued_dttm = queued_dttm self.start_date = start_date self.end_date = end_date self.try_number = try_number @@ -3789,6 +3803,9 @@ def __init__( self.priority_weight = priority_weight self.queue = queue self.key = key + self.parent_context_carrier = parent_context_carrier + self.context_carrier = context_carrier + self.span_status = span_status def __repr__(self) -> str: attrs = ", ".join(f"{k}={v!r}" for k, v in self.__dict__.items()) @@ -3806,6 +3823,7 @@ def from_ti(cls, ti: TaskInstance) -> SimpleTaskInstance: task_id=ti.task_id, run_id=ti.run_id, map_index=ti.map_index, + queued_dttm=ti.queued_dttm, start_date=ti.start_date, end_date=ti.end_date, try_number=ti.try_number, @@ -3817,6 +3835,12 @@ def from_ti(cls, ti: TaskInstance) -> SimpleTaskInstance: key=ti.key, run_as_user=ti.run_as_user if hasattr(ti, "run_as_user") else None, priority_weight=ti.priority_weight if hasattr(ti, "priority_weight") else None, + # Inspect the ti, to check if the 'dag_run' relationship is loaded. + parent_context_carrier=ti.dag_run.context_carrier + if "dag_run" not in inspect(ti).unloaded + else None, + context_carrier=ti.context_carrier if hasattr(ti, "context_carrier") else None, + span_status=ti.span_status, ) diff --git a/airflow/models/taskinstancehistory.py b/airflow/models/taskinstancehistory.py index 9ac11cad7dba5..01a930a7f7c34 100644 --- a/airflow/models/taskinstancehistory.py +++ b/airflow/models/taskinstancehistory.py @@ -38,6 +38,7 @@ from airflow.models.base import Base, StringID from airflow.utils import timezone from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.span_status import SpanStatus from airflow.utils.sqlalchemy import ( ExecutorConfigType, ExtendedJSON, @@ -83,6 +84,8 @@ class TaskInstanceHistory(Base): executor_config = Column(ExecutorConfigType(pickler=dill)) updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow) rendered_map_index = Column(String(250)) + context_carrier = Column(MutableDict.as_mutable(ExtendedJSON)) + span_status = Column(String(250), default=SpanStatus.NOT_STARTED, nullable=False) external_executor_id = Column(StringID()) trigger_id = Column(Integer) diff --git a/airflow/traces/otel_tracer.py b/airflow/traces/otel_tracer.py index b172645850353..60e64b738462c 100644 --- a/airflow/traces/otel_tracer.py +++ b/airflow/traces/otel_tracer.py @@ -19,32 +19,31 @@ import logging import random +from typing import TYPE_CHECKING +import pendulum from opentelemetry import trace -from opentelemetry.context import create_key +from opentelemetry.context import attach, create_key from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk.resources import HOST_NAME, SERVICE_NAME, Resource -from opentelemetry.sdk.trace import Span, Tracer as OpenTelemetryTracer, TracerProvider -from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter +from opentelemetry.sdk.trace import Span, SpanProcessor, Tracer as OpenTelemetryTracer, TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter, SimpleSpanProcessor from opentelemetry.sdk.trace.id_generator import IdGenerator from opentelemetry.trace import Link, NonRecordingSpan, SpanContext, TraceFlags, Tracer +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID from airflow.configuration import conf -from airflow.traces import ( - TRACEPARENT, - TRACESTATE, -) from airflow.traces.utils import ( - gen_dag_span_id, - gen_span_id, - gen_trace_id, parse_traceparent, parse_tracestate, ) from airflow.utils.dates import datetime_to_nano from airflow.utils.net import get_hostname +if TYPE_CHECKING: + from opentelemetry.context.context import Context + log = logging.getLogger(__name__) _NEXT_ID = create_key("next_id") @@ -57,29 +56,65 @@ class OtelTrace: When OTEL is enabled, the Trace class will be replaced by this class. """ - def __init__(self, span_exporter: OTLPSpanExporter, tag_string: str | None = None): + def __init__( + self, + span_exporter: OTLPSpanExporter, + use_simple_processor: bool, + tag_string: str | None = None, + ): self.span_exporter = span_exporter - self.span_processor = BatchSpanProcessor(self.span_exporter) + self.use_simple_processor = use_simple_processor + if self.use_simple_processor: + # With a BatchSpanProcessor, spans are exported at an interval. + # A task can run fast and finish before spans have enough time to get exported to the collector. + # When creating spans from inside a task, a SimpleSpanProcessor needs to be used because + # it exports the spans immediately after they are created. + log.info("(__init__) - [SimpleSpanProcessor] is being used") + self.span_processor: SpanProcessor = SimpleSpanProcessor(self.span_exporter) + else: + log.info("(__init__) - [BatchSpanProcessor] is being used") + self.span_processor = BatchSpanProcessor(self.span_exporter) self.tag_string = tag_string self.otel_service = conf.get("traces", "otel_service") + self.resource = Resource.create( + attributes={HOST_NAME: get_hostname(), SERVICE_NAME: self.otel_service} + ) - def get_tracer( - self, component: str, trace_id: int | None = None, span_id: int | None = None - ) -> OpenTelemetryTracer | Tracer: - """Tracer that will use special AirflowOtelIdGenerator to control producing certain span and trace id.""" - resource = Resource.create(attributes={HOST_NAME: get_hostname(), SERVICE_NAME: self.otel_service}) + def get_otel_tracer_provider( + self, trace_id: int | None = None, span_id: int | None = None + ) -> TracerProvider: + """ + Tracer that will use special AirflowOtelIdGenerator to control producing certain span and trace id. + + It can be used to get a tracer and directly create spans, or for auto-instrumentation. + """ if trace_id or span_id: # in case where trace_id or span_id was given tracer_provider = TracerProvider( - resource=resource, id_generator=AirflowOtelIdGenerator(span_id=span_id, trace_id=trace_id) + resource=self.resource, + id_generator=AirflowOtelIdGenerator(span_id=span_id, trace_id=trace_id), ) else: - tracer_provider = TracerProvider(resource=resource) + tracer_provider = TracerProvider(resource=self.resource) debug = conf.getboolean("traces", "otel_debugging_on") if debug is True: log.info("[ConsoleSpanExporter] is being used") - tracer_provider.add_span_processor(BatchSpanProcessor(ConsoleSpanExporter())) - tracer_provider.add_span_processor(self.span_processor) + if self.use_simple_processor: + log.info("[SimpleSpanProcessor] is being used") + span_processor_for_tracer_prov: SpanProcessor = SimpleSpanProcessor(ConsoleSpanExporter()) + else: + log.info("[BatchSpanProcessor] is being used") + span_processor_for_tracer_prov = BatchSpanProcessor(ConsoleSpanExporter()) + else: + span_processor_for_tracer_prov = self.span_processor + + tracer_provider.add_span_processor(span_processor_for_tracer_prov) + return tracer_provider + + def get_tracer( + self, component: str, trace_id: int | None = None, span_id: int | None = None + ) -> OpenTelemetryTracer | Tracer: + tracer_provider = self.get_otel_tracer_provider(trace_id=trace_id, span_id=span_id) tracer = tracer_provider.get_tracer(component) """ Tracer will produce a single ID value if value is provided. Note that this is one-time only, so any @@ -130,106 +165,130 @@ def start_span( ) return span - def start_span_from_dagrun( - self, dagrun, span_name: str | None = None, component: str = "dagrun", links=None + def start_root_span( + self, + span_name: str, + component: str | None = None, + links=None, + start_time=None, + start_as_current: bool = True, ): - """Produce a span from dag run.""" - # check if dagrun has configs - conf = dagrun.conf - trace_id = int(gen_trace_id(dag_run=dagrun, as_int=True)) - span_id = int(gen_dag_span_id(dag_run=dagrun, as_int=True)) - - tracer = self.get_tracer(component=component, span_id=span_id, trace_id=trace_id) + """Start a root span.""" + # If no context is passed to the new span, + # then it will try to get the context of the current active span. + # Due to that, the context parameter can't be empty. + # It needs an invalid context in order to declare the new span as root. + invalid_span_ctx = SpanContext( + trace_id=INVALID_TRACE_ID, span_id=INVALID_SPAN_ID, is_remote=True, trace_flags=TraceFlags(0x01) + ) + invalid_ctx = trace.set_span_in_context(NonRecordingSpan(invalid_span_ctx)) - tag_string = self.tag_string if self.tag_string else "" - tag_string = tag_string + ("," + conf.get(TRACESTATE) if (conf and conf.get(TRACESTATE)) else "") + if links is None: + _links = [] + else: + _links = links - if span_name is None: - span_name = dagrun.dag_id + return self._new_span( + span_name=span_name, + parent_context=invalid_ctx, + component=component, + links=_links, + start_time=start_time, + start_as_current=start_as_current, + ) - _links = gen_links_from_kv_list(links) if links else [] + def start_child_span( + self, + span_name: str, + parent_context: Context | None = None, + component: str | None = None, + links=None, + start_time=None, + start_as_current: bool = True, + ): + """Start a child span.""" + if parent_context is None: + # If no context is passed, then use the current. + parent_span_context = trace.get_current_span().get_span_context() + parent_context = trace.set_span_in_context(NonRecordingSpan(parent_span_context)) + else: + context_val = next(iter(parent_context.values())) + parent_span_context = None + if isinstance(context_val, NonRecordingSpan): + parent_span_context = context_val.get_span_context() - _links.append( - Link( - context=trace.get_current_span().get_span_context(), - attributes={"meta.annotation_type": "link", "from": "parenttrace"}, + if links is None: + _links = [] + else: + _links = links + + if parent_span_context is not None: + _links.append( + Link( + context=parent_span_context, + attributes={"meta.annotation_type": "link", "from": "parenttrace"}, + ) ) - ) - - if conf and conf.get(TRACEPARENT): - # add the trace parent as the link - _links.append(gen_link_from_traceparent(conf.get(TRACEPARENT))) - span_ctx = SpanContext( - trace_id=INVALID_TRACE_ID, span_id=INVALID_SPAN_ID, is_remote=True, trace_flags=TraceFlags(0x01) - ) - ctx = trace.set_span_in_context(NonRecordingSpan(span_ctx)) - span = tracer.start_as_current_span( - name=span_name, - context=ctx, + return self._new_span( + span_name=span_name, + parent_context=parent_context, + component=component, links=_links, - start_time=datetime_to_nano(dagrun.queued_at), - attributes=parse_tracestate(tag_string), + start_time=start_time, + start_as_current=start_as_current, ) - return span - def start_span_from_taskinstance( + def _new_span( self, - ti, - span_name: str | None = None, - component: str = "taskinstance", - child: bool = False, + span_name: str, + parent_context: Context | None = None, + component: str | None = None, links=None, + start_time=None, + start_as_current: bool = True, ): - """ - Create and start span from given task instance. + if component is None: + component = self.otel_service - Essentially the span represents the ti itself if child == True, it will create a 'child' span under the given span. - """ - dagrun = ti.dag_run - conf = dagrun.conf - trace_id = int(gen_trace_id(dag_run=dagrun, as_int=True)) - span_id = int(gen_span_id(ti=ti, as_int=True)) - if span_name is None: - span_name = ti.task_id + tracer = self.get_tracer(component=component) - parent_id = span_id if child else int(gen_dag_span_id(dag_run=dagrun, as_int=True)) + if start_time is None: + start_time = pendulum.now("UTC") - span_ctx = SpanContext( - trace_id=trace_id, span_id=parent_id, is_remote=True, trace_flags=TraceFlags(0x01) - ) + if links is None: + links = [] - _links = gen_links_from_kv_list(links) if links else [] - - _links.append( - Link( - context=SpanContext( - trace_id=trace.get_current_span().get_span_context().trace_id, - span_id=span_id, - is_remote=True, - trace_flags=TraceFlags(0x01), - ), - attributes={"meta.annotation_type": "link", "from": "parenttrace"}, + if start_as_current: + span = tracer.start_as_current_span( + name=span_name, + context=parent_context, + links=links, + start_time=datetime_to_nano(start_time), ) - ) - - if child is False: - tracer = self.get_tracer(component=component, span_id=span_id, trace_id=trace_id) else: - tracer = self.get_tracer(component=component) + span = tracer.start_span( + name=span_name, + context=parent_context, + links=links, + start_time=datetime_to_nano(start_time), + ) + current_span_ctx = trace.set_span_in_context(NonRecordingSpan(span.get_span_context())) + # We have to manually make the span context as the active context. + # If the span needs to be injected into the carrier, then this is needed to make sure + # that the injected context will point to the span context that was just created. + attach(current_span_ctx) + return span - tag_string = self.tag_string if self.tag_string else "" - tag_string = tag_string + ("," + conf.get(TRACESTATE) if (conf and conf.get(TRACESTATE)) else "") + def inject(self) -> dict: + """Inject the current span context into a carrier and return it.""" + carrier: dict[str, str] = {} + TraceContextTextMapPropagator().inject(carrier) + return carrier - ctx = trace.set_span_in_context(NonRecordingSpan(span_ctx)) - span = tracer.start_as_current_span( - name=span_name, - context=ctx, - links=_links, - start_time=datetime_to_nano(ti.queued_dttm), - attributes=parse_tracestate(tag_string), - ) - return span + def extract(self, carrier: dict) -> Context: + """Extract the span context from a provided carrier.""" + return TraceContextTextMapPropagator().extract(carrier) def gen_context(trace_id: int, span_id: int): @@ -265,7 +324,7 @@ def gen_link_from_traceparent(traceparent: str): return Link(context=span_ctx, attributes={"meta.annotation_type": "link", "from": "traceparent"}) -def get_otel_tracer(cls) -> OtelTrace: +def get_otel_tracer(cls, use_simple_processor: bool = False) -> OtelTrace: """Get OTEL tracer from airflow configuration.""" host = conf.get("traces", "otel_host") port = conf.getint("traces", "otel_port") @@ -275,7 +334,16 @@ def get_otel_tracer(cls) -> OtelTrace: protocol = "https" if ssl_active else "http" endpoint = f"{protocol}://{host}:{port}/v1/traces" log.info("[OTLPSpanExporter] Connecting to OpenTelemetry Collector at %s", endpoint) - return OtelTrace(span_exporter=OTLPSpanExporter(endpoint=endpoint), tag_string=tag_string) + log.info("Should use simple processor: %s", use_simple_processor) + return OtelTrace( + span_exporter=OTLPSpanExporter(endpoint=endpoint), + use_simple_processor=use_simple_processor, + tag_string=tag_string, + ) + + +def get_otel_tracer_for_task(cls) -> OtelTrace: + return get_otel_tracer(cls, use_simple_processor=True) class AirflowOtelIdGenerator(IdGenerator): diff --git a/airflow/traces/tracer.py b/airflow/traces/tracer.py index 86b20e6441a02..66a4a76e9bcab 100644 --- a/airflow/traces/tracer.py +++ b/airflow/traces/tracer.py @@ -60,6 +60,7 @@ class EmptyContext: """If no Tracer is configured, EmptyContext is used as a fallback.""" trace_id = 1 + span_id = 1 class EmptySpan: @@ -151,27 +152,31 @@ def get_current_span(self): raise NotImplementedError() @classmethod - def start_span_from_dagrun( - cls, - dagrun, - span_name=None, - service_name=None, - component=None, - links=None, - ): - """Start a span from dagrun.""" + def start_root_span(cls, span_name=None, component=None, start_time=None, start_as_current=True): + """Start a root span.""" raise NotImplementedError() @classmethod - def start_span_from_taskinstance( + def start_child_span( cls, - ti, span_name=None, + parent_context=None, component=None, - child=False, links=None, + start_time=None, + start_as_current=True, ): - """Start a span from taskinstance.""" + """Start a child span.""" + raise NotImplementedError() + + @classmethod + def inject(cls) -> dict: + """Inject the current span context into a carrier and return it.""" + raise NotImplementedError() + + @classmethod + def extract(cls, carrier) -> EmptyContext: + """Extract the span context from a provided carrier.""" raise NotImplementedError() @@ -212,29 +217,35 @@ def get_current_span(self) -> EmptySpan: return EMPTY_SPAN @classmethod - def start_span_from_dagrun( - cls, - dagrun, - span_name=None, - service_name=None, - component=None, - links=None, + def start_root_span( + cls, span_name=None, component=None, start_time=None, start_as_current=True ) -> EmptySpan: - """Start a span from dagrun.""" + """Start a root span.""" return EMPTY_SPAN @classmethod - def start_span_from_taskinstance( + def start_child_span( cls, - ti, span_name=None, + parent_context=None, component=None, - child=False, links=None, + start_time=None, + start_as_current=True, ) -> EmptySpan: - """Start a span from taskinstance.""" + """Start a child span.""" return EMPTY_SPAN + @classmethod + def inject(cls): + """Inject the current span context into a carrier and return it.""" + return {} + + @classmethod + def extract(cls, carrier) -> EmptyContext: + """Extract the span context from a provided carrier.""" + return EMPTY_CTX + class _TraceMeta(type): factory: Callable[[], Tracer] | None = None diff --git a/airflow/utils/dates.py b/airflow/utils/dates.py index 70d9ee8345025..422c662f93640 100644 --- a/airflow/utils/dates.py +++ b/airflow/utils/dates.py @@ -17,6 +17,8 @@ # under the License. from __future__ import annotations +import calendar + cron_presets: dict[str, str] = { "@hourly": "0 * * * *", "@daily": "0 0 * * *", @@ -30,5 +32,11 @@ def datetime_to_nano(datetime) -> int | None: """Convert datetime to nanoseconds.""" if datetime: - return int(datetime.timestamp() * 1000000000) + if datetime.tzinfo is None: + # There is no timezone info, handle it the same as UTC. + timestamp = calendar.timegm(datetime.timetuple()) + datetime.microsecond / 1e6 + else: + # The datetime is timezone-aware. Use timestamp directly. + timestamp = datetime.timestamp() + return int(timestamp * 1e9) return None diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 1a1eb6f4d3500..d7b95d0ce4e9f 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -94,7 +94,7 @@ class MappedClassProtocol(Protocol): "2.9.2": "686269002441", "2.10.0": "22ed7efa9da2", "2.10.3": "5f2621c13b39", - "3.0.0": "e39a26ac59f6", + "3.0.0": "70206b887482", } diff --git a/airflow/utils/span_status.py b/airflow/utils/span_status.py new file mode 100644 index 0000000000000..4d018e72aac87 --- /dev/null +++ b/airflow/utils/span_status.py @@ -0,0 +1,33 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from enum import Enum + + +class SpanStatus(str, Enum): + """All possible statuses for a span.""" + + NOT_STARTED = "not_started" + ACTIVE = "active" + ENDED = "ended" + SHOULD_END = "should_end" + NEEDS_CONTINUANCE = "needs_continuance" + + def __str__(self) -> str: + return self.value diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py index 917af7c1f1da3..818d0e97576dd 100644 --- a/airflow/utils/sqlalchemy.py +++ b/airflow/utils/sqlalchemy.py @@ -28,6 +28,7 @@ from packaging import version from sqlalchemy import TIMESTAMP, PickleType, event, nullsfirst, tuple_ from sqlalchemy.dialects import mysql +from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.types import JSON, Text, TypeDecorator from airflow.configuration import conf @@ -113,7 +114,10 @@ class ExtendedJSON(TypeDecorator): should_evaluate_none = True def load_dialect_impl(self, dialect) -> TypeEngine: - return dialect.type_descriptor(JSON) + if dialect.name == "postgresql": + return dialect.type_descriptor(JSONB) + else: + return dialect.type_descriptor(JSON) def process_bind_param(self, value, dialect): from airflow.serialization.serialized_objects import BaseSerialization diff --git a/airflow/utils/thread_safe_dict.py b/airflow/utils/thread_safe_dict.py new file mode 100644 index 0000000000000..ecd1aed202c62 --- /dev/null +++ b/airflow/utils/thread_safe_dict.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import threading + + +class ThreadSafeDict: + """Dictionary that uses a lock during operations, to ensure thread safety.""" + + def __init__(self): + self.sync_dict = {} + self.thread_lock = threading.Lock() + + def set(self, key, value): + with self.thread_lock: + self.sync_dict[key] = value + + def get(self, key): + with self.thread_lock: + return self.sync_dict.get(key) + + def delete(self, key): + with self.thread_lock: + if key in self.sync_dict: + del self.sync_dict[key] + + def clear(self): + with self.thread_lock: + self.sync_dict.clear() + + def get_all(self): + with self.thread_lock: + # Return a copy to avoid exposing the internal dictionary. + return self.sync_dict.copy() diff --git a/docs/apache-airflow/img/airflow_erd.sha256 b/docs/apache-airflow/img/airflow_erd.sha256 index 5626cf9708b0b..a736de2b0791e 100644 --- a/docs/apache-airflow/img/airflow_erd.sha256 +++ b/docs/apache-airflow/img/airflow_erd.sha256 @@ -1 +1 @@ -cb858681fdc7a596db20c1c5dbf93812fd011a6df1e0b5322a49a51c8476bb93 \ No newline at end of file +2cc5f17792e8dbbb270482d8d902bcb279a6eaf6b15c14aa443aadfdfdf53b65 \ No newline at end of file diff --git a/docs/apache-airflow/img/airflow_erd.svg b/docs/apache-airflow/img/airflow_erd.svg index 3fa6369924caa..ada7ea06cd653 100644 --- a/docs/apache-airflow/img/airflow_erd.svg +++ b/docs/apache-airflow/img/airflow_erd.svg @@ -121,8 +121,8 @@ callback_data - [JSON] - NOT NULL + [JSONB] + NOT NULL callback_type @@ -649,24 +649,24 @@ dagrun_asset_event - -dagrun_asset_event - -dag_run_id - - [INTEGER] - NOT NULL - -event_id - - [INTEGER] - NOT NULL + +dagrun_asset_event + +dag_run_id + + [INTEGER] + NOT NULL + +event_id + + [INTEGER] + NOT NULL asset_event--dagrun_asset_event - -0..N + +0..N 1 @@ -709,114 +709,123 @@ task_instance - -task_instance + +task_instance + +id + + [UUID] + NOT NULL + +context_carrier + + [JSONB] -id - - [UUID] - NOT NULL +custom_operator_name + + [VARCHAR(1000)] -custom_operator_name - - [VARCHAR(1000)] +dag_id + + [VARCHAR(250)] + NOT NULL -dag_id - - [VARCHAR(250)] - NOT NULL +dag_version_id + + [UUID] -dag_version_id - - [UUID] +duration + + [DOUBLE_PRECISION] -duration - - [DOUBLE_PRECISION] +end_date + + [TIMESTAMP] -end_date - - [TIMESTAMP] +executor + + [VARCHAR(1000)] -executor - - [VARCHAR(1000)] +executor_config + + [BYTEA] -executor_config - - [BYTEA] +external_executor_id + + [VARCHAR(250)] -external_executor_id - - [VARCHAR(250)] +hostname + + [VARCHAR(1000)] -hostname - - [VARCHAR(1000)] +last_heartbeat_at + + [TIMESTAMP] -last_heartbeat_at - - [TIMESTAMP] +map_index + + [INTEGER] + NOT NULL -map_index - - [INTEGER] - NOT NULL +max_tries + + [INTEGER] -max_tries - - [INTEGER] +next_kwargs + + [JSONB] -next_kwargs - - [JSON] +next_method + + [VARCHAR(1000)] -next_method - - [VARCHAR(1000)] +operator + + [VARCHAR(1000)] -operator - - [VARCHAR(1000)] +pid + + [INTEGER] -pid - - [INTEGER] +pool + + [VARCHAR(256)] + NOT NULL -pool - - [VARCHAR(256)] - NOT NULL +pool_slots + + [INTEGER] + NOT NULL -pool_slots - - [INTEGER] - NOT NULL +priority_weight + + [INTEGER] -priority_weight - - [INTEGER] +queue + + [VARCHAR(256)] -queue - - [VARCHAR(256)] +queued_by_job_id + + [INTEGER] -queued_by_job_id - - [INTEGER] +queued_dttm + + [TIMESTAMP] -queued_dttm - - [TIMESTAMP] +rendered_map_index + + [VARCHAR(250)] -rendered_map_index - - [VARCHAR(250)] +run_id + + [VARCHAR(250)] + NOT NULL -run_id - - [VARCHAR(250)] - NOT NULL +span_status + + [VARCHAR(250)] + NOT NULL start_date @@ -858,93 +867,93 @@ trigger--task_instance - -0..N + +0..N {0,1} task_reschedule - -task_reschedule + +task_reschedule + +id + + [INTEGER] + NOT NULL + +dag_id + + [VARCHAR(250)] + NOT NULL -id - - [INTEGER] - NOT NULL +duration + + [INTEGER] + NOT NULL -dag_id - - [VARCHAR(250)] - NOT NULL +end_date + + [TIMESTAMP] + NOT NULL -duration - - [INTEGER] - NOT NULL +map_index + + [INTEGER] + NOT NULL -end_date - - [TIMESTAMP] - NOT NULL +reschedule_date + + [TIMESTAMP] + NOT NULL -map_index - - [INTEGER] - NOT NULL +run_id + + [VARCHAR(250)] + NOT NULL -reschedule_date - - [TIMESTAMP] - NOT NULL +start_date + + [TIMESTAMP] + NOT NULL -run_id - - [VARCHAR(250)] - NOT NULL +task_id + + [VARCHAR(250)] + NOT NULL -start_date - - [TIMESTAMP] - NOT NULL - -task_id - - [VARCHAR(250)] - NOT NULL - -try_number - - [INTEGER] - NOT NULL +try_number + + [INTEGER] + NOT NULL task_instance--task_reschedule - -0..N -1 + +0..N +1 task_instance--task_reschedule - -0..N -1 + +0..N +1 task_instance--task_reschedule - -0..N -1 + +0..N +1 task_instance--task_reschedule - -0..N -1 + +0..N +1 @@ -984,30 +993,30 @@ task_instance--rendered_task_instance_fields - -0..N -1 + +0..N +1 task_instance--rendered_task_instance_fields - -0..N -1 + +0..N +1 task_instance--rendered_task_instance_fields - -0..N -1 + +0..N +1 task_instance--rendered_task_instance_fields - -0..N -1 + +0..N +1 @@ -1037,7 +1046,7 @@ keys - [JSON] + [JSONB] length @@ -1047,30 +1056,30 @@ task_instance--task_map - -0..N -1 + +0..N +1 task_instance--task_map - -0..N -1 + +0..N +1 task_instance--task_map - -0..N -1 + +0..N +1 task_instance--task_map - -0..N -1 + +0..N +1 @@ -1120,210 +1129,219 @@ task_instance--xcom - -0..N -1 + +0..N +1 task_instance--xcom - -0..N -1 + +0..N +1 task_instance--xcom - -0..N -1 + +0..N +1 task_instance--xcom - -0..N -1 + +0..N +1 task_instance_note - -task_instance_note + +task_instance_note + +dag_id + + [VARCHAR(250)] + NOT NULL + +map_index + + [INTEGER] + NOT NULL -dag_id - - [VARCHAR(250)] - NOT NULL +run_id + + [VARCHAR(250)] + NOT NULL -map_index - - [INTEGER] - NOT NULL +task_id + + [VARCHAR(250)] + NOT NULL -run_id - - [VARCHAR(250)] - NOT NULL +content + + [VARCHAR(1000)] -task_id - - [VARCHAR(250)] - NOT NULL +created_at + + [TIMESTAMP] + NOT NULL -content - - [VARCHAR(1000)] +updated_at + + [TIMESTAMP] + NOT NULL -created_at - - [TIMESTAMP] - NOT NULL - -updated_at - - [TIMESTAMP] - NOT NULL - -user_id - - [VARCHAR(128)] +user_id + + [VARCHAR(128)] task_instance--task_instance_note - -0..N -1 + +0..N +1 task_instance--task_instance_note - -0..N -1 + +0..N +1 task_instance--task_instance_note - -0..N -1 + +0..N +1 task_instance--task_instance_note - -0..N -1 + +0..N +1 task_instance_history - -task_instance_history + +task_instance_history + +id + + [INTEGER] + NOT NULL + +context_carrier + + [JSONB] -id - - [INTEGER] - NOT NULL +custom_operator_name + + [VARCHAR(1000)] -custom_operator_name - - [VARCHAR(1000)] +dag_id + + [VARCHAR(250)] + NOT NULL -dag_id - - [VARCHAR(250)] - NOT NULL +dag_version_id + + [UUID] -dag_version_id - - [UUID] +duration + + [DOUBLE_PRECISION] -duration - - [DOUBLE_PRECISION] +end_date + + [TIMESTAMP] -end_date - - [TIMESTAMP] +executor + + [VARCHAR(1000)] -executor - - [VARCHAR(1000)] +executor_config + + [BYTEA] -executor_config - - [BYTEA] +external_executor_id + + [VARCHAR(250)] -external_executor_id - - [VARCHAR(250)] +hostname + + [VARCHAR(1000)] -hostname - - [VARCHAR(1000)] +map_index + + [INTEGER] + NOT NULL -map_index - - [INTEGER] - NOT NULL +max_tries + + [INTEGER] -max_tries - - [INTEGER] +next_kwargs + + [JSONB] -next_kwargs - - [JSON] +next_method + + [VARCHAR(1000)] -next_method - - [VARCHAR(1000)] +operator + + [VARCHAR(1000)] -operator - - [VARCHAR(1000)] +pid + + [INTEGER] -pid - - [INTEGER] +pool + + [VARCHAR(256)] + NOT NULL -pool - - [VARCHAR(256)] - NOT NULL +pool_slots + + [INTEGER] + NOT NULL -pool_slots - - [INTEGER] - NOT NULL +priority_weight + + [INTEGER] -priority_weight - - [INTEGER] +queue + + [VARCHAR(256)] -queue - - [VARCHAR(256)] +queued_by_job_id + + [INTEGER] -queued_by_job_id - - [INTEGER] +queued_dttm + + [TIMESTAMP] -queued_dttm - - [TIMESTAMP] +rendered_map_index + + [VARCHAR(250)] -rendered_map_index - - [VARCHAR(250)] +run_id + + [VARCHAR(250)] + NOT NULL -run_id - - [VARCHAR(250)] - NOT NULL +span_status + + [VARCHAR(250)] + NOT NULL start_date @@ -1366,30 +1384,30 @@ task_instance--task_instance_history - -0..N -1 + +0..N +1 task_instance--task_instance_history - -0..N -1 + +0..N +1 task_instance--task_instance_history - -0..N -1 + +0..N +1 task_instance--task_instance_history - -0..N -1 + +0..N +1 @@ -1731,369 +1749,382 @@ deadline - -deadline - -id - - [UUID] - NOT NULL - -callback - - [VARCHAR(500)] - NOT NULL - -callback_kwargs - - [JSON] - -dag_id - - [VARCHAR(250)] - -dagrun_id - - [INTEGER] - -deadline - - [TIMESTAMP] - NOT NULL + +deadline + +id + + [UUID] + NOT NULL + +callback + + [VARCHAR(500)] + NOT NULL + +callback_kwargs + + [JSON] + +dag_id + + [VARCHAR(250)] + +dagrun_id + + [INTEGER] + +deadline + + [TIMESTAMP] + NOT NULL dag--deadline - -0..N + +0..N {0,1} dag_version--task_instance - -0..N -{0,1} + +0..N +{0,1} dag_run - -dag_run - -id - - [INTEGER] - NOT NULL - -backfill_id - - [INTEGER] - -bundle_version - - [VARCHAR(250)] - -clear_number - - [INTEGER] - NOT NULL - -conf - - [JSONB] - -creating_job_id - - [INTEGER] - -dag_id - - [VARCHAR(250)] - NOT NULL - -dag_version_id - - [UUID] - -data_interval_end - - [TIMESTAMP] - -data_interval_start - - [TIMESTAMP] - -end_date - - [TIMESTAMP] - -external_trigger - - [BOOLEAN] - -last_scheduling_decision - - [TIMESTAMP] - -log_template_id - - [INTEGER] - -logical_date - - [TIMESTAMP] - NOT NULL - -queued_at - - [TIMESTAMP] - -run_id - - [VARCHAR(250)] - NOT NULL - -run_type - - [VARCHAR(50)] - NOT NULL - -start_date - - [TIMESTAMP] - -state - - [VARCHAR(50)] - -triggered_by - - [VARCHAR(50)] - -updated_at - - [TIMESTAMP] + +dag_run + +id + + [INTEGER] + NOT NULL + +backfill_id + + [INTEGER] + +bundle_version + + [VARCHAR(250)] + +clear_number + + [INTEGER] + NOT NULL + +conf + + [JSONB] + +context_carrier + + [JSONB] + +creating_job_id + + [INTEGER] + +dag_id + + [VARCHAR(250)] + NOT NULL + +dag_version_id + + [UUID] + +data_interval_end + + [TIMESTAMP] + +data_interval_start + + [TIMESTAMP] + +end_date + + [TIMESTAMP] + +external_trigger + + [BOOLEAN] + +last_scheduling_decision + + [TIMESTAMP] + +log_template_id + + [INTEGER] + +logical_date + + [TIMESTAMP] + NOT NULL + +queued_at + + [TIMESTAMP] + +run_id + + [VARCHAR(250)] + NOT NULL + +run_type + + [VARCHAR(50)] + NOT NULL + +scheduled_by_job_id + + [INTEGER] + +span_status + + [VARCHAR(250)] + NOT NULL + +start_date + + [TIMESTAMP] + +state + + [VARCHAR(50)] + +triggered_by + + [VARCHAR(50)] + +updated_at + + [TIMESTAMP] dag_version--dag_run - -0..N -{0,1} + +0..N +{0,1} dag_code - -dag_code - -id - - [UUID] - NOT NULL - -dag_id - - [VARCHAR(250)] - NOT NULL - -dag_version_id - - [UUID] - NOT NULL - -fileloc - - [VARCHAR(2000)] - NOT NULL - -last_updated - - [TIMESTAMP] - NOT NULL - -source_code - - [TEXT] - NOT NULL - -source_code_hash - - [VARCHAR(32)] - NOT NULL + +dag_code + +id + + [UUID] + NOT NULL + +dag_id + + [VARCHAR(250)] + NOT NULL + +dag_version_id + + [UUID] + NOT NULL + +fileloc + + [VARCHAR(2000)] + NOT NULL + +last_updated + + [TIMESTAMP] + NOT NULL + +source_code + + [TEXT] + NOT NULL + +source_code_hash + + [VARCHAR(32)] + NOT NULL dag_version--dag_code - -0..N -1 + +0..N +1 serialized_dag - -serialized_dag - -id - - [UUID] - NOT NULL - -created_at - - [TIMESTAMP] - NOT NULL - -dag_hash - - [VARCHAR(32)] - NOT NULL - -dag_id - - [VARCHAR(250)] - NOT NULL - -dag_version_id - - [UUID] - NOT NULL - -data - - [JSON] - -data_compressed - - [BYTEA] + +serialized_dag + +id + + [UUID] + NOT NULL + +created_at + + [TIMESTAMP] + NOT NULL + +dag_hash + + [VARCHAR(32)] + NOT NULL + +dag_id + + [VARCHAR(250)] + NOT NULL + +dag_version_id + + [UUID] + NOT NULL + +data + + [JSON] + +data_compressed + + [BYTEA] dag_version--serialized_dag - -0..N -1 + +0..N +1 dag_run--dagrun_asset_event - -0..N -1 + +0..N +1 dag_run--task_instance - -0..N -1 + +0..N +1 dag_run--task_instance - -0..N -1 + +0..N +1 dag_run--deadline - -0..N -{0,1} + +0..N +{0,1} backfill_dag_run - -backfill_dag_run - -id - - [INTEGER] - NOT NULL - -backfill_id - - [INTEGER] - NOT NULL - -dag_run_id - - [INTEGER] - -exception_reason - - [VARCHAR(250)] - -logical_date - - [TIMESTAMP] - NOT NULL - -sort_ordinal - - [INTEGER] - NOT NULL + +backfill_dag_run + +id + + [INTEGER] + NOT NULL + +backfill_id + + [INTEGER] + NOT NULL + +dag_run_id + + [INTEGER] + +exception_reason + + [VARCHAR(250)] + +logical_date + + [TIMESTAMP] + NOT NULL + +sort_ordinal + + [INTEGER] + NOT NULL dag_run--backfill_dag_run - -0..N -{0,1} + +0..N +{0,1} dag_run_note - -dag_run_note - -dag_run_id - - [INTEGER] - NOT NULL - -content - - [VARCHAR(1000)] - -created_at - - [TIMESTAMP] - NOT NULL - -updated_at - - [TIMESTAMP] - NOT NULL - -user_id - - [VARCHAR(128)] + +dag_run_note + +dag_run_id + + [INTEGER] + NOT NULL + +content + + [VARCHAR(1000)] + +created_at + + [TIMESTAMP] + NOT NULL + +updated_at + + [TIMESTAMP] + NOT NULL + +user_id + + [VARCHAR(128)] dag_run--dag_run_note - -1 -1 + +1 +1 dag_run--task_reschedule - -0..N -1 + +0..N +1 dag_run--task_reschedule - -0..N -1 + +0..N +1 @@ -2124,9 +2155,9 @@ log_template--dag_run - -0..N -{0,1} + +0..N +{0,1} @@ -2190,16 +2221,16 @@ backfill--dag_run - -0..N -{0,1} + +0..N +{0,1} backfill--backfill_dag_run - -0..N -1 + +0..N +1 diff --git a/docs/apache-airflow/migrations-ref.rst b/docs/apache-airflow/migrations-ref.rst index 62013ff8f799c..b04e4c48314e0 100644 --- a/docs/apache-airflow/migrations-ref.rst +++ b/docs/apache-airflow/migrations-ref.rst @@ -39,7 +39,9 @@ Here's the list of all the Database Migrations that are executed via when you ru +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | Revision ID | Revises ID | Airflow Version | Description | +=========================+==================+===================+==============================================================+ -| ``e39a26ac59f6`` (head) | ``38770795785f`` | ``3.0.0`` | remove pickled data from dagrun table. | +| ``70206b887482`` (head) | ``e39a26ac59f6`` | ``3.0.0`` | Add new otel span fields. | ++-------------------------+------------------+-------------------+--------------------------------------------------------------+ +| ``e39a26ac59f6`` | ``38770795785f`` | ``3.0.0`` | remove pickled data from dagrun table. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | ``38770795785f`` | ``5c9c0231baa2`` | ``3.0.0`` | Add asset reference models. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ diff --git a/scripts/ci/docker-compose/integration-otel.yml b/scripts/ci/docker-compose/integration-otel.yml index 7a635c17c7d22..d2cd759ad7f59 100644 --- a/scripts/ci/docker-compose/integration-otel.yml +++ b/scripts/ci/docker-compose/integration-otel.yml @@ -67,6 +67,7 @@ services: airflow: environment: + - INTEGRATION_OTEL=true - AIRFLOW__METRICS__OTEL_ON=True - AIRFLOW__METRICS__OTEL_HOST=breeze-otel-collector - AIRFLOW__METRICS__OTEL_PORT=4318 diff --git a/tests/api_fastapi/common/test_exceptions.py b/tests/api_fastapi/common/test_exceptions.py index 6751aff20c725..0ea8463e6baca 100644 --- a/tests/api_fastapi/common/test_exceptions.py +++ b/tests/api_fastapi/common/test_exceptions.py @@ -186,7 +186,7 @@ def test_handle_single_column_unique_constraint_error(self, session, table, expe status_code=status.HTTP_409_CONFLICT, detail={ "reason": "Unique constraint violation", - "statement": "INSERT INTO dag_run (dag_id, queued_at, logical_date, start_date, end_date, state, run_id, creating_job_id, external_trigger, run_type, triggered_by, conf, data_interval_start, data_interval_end, last_scheduling_decision, log_template_id, updated_at, clear_number, backfill_id, dag_version_id, bundle_version) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, (SELECT max(log_template.id) AS max_1 \nFROM log_template), ?, ?, ?, ?, ?)", + "statement": "INSERT INTO dag_run (dag_id, queued_at, logical_date, start_date, end_date, state, run_id, creating_job_id, external_trigger, run_type, triggered_by, conf, data_interval_start, data_interval_end, last_scheduling_decision, log_template_id, updated_at, clear_number, backfill_id, dag_version_id, bundle_version, scheduled_by_job_id, context_carrier, span_status) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, (SELECT max(log_template.id) AS max_1 \nFROM log_template), ?, ?, ?, ?, ?, ?, ?, ?)", "orig_error": "UNIQUE constraint failed: dag_run.dag_id, dag_run.run_id", }, ), @@ -194,7 +194,7 @@ def test_handle_single_column_unique_constraint_error(self, session, table, expe status_code=status.HTTP_409_CONFLICT, detail={ "reason": "Unique constraint violation", - "statement": "INSERT INTO dag_run (dag_id, queued_at, logical_date, start_date, end_date, state, run_id, creating_job_id, external_trigger, run_type, triggered_by, conf, data_interval_start, data_interval_end, last_scheduling_decision, log_template_id, updated_at, clear_number, backfill_id, dag_version_id, bundle_version) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, (SELECT max(log_template.id) AS max_1 \nFROM log_template), %s, %s, %s, %s, %s)", + "statement": "INSERT INTO dag_run (dag_id, queued_at, logical_date, start_date, end_date, state, run_id, creating_job_id, external_trigger, run_type, triggered_by, conf, data_interval_start, data_interval_end, last_scheduling_decision, log_template_id, updated_at, clear_number, backfill_id, dag_version_id, bundle_version, scheduled_by_job_id, context_carrier, span_status) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, (SELECT max(log_template.id) AS max_1 \nFROM log_template), %s, %s, %s, %s, %s, %s, %s, %s)", "orig_error": "(1062, \"Duplicate entry 'test_dag_id-test_run_id' for key 'dag_run.dag_run_dag_id_run_id_key'\")", }, ), @@ -202,7 +202,7 @@ def test_handle_single_column_unique_constraint_error(self, session, table, expe status_code=status.HTTP_409_CONFLICT, detail={ "reason": "Unique constraint violation", - "statement": "INSERT INTO dag_run (dag_id, queued_at, logical_date, start_date, end_date, state, run_id, creating_job_id, external_trigger, run_type, triggered_by, conf, data_interval_start, data_interval_end, last_scheduling_decision, log_template_id, updated_at, clear_number, backfill_id, dag_version_id, bundle_version) VALUES (%(dag_id)s, %(queued_at)s, %(logical_date)s, %(start_date)s, %(end_date)s, %(state)s, %(run_id)s, %(creating_job_id)s, %(external_trigger)s, %(run_type)s, %(triggered_by)s, %(conf)s, %(data_interval_start)s, %(data_interval_end)s, %(last_scheduling_decision)s, (SELECT max(log_template.id) AS max_1 \nFROM log_template), %(updated_at)s, %(clear_number)s, %(backfill_id)s, %(dag_version_id)s, %(bundle_version)s) RETURNING dag_run.id", + "statement": "INSERT INTO dag_run (dag_id, queued_at, logical_date, start_date, end_date, state, run_id, creating_job_id, external_trigger, run_type, triggered_by, conf, data_interval_start, data_interval_end, last_scheduling_decision, log_template_id, updated_at, clear_number, backfill_id, dag_version_id, bundle_version, scheduled_by_job_id, context_carrier, span_status) VALUES (%(dag_id)s, %(queued_at)s, %(logical_date)s, %(start_date)s, %(end_date)s, %(state)s, %(run_id)s, %(creating_job_id)s, %(external_trigger)s, %(run_type)s, %(triggered_by)s, %(conf)s, %(data_interval_start)s, %(data_interval_end)s, %(last_scheduling_decision)s, (SELECT max(log_template.id) AS max_1 \nFROM log_template), %(updated_at)s, %(clear_number)s, %(backfill_id)s, %(dag_version_id)s, %(bundle_version)s, %(scheduled_by_job_id)s, %(context_carrier)s, %(span_status)s) RETURNING dag_run.id", "orig_error": 'duplicate key value violates unique constraint "dag_run_dag_id_run_id_key"\nDETAIL: Key (dag_id, run_id)=(test_dag_id, test_run_id) already exists.\n', }, ), diff --git a/tests/core/test_otel_tracer.py b/tests/core/test_otel_tracer.py index 666e78e2dc84f..6de8e933b990f 100644 --- a/tests/core/test_otel_tracer.py +++ b/tests/core/test_otel_tracer.py @@ -19,13 +19,17 @@ import json import logging from datetime import datetime -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest +from opentelemetry.sdk import util from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter -from airflow.traces import TRACEPARENT, TRACESTATE, otel_tracer, utils -from airflow.traces.tracer import Trace +from airflow.configuration import conf +from airflow.traces import otel_tracer +from airflow.traces.otel_tracer import OtelTrace +from airflow.traces.tracer import EmptyTrace, Trace +from airflow.utils.dates import datetime_to_nano from tests_common.test_utils.config import env_vars @@ -36,6 +40,26 @@ def name(): class TestOtelTrace: + def test_get_otel_tracer_from_trace_metaclass(self): + """Test that `Trace.some_method()`, uses an `OtelTrace` instance when otel is configured.""" + conf.add_section("traces") + conf.set("traces", "otel_on", "True") + conf.set("traces", "otel_debugging_on", "True") + + tracer = otel_tracer.get_otel_tracer(Trace) + assert tracer.use_simple_processor is False + + assert isinstance(Trace.factory(), EmptyTrace) + + Trace.configure_factory() + assert isinstance(Trace.factory(), OtelTrace) + + task_tracer = otel_tracer.get_otel_tracer_for_task(Trace) + assert task_tracer.use_simple_processor is True + + task_tracer.get_otel_tracer_provider() + assert task_tracer.use_simple_processor is True + @patch("opentelemetry.sdk.trace.export.ConsoleSpanExporter") @patch("airflow.traces.otel_tracer.conf") def test_tracer(self, conf_a, exporter): @@ -90,37 +114,28 @@ def test_dag_tracer(self, conf_a, exporter): exporter.return_value = in_mem_exporter now = datetime.now() - dag_run = MagicMock() - - parent_trace_id = "0af7651916cd43dd8448eb211c80319c" - parent_span_id = "b7ad6b7169203331" - - dag_run.conf = { - TRACEPARENT: f"00-{parent_trace_id}-{parent_span_id}-01", - TRACESTATE: "key1=val1,key2=val2", - } - dag_run.dag_id = "dag_id" - dag_run.run_id = "run_id" - dag_run.dag_hash = "hashcode" - dag_run.run_type = "manual" - dag_run.queued_at = now - dag_run.start_date = now tracer = otel_tracer.get_otel_tracer(Trace) - with tracer.start_span_from_dagrun(dagrun=dag_run) as s1: + with tracer.start_root_span(span_name="span1", start_time=now) as s1: with tracer.start_span(span_name="span2") as s2: s2.set_attribute("attr2", "val2") + span2 = json.loads(s2.to_json()) span1 = json.loads(s1.to_json()) - assert span1["context"]["trace_id"] != f"0x{parent_trace_id}" - assert span1["links"][1]["context"]["trace_id"] == f"0x{parent_trace_id}" - assert span1["links"][1]["context"]["span_id"] == f"0x{parent_span_id}" + + # The otel sdk, accepts an int for the start_time, and converts it to an iso string, + # using `util.ns_to_iso_str()`. + nano_time = datetime_to_nano(now) + assert span1["start_time"] == util.ns_to_iso_str(nano_time) + # Same trace_id + assert span1["context"]["trace_id"] == span2["context"]["trace_id"] + assert span1["context"]["span_id"] == span2["parent_id"] @patch("opentelemetry.sdk.trace.export.ConsoleSpanExporter") @patch("airflow.traces.otel_tracer.conf") - def test_traskinstance_tracer(self, conf_a, exporter): + def test_context_propagation(self, conf_a, exporter): # necessary to speed up the span to be emitted with env_vars({"OTEL_BSP_SCHEDULE_DELAY": "1"}): - log = logging.getLogger("TestOtelTrace.test_taskinstance_tracer") + log = logging.getLogger("TestOtelTrace.test_context_propagation") log.setLevel(logging.DEBUG) conf_a.get.return_value = "abc" conf_a.getint.return_value = 123 @@ -131,27 +146,38 @@ def test_traskinstance_tracer(self, conf_a, exporter): in_mem_exporter = InMemorySpanExporter() exporter.return_value = in_mem_exporter - now = datetime.now() - # magic mock - ti = MagicMock() - ti.dag_run.conf = {} - ti.task_id = "task_id" - ti.start_date = now - ti.dag_run.dag_id = "dag_id" - ti.dag_run.run_id = "run_id" - ti.dag_run.dag_hash = "hashcode" - ti.dag_run.run_type = "manual" - ti.dag_run.queued_at = now - ti.dag_run.start_date = now + # Method that represents another service which is + # - getting the carrier + # - extracting the context + # - using the context to create a new span + # The new span should be associated with the span from the injected context carrier. + def _task_func(otel_tr, carrier): + parent_context = otel_tr.extract(carrier) + + with otel_tr.start_child_span(span_name="sub_span", parent_context=parent_context) as span: + span.set_attribute("attr2", "val2") + json_span = json.loads(span.to_json()) + return json_span tracer = otel_tracer.get_otel_tracer(Trace) - with tracer.start_span_from_taskinstance(ti=ti, span_name="mydag") as s1: - with tracer.start_span(span_name="span2") as s2: - s2.set_attribute("attr2", "val2") - span2 = json.loads(s2.to_json()) - span1 = json.loads(s1.to_json()) - log.info(span1) - log.info(span2) - assert span1["context"]["trace_id"] == f"0x{utils.gen_trace_id(ti.dag_run)}" - assert span1["context"]["span_id"] == f"0x{utils.gen_span_id(ti)}" + root_span = tracer.start_root_span(span_name="root_span", start_as_current=False) + # The context is available, it can be injected into the carrier. + context_carrier = tracer.inject() + + # Some function that uses the carrier to create a new span. + json_span2 = _task_func(otel_tr=tracer, carrier=context_carrier) + + json_span1 = json.loads(root_span.to_json()) + # Manually end the span. + root_span.end() + + log.info(json_span1) + log.info(json_span2) + # Verify that span1 is a root span. + assert json_span1["parent_id"] is None + # Check span2 parent_id to verify that it's a child of span1. + assert json_span2["parent_id"] == json_span1["context"]["span_id"] + # The trace_id and the span_id are randomly generated by the otel sdk. + # Both spans should belong to the same trace. + assert json_span1["context"]["trace_id"] == json_span2["context"]["trace_id"] diff --git a/tests/integration/executors/test_celery_executor.py b/tests/integration/executors/test_celery_executor.py index 0f9f0b45ae9c1..e43bd88a06481 100644 --- a/tests/integration/executors/test_celery_executor.py +++ b/tests/integration/executors/test_celery_executor.py @@ -218,7 +218,7 @@ def fake_execute_command(): ) when = datetime.now() value_tuple = ( - "command", + ["command"], 1, None, SimpleTaskInstance.from_ti(ti=TaskInstance(task=task, run_id=None)), @@ -256,7 +256,7 @@ def test_retry_on_error_sending_task(self, caplog): ) when = datetime.now() value_tuple = ( - "command", + ["command"], 1, None, SimpleTaskInstance.from_ti(ti=TaskInstance(task=task, run_id=None)), diff --git a/tests/integration/otel/__init__.py b/tests/integration/otel/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/tests/integration/otel/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/integration/otel/dags/__init__.py b/tests/integration/otel/dags/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/tests/integration/otel/dags/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/integration/otel/dags/otel_test_dag.py b/tests/integration/otel/dags/otel_test_dag.py new file mode 100644 index 0000000000000..92b03dcd1f0cd --- /dev/null +++ b/tests/integration/otel/dags/otel_test_dag.py @@ -0,0 +1,103 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +from datetime import datetime + +from opentelemetry import trace + +from airflow import DAG +from airflow.providers.standard.operators.python import PythonOperator +from airflow.traces import otel_tracer +from airflow.traces.tracer import Trace + +logger = logging.getLogger("airflow.otel_test_dag") + +args = { + "owner": "airflow", + "start_date": datetime(2024, 9, 1), + "retries": 0, +} + +# DAG definition. +with DAG( + "otel_test_dag", + default_args=args, + schedule=None, + catchup=False, +) as dag: + # Tasks. + def task1_func(**dag_context): + logger.info("Starting Task_1.") + + ti = dag_context["ti"] + context_carrier = ti.context_carrier + + otel_task_tracer = otel_tracer.get_otel_tracer_for_task(Trace) + tracer_provider = otel_task_tracer.get_otel_tracer_provider() + + if context_carrier is not None: + logger.info("Found ti.context_carrier: %s.", context_carrier) + logger.info("Extracting the span context from the context_carrier.") + parent_context = otel_task_tracer.extract(context_carrier) + with otel_task_tracer.start_child_span( + span_name=f"{ti.task_id}_sub_span1", + parent_context=parent_context, + component="dag", + ) as s1: + s1.set_attribute("attr1", "val1") + logger.info("From task sub_span1.") + + with otel_task_tracer.start_child_span(f"{ti.task_id}_sub_span2") as s2: + s2.set_attribute("attr2", "val2") + logger.info("From task sub_span2.") + + tracer = trace.get_tracer("trace_test.tracer", tracer_provider=tracer_provider) + with tracer.start_as_current_span(name=f"{ti.task_id}_sub_span3") as s3: + s3.set_attribute("attr3", "val3") + logger.info("From task sub_span3.") + + with otel_task_tracer.start_child_span( + span_name=f"{ti.task_id}_sub_span4", + parent_context=parent_context, + component="dag", + ) as s4: + s4.set_attribute("attr4", "val4") + logger.info("From task sub_span4.") + + logger.info("Task_1 finished.") + + def task2_func(): + logger.info("Starting Task_2.") + for i in range(3): + logger.info("Task_2, iteration '%d'.", i) + logger.info("Task_2 finished.") + + # Task operators. + t1 = PythonOperator( + task_id="task_1", + python_callable=task1_func, + ) + + t2 = PythonOperator( + task_id="task_2", + python_callable=task2_func, + ) + + # Dependencies. + t1 >> t2 diff --git a/tests/integration/otel/dags/otel_test_dag_with_pause.py b/tests/integration/otel/dags/otel_test_dag_with_pause.py new file mode 100644 index 0000000000000..1c46f7f7cc16a --- /dev/null +++ b/tests/integration/otel/dags/otel_test_dag_with_pause.py @@ -0,0 +1,166 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +import os +from datetime import datetime + +from opentelemetry import trace +from sqlalchemy import select + +from airflow import DAG +from airflow.models import TaskInstance +from airflow.providers.standard.operators.python import PythonOperator +from airflow.traces import otel_tracer +from airflow.traces.tracer import Trace +from airflow.utils.session import create_session + +logger = logging.getLogger("airflow.otel_test_dag_with_pause") + +args = { + "owner": "airflow", + "start_date": datetime(2024, 9, 2), + "retries": 0, +} + +# DAG definition. +with DAG( + "otel_test_dag_with_pause", + default_args=args, + schedule=None, + catchup=False, +) as dag: + # Tasks. + def task1_func(**dag_context): + logger.info("Starting Task_1.") + + ti = dag_context["ti"] + context_carrier = ti.context_carrier + + otel_task_tracer = otel_tracer.get_otel_tracer_for_task(Trace) + tracer_provider = otel_task_tracer.get_otel_tracer_provider() + + if context_carrier is not None: + logger.info("Found ti.context_carrier: %s.", context_carrier) + logger.info("Extracting the span context from the context_carrier.") + + # If the task takes too long to execute, then the ti should be read from the db + # to make sure that the initial context_carrier is the same. + with create_session() as session: + session_ti: TaskInstance = session.scalars( + select(TaskInstance).where( + TaskInstance.task_id == ti.task_id, + TaskInstance.run_id == ti.run_id, + ) + ).one() + context_carrier = session_ti.context_carrier + + parent_context = Trace.extract(context_carrier) + with otel_task_tracer.start_child_span( + span_name=f"{ti.task_id}_sub_span1", + parent_context=parent_context, + component="dag", + ) as s1: + s1.set_attribute("attr1", "val1") + logger.info("From task sub_span1.") + + with otel_task_tracer.start_child_span(f"{ti.task_id}_sub_span2") as s2: + s2.set_attribute("attr2", "val2") + logger.info("From task sub_span2.") + + tracer = trace.get_tracer("trace_test.tracer", tracer_provider=tracer_provider) + with tracer.start_as_current_span(name=f"{ti.task_id}_sub_span3") as s3: + s3.set_attribute("attr3", "val3") + logger.info("From task sub_span3.") + + with create_session() as session: + session_ti: TaskInstance = session.scalars( + select(TaskInstance).where( + TaskInstance.task_id == ti.task_id, + TaskInstance.run_id == ti.run_id, + ) + ).one() + context_carrier = session_ti.context_carrier + parent_context = Trace.extract(context_carrier) + + with otel_task_tracer.start_child_span( + span_name=f"{ti.task_id}_sub_span4", + parent_context=parent_context, + component="dag", + ) as s4: + s4.set_attribute("attr4", "val4") + logger.info("From task sub_span4.") + + logger.info("Task_1 finished.") + + def paused_task_func(): + logger.info("Starting Paused_task.") + + dag_folder = os.path.dirname(os.path.abspath(__file__)) + control_file = os.path.join(dag_folder, "dag_control.txt") + + # Create the file and write 'pause' to it. + with open(control_file, "w") as file: + file.write("pause") + + # Pause execution until the word 'pause' is replaced on the file. + while True: + # If there is an exception, then writing to the file failed. Let it exit. + file_contents = None + with open(control_file) as file: + file_contents = file.read() + + if "pause" in file_contents: + logger.info("Task has been paused.") + continue + else: + logger.info("Resuming task execution.") + # Break the loop and finish with the task execution. + break + + # Cleanup the control file. + if os.path.exists(control_file): + os.remove(control_file) + print("Control file has been cleaned up.") + + logger.info("Paused_task finished.") + + def task2_func(): + logger.info("Starting Task_2.") + for i in range(3): + logger.info("Task_2, iteration '%d'.", i) + logger.info("Task_2 finished.") + + # Task operators. + t1 = PythonOperator( + task_id="task_1", + python_callable=task1_func, + ) + + pause = PythonOperator( + task_id="paused_task", + python_callable=paused_task_func, + ) + + t2 = PythonOperator( + task_id="task_2", + python_callable=task2_func, + ) + + # Dependencies. + t1 >> pause >> t2 diff --git a/tests/integration/otel/dags/otel_test_dag_with_pause_in_task.py b/tests/integration/otel/dags/otel_test_dag_with_pause_in_task.py new file mode 100644 index 0000000000000..9119634c0d3fc --- /dev/null +++ b/tests/integration/otel/dags/otel_test_dag_with_pause_in_task.py @@ -0,0 +1,156 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +import os +from datetime import datetime + +from opentelemetry import trace +from sqlalchemy import select + +from airflow import DAG +from airflow.models import TaskInstance +from airflow.providers.standard.operators.python import PythonOperator +from airflow.traces import otel_tracer +from airflow.traces.tracer import Trace +from airflow.utils.session import create_session + +logger = logging.getLogger("airflow.otel_test_dag_with_pause_in_task") + +args = { + "owner": "airflow", + "start_date": datetime(2024, 9, 2), + "retries": 0, +} + +# DAG definition. +with DAG( + "otel_test_dag_with_pause_in_task", + default_args=args, + schedule=None, + catchup=False, +) as dag: + # Tasks. + def task1_func(**dag_context): + logger.info("Starting Task_1.") + + ti = dag_context["ti"] + context_carrier = ti.context_carrier + + dag_folder = os.path.dirname(os.path.abspath(__file__)) + control_file = os.path.join(dag_folder, "dag_control.txt") + + # Create the file and write 'pause' to it. + with open(control_file, "w") as file: + file.write("pause") + + # Pause execution until the word 'pause' is replaced on the file. + while True: + # If there is an exception, then writing to the file failed. Let it exit. + file_contents = None + with open(control_file) as file: + file_contents = file.read() + + if "pause" in file_contents: + logger.info("Task has been paused.") + continue + else: + logger.info("Resuming task execution.") + # Break the loop and finish with the task execution. + break + + otel_task_tracer = otel_tracer.get_otel_tracer_for_task(Trace) + tracer_provider = otel_task_tracer.get_otel_tracer_provider() + + if context_carrier is not None: + logger.info("Found ti.context_carrier: %s.", context_carrier) + logger.info("Extracting the span context from the context_carrier.") + + # If the task takes too long to execute, then the ti should be read from the db + # to make sure that the initial context_carrier is the same. + with create_session() as session: + session_ti: TaskInstance = session.scalars( + select(TaskInstance).where( + TaskInstance.task_id == ti.task_id, + TaskInstance.run_id == ti.run_id, + ) + ).one() + context_carrier = session_ti.context_carrier + + parent_context = Trace.extract(context_carrier) + with otel_task_tracer.start_child_span( + span_name=f"{ti.task_id}_sub_span1", + parent_context=parent_context, + component="dag", + ) as s1: + s1.set_attribute("attr1", "val1") + logger.info("From task sub_span1.") + + with otel_task_tracer.start_child_span(f"{ti.task_id}_sub_span2") as s2: + s2.set_attribute("attr2", "val2") + logger.info("From task sub_span2.") + + tracer = trace.get_tracer("trace_test.tracer", tracer_provider=tracer_provider) + with tracer.start_as_current_span(name=f"{ti.task_id}_sub_span3") as s3: + s3.set_attribute("attr3", "val3") + logger.info("From task sub_span3.") + + with create_session() as session: + session_ti: TaskInstance = session.scalars( + select(TaskInstance).where( + TaskInstance.task_id == ti.task_id, + TaskInstance.run_id == ti.run_id, + ) + ).one() + context_carrier = session_ti.context_carrier + parent_context = Trace.extract(context_carrier) + + with otel_task_tracer.start_child_span( + span_name=f"{ti.task_id}_sub_span4", + parent_context=parent_context, + component="dag", + ) as s4: + s4.set_attribute("attr4", "val4") + logger.info("From task sub_span4.") + + # Cleanup the control file. + if os.path.exists(control_file): + os.remove(control_file) + print("Control file has been cleaned up.") + + logger.info("Task_1 finished.") + + def task2_func(): + logger.info("Starting Task_2.") + for i in range(3): + logger.info("Task_2, iteration '%d'.", i) + logger.info("Task_2 finished.") + + # Task operators. + t1 = PythonOperator( + task_id="task_1", + python_callable=task1_func, + ) + + t2 = PythonOperator( + task_id="task_2", + python_callable=task2_func, + ) + + # Dependencies. + t1 >> t2 diff --git a/tests/integration/otel/test_otel.py b/tests/integration/otel/test_otel.py new file mode 100644 index 0000000000000..b515f03266a94 --- /dev/null +++ b/tests/integration/otel/test_otel.py @@ -0,0 +1,1225 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +import os +import signal +import subprocess +import time +from typing import Any + +import pendulum +import pytest + +from airflow.executors import executor_loader +from airflow.executors.executor_utils import ExecutorName +from airflow.models import DAG, DagBag, DagRun +from airflow.models.serialized_dag import SerializedDagModel +from airflow.models.taskinstance import TaskInstance +from airflow.utils.session import create_session +from airflow.utils.span_status import SpanStatus +from airflow.utils.state import State + +from tests.integration.otel.test_utils import ( + assert_parent_children_spans, + assert_parent_children_spans_for_non_root, + assert_span_name_belongs_to_root_span, + assert_span_not_in_children_spans, + dump_airflow_metadata_db, + extract_spans_from_output, + get_parent_child_dict, +) + +log = logging.getLogger("integration.otel.test_otel") + + +def unpause_trigger_dag_and_get_run_id(dag_id: str) -> str: + unpause_command = ["airflow", "dags", "unpause", dag_id] + + # Unpause the dag using the cli. + subprocess.run(unpause_command, check=True, env=os.environ.copy()) + + execution_date = pendulum.now("UTC") + run_id = f"manual__{execution_date.isoformat()}" + + trigger_command = [ + "airflow", + "dags", + "trigger", + dag_id, + "--run-id", + run_id, + "--exec-date", + execution_date.isoformat(), + ] + + # Trigger the dag using the cli. + subprocess.run(trigger_command, check=True, env=os.environ.copy()) + + return run_id + + +def wait_for_dag_run_and_check_span_status(dag_id: str, run_id: str, max_wait_time: int, span_status: str): + # max_wait_time, is the timeout for the DAG run to complete. The value is in seconds. + start_time = time.time() + + while time.time() - start_time < max_wait_time: + with create_session() as session: + dag_run = ( + session.query(DagRun) + .filter( + DagRun.dag_id == dag_id, + DagRun.run_id == run_id, + ) + .first() + ) + + if dag_run is None: + time.sleep(5) + continue + + dag_run_state = dag_run.state + log.debug("DAG Run state: %s.", dag_run_state) + + dag_run_span_status = dag_run.span_status + log.debug("DAG Run span status: %s.", dag_run_span_status) + + if dag_run_state in [State.SUCCESS, State.FAILED]: + break + + assert ( + dag_run_state == State.SUCCESS + ), f"Dag run did not complete successfully. Final state: {dag_run_state}." + + assert ( + dag_run_span_status == span_status + ), f"Dag run span status isn't {span_status} as expected.Actual status: {dag_run_span_status}." + + +def check_dag_run_state_and_span_status(dag_id: str, run_id: str, state: str, span_status: str): + with create_session() as session: + dag_run = ( + session.query(DagRun) + .filter( + DagRun.dag_id == dag_id, + DagRun.run_id == run_id, + ) + .first() + ) + + assert dag_run.state == state, f"Dag Run state isn't {state}. State: {dag_run.state}" + assert ( + dag_run.span_status == span_status + ), f"Dag Run span_status isn't {span_status}. Span_status: {dag_run.span_status}" + + +def check_ti_state_and_span_status(task_id: str, run_id: str, state: str, span_status: str): + with create_session() as session: + ti = ( + session.query(TaskInstance) + .filter( + TaskInstance.task_id == task_id, + TaskInstance.run_id == run_id, + ) + .first() + ) + + assert ti.state == state, f"Task instance state isn't {state}. State: {ti.state}" + assert ( + ti.span_status == span_status + ), f"Task instance span_status isn't {span_status}. Span_status: {ti.span_status}" + + +def check_spans_with_continuance(output: str, dag: DAG, continuance_for_t1: bool = True): + # Get a list of lines from the captured output. + output_lines = output.splitlines() + + # Filter the output, create a json obj for each span and then store them into dictionaries. + # One dictionary with only the root spans, and one with all the captured spans (root and otherwise). + root_span_dict, span_dict = extract_spans_from_output(output_lines) + # Generate a dictionary with parent child relationships. + # This is done by comparing the span_id of each root span with the parent_id of each non-root span. + parent_child_dict = get_parent_child_dict(root_span_dict, span_dict) + + # The span hierarchy for dag 'otel_test_dag_with_pause_in_task' is + # dag span + # |_ task_1 span + # |_ scheduler_exited span + # |_ new_scheduler span + # |_ dag span (continued) + # |_ task_1 span (continued) + # |_ sub_span_1 + # |_ sub_span_2 + # |_ sub_span_3 + # |_ sub_span_4 + # |_ task_2 span + # + # If there is no continuance for task_1, then the span hierarchy is + # dag span + # |_ task_1 span + # |_ sub_span_1 + # |_ sub_span_2 + # |_ sub_span_3 + # |_ sub_span_4 + # |_ scheduler_exited span + # |_ new_scheduler span + # |_ dag span (continued) + # |_ task_2 span + + dag_id = dag.dag_id + + task_instance_ids = dag.task_ids + task1_id = task_instance_ids[0] + task2_id = task_instance_ids[1] + + dag_root_span_name = f"{dag_id}" + + dag_root_span_children_names = [ + f"{task1_id}", + "current_scheduler_exited", + "new_scheduler", + f"{dag_id}_continued", + ] + + if continuance_for_t1: + dag_continued_span_children_names = [ + f"{task1_id}_continued", + f"{task2_id}", + ] + else: + dag_continued_span_children_names = [ + f"{task2_id}", + ] + + task1_span_children_names = [ + f"{task1_id}_sub_span1", + f"{task1_id}_sub_span4", + ] + + # Single element lists. + task1_sub_span1_children_span_names = [f"{task1_id}_sub_span2"] + task1_sub_span2_children_span_names = [f"{task1_id}_sub_span3"] + + assert_span_name_belongs_to_root_span( + root_span_dict=root_span_dict, span_name=dag_root_span_name, should_succeed=True + ) + + # Check direct children of the root span. + assert_parent_children_spans( + parent_child_dict=parent_child_dict, + root_span_dict=root_span_dict, + parent_name=dag_root_span_name, + children_names=dag_root_span_children_names, + ) + + # Use a span name that exists, but it's not a direct child. + assert_span_not_in_children_spans( + parent_child_dict=parent_child_dict, + root_span_dict=root_span_dict, + span_dict=span_dict, + parent_name=dag_root_span_name, + child_name=f"{task1_id}_continued", + span_exists=True, + ) + + # Use a span name that doesn't exist at all. + assert_span_not_in_children_spans( + parent_child_dict=parent_child_dict, + root_span_dict=root_span_dict, + span_dict=span_dict, + parent_name=dag_root_span_name, + child_name=f"{task1_id}_non_existent", + span_exists=False, + ) + + # Check children of the continued dag span. + assert_parent_children_spans_for_non_root( + span_dict=span_dict, + parent_name=f"{dag_id}_continued", + children_names=dag_continued_span_children_names, + ) + + if continuance_for_t1: + # Check children of the continued task1 span. + assert_parent_children_spans_for_non_root( + span_dict=span_dict, + parent_name=f"{task1_id}_continued", + children_names=task1_span_children_names, + ) + else: + # Check children of the task1 span. + assert_parent_children_spans_for_non_root( + span_dict=span_dict, + parent_name=f"{task1_id}", + children_names=task1_span_children_names, + ) + + # Check children of task1 sub span1. + assert_parent_children_spans_for_non_root( + span_dict=span_dict, + parent_name=f"{task1_id}_sub_span1", + children_names=task1_sub_span1_children_span_names, + ) + + # Check children of task1 sub span2. + assert_parent_children_spans_for_non_root( + span_dict=span_dict, + parent_name=f"{task1_id}_sub_span2", + children_names=task1_sub_span2_children_span_names, + ) + + +def check_spans_without_continuance( + output: str, dag: DAG, is_recreated: bool = False, check_t1_sub_spans: bool = True +): + recreated_suffix = "_recreated" if is_recreated else "" + + # Get a list of lines from the captured output. + output_lines = output.splitlines() + + # Filter the output, create a json obj for each span and then store them into dictionaries. + # One dictionary with only the root spans, and one with all the captured spans (root and otherwise). + root_span_dict, span_dict = extract_spans_from_output(output_lines) + # Generate a dictionary with parent child relationships. + # This is done by comparing the span_id of each root span with the parent_id of each non-root span. + parent_child_dict = get_parent_child_dict(root_span_dict, span_dict) + + # Any spans generated under a task, are children of the task span. + # The span hierarchy for dag 'otel_test_dag' is + # dag span + # |_ task_1 span + # |_ sub_span_1 + # |_ sub_span_2 + # |_ sub_span_3 + # |_ sub_span_4 + # |_ task_2 span + # + # In case task_1 has finished running and the span is recreated, + # the sub spans are lost and can't be recreated. The span hierarchy will be + # dag span + # |_ task_1 span + # |_ task_2 span + + dag_id = dag.dag_id + + task_instance_ids = dag.task_ids + task1_id = task_instance_ids[0] + task2_id = task_instance_ids[1] + + # Based on the current tests, only the root span and the task1 span will be recreated. + # TODO: Adjust accordingly, if there are more tests in the future + # that require other spans to be recreated as well. + dag_root_span_name = f"{dag_id}{recreated_suffix}" + + dag_root_span_children_names = [ + f"{task1_id}{recreated_suffix}", + f"{task2_id}", + ] + + task1_span_children_names = [ + f"{task1_id}_sub_span1", + f"{task1_id}_sub_span4", + ] + + # Single element lists. + task1_sub_span1_children_span_names = [f"{task1_id}_sub_span2"] + task1_sub_span2_children_span_names = [f"{task1_id}_sub_span3"] + + assert_span_name_belongs_to_root_span( + root_span_dict=root_span_dict, span_name=dag_root_span_name, should_succeed=True + ) + + # Check direct children of the root span. + assert_parent_children_spans( + parent_child_dict=parent_child_dict, + root_span_dict=root_span_dict, + parent_name=dag_root_span_name, + children_names=dag_root_span_children_names, + ) + + # Use a span name that exists, but it's not a direct child. + assert_span_not_in_children_spans( + parent_child_dict=parent_child_dict, + root_span_dict=root_span_dict, + span_dict=span_dict, + parent_name=dag_root_span_name, + child_name=f"{task1_id}_sub_span1", + span_exists=True, + ) + + # Use a span name that doesn't exist at all. + assert_span_not_in_children_spans( + parent_child_dict=parent_child_dict, + root_span_dict=root_span_dict, + span_dict=span_dict, + parent_name=dag_root_span_name, + child_name=f"{task1_id}_non_existent", + span_exists=False, + ) + + if check_t1_sub_spans: + # Check children of the task1 span. + assert_parent_children_spans_for_non_root( + span_dict=span_dict, + parent_name=f"{task1_id}{recreated_suffix}", + children_names=task1_span_children_names, + ) + + # Check children of task1 sub span1. + assert_parent_children_spans_for_non_root( + span_dict=span_dict, + parent_name=f"{task1_id}_sub_span1", + children_names=task1_sub_span1_children_span_names, + ) + + # Check children of task1 sub span2. + assert_parent_children_spans_for_non_root( + span_dict=span_dict, + parent_name=f"{task1_id}_sub_span2", + children_names=task1_sub_span2_children_span_names, + ) + + +def check_spans_for_paused_dag( + output: str, dag: DAG, is_recreated: bool = False, check_t1_sub_spans: bool = True +): + recreated_suffix = "_recreated" if is_recreated else "" + + # Get a list of lines from the captured output. + output_lines = output.splitlines() + + # Filter the output, create a json obj for each span and then store them into dictionaries. + # One dictionary with only the root spans, and one with all the captured spans (root and otherwise). + root_span_dict, span_dict = extract_spans_from_output(output_lines) + # Generate a dictionary with parent child relationships. + # This is done by comparing the span_id of each root span with the parent_id of each non-root span. + parent_child_dict = get_parent_child_dict(root_span_dict, span_dict) + + # Any spans generated under a task, are children of the task span. + # The span hierarchy for dag 'otel_test_dag_with_pause' is + # dag span + # |_ task_1 span + # |_ sub_span_1 + # |_ sub_span_2 + # |_ sub_span_3 + # |_ sub_span_4 + # |_ paused_task span + # |_ task_2 span + # + # In case task_1 has finished running and the span is recreated, + # the sub spans are lost and can't be recreated. The span hierarchy will be + # dag span + # |_ task_1 span + # |_ paused_task span + # |_ task_2 span + + dag_id = dag.dag_id + + task_instance_ids = dag.task_ids + task1_id = task_instance_ids[0] + paused_task_id = task_instance_ids[1] + task2_id = task_instance_ids[2] + + # Based on the current tests, only the root span and the task1 span will be recreated. + # TODO: Adjust accordingly, if there are more tests in the future + # that require other spans to be recreated as well. + dag_root_span_name = f"{dag_id}{recreated_suffix}" + + dag_root_span_children_names = [ + f"{task1_id}{recreated_suffix}", + f"{paused_task_id}{recreated_suffix}", + f"{task2_id}", + ] + + task1_span_children_names = [ + f"{task1_id}_sub_span1", + f"{task1_id}_sub_span4", + ] + + # Single element lists. + task1_sub_span1_children_span_names = [f"{task1_id}_sub_span2"] + task1_sub_span2_children_span_names = [f"{task1_id}_sub_span3"] + + assert_span_name_belongs_to_root_span( + root_span_dict=root_span_dict, span_name=dag_root_span_name, should_succeed=True + ) + + # Check direct children of the root span. + assert_parent_children_spans( + parent_child_dict=parent_child_dict, + root_span_dict=root_span_dict, + parent_name=dag_root_span_name, + children_names=dag_root_span_children_names, + ) + + # Use a span name that exists, but it's not a direct child. + assert_span_not_in_children_spans( + parent_child_dict=parent_child_dict, + root_span_dict=root_span_dict, + span_dict=span_dict, + parent_name=dag_root_span_name, + child_name=f"{task1_id}_sub_span1", + span_exists=True, + ) + + # Use a span name that doesn't exist at all. + assert_span_not_in_children_spans( + parent_child_dict=parent_child_dict, + root_span_dict=root_span_dict, + span_dict=span_dict, + parent_name=dag_root_span_name, + child_name=f"{task1_id}_non_existent", + span_exists=False, + ) + + if check_t1_sub_spans: + # Check children of the task1 span. + assert_parent_children_spans_for_non_root( + span_dict=span_dict, + parent_name=f"{task1_id}{recreated_suffix}", + children_names=task1_span_children_names, + ) + + # Check children of task1 sub span1. + assert_parent_children_spans_for_non_root( + span_dict=span_dict, + parent_name=f"{task1_id}_sub_span1", + children_names=task1_sub_span1_children_span_names, + ) + + # Check children of task1 sub span2. + assert_parent_children_spans_for_non_root( + span_dict=span_dict, + parent_name=f"{task1_id}_sub_span2", + children_names=task1_sub_span2_children_span_names, + ) + + +def print_output_for_dag_tis(dag: DAG): + with create_session() as session: + tis: list[TaskInstance] = dag.get_task_instances(session=session) + + for ti in tis: + print_ti_output(ti) + + +def print_ti_output(ti: TaskInstance): + from airflow.utils.log.log_reader import TaskLogReader + + task_log_reader = TaskLogReader() + if task_log_reader.supports_read: + metadata: dict[str, Any] = {} + logs, metadata = task_log_reader.read_log_chunks(ti, ti.try_number, metadata) + if ti.hostname in dict(logs[0]): + output = ( + str(dict(logs[0])[ti.hostname]) + .replace("\\n", "\n") + .replace("{log.py:232} WARNING - {", "\n{") + ) + while metadata["end_of_log"] is False: + logs, metadata = task_log_reader.read_log_chunks(ti, ti.try_number - 1, metadata) + if ti.hostname in dict(logs[0]): + output = output + str(dict(logs[0])[ti.hostname]).replace("\\n", "\n") + # Logging the output is enough for capfd to capture it. + log.info(format(output)) + + +@pytest.mark.integration("redis") +@pytest.mark.backend("postgres") +class TestOtelIntegration: + """ + This test is using a ConsoleSpanExporter so that it can capture + the spans from the stdout and run assertions on them. + + It can also be used with otel and jaeger for manual testing. + To export the spans to otel and visualize them with jaeger, + - start breeze with '--integration otel' + - run on the shell 'export use_otel=true' + - run the test + - check 'http://localhost:36686/' + + To get a db dump on the stdout, run 'export log_level=debug'. + """ + + test_dir = os.path.dirname(os.path.abspath(__file__)) + dag_folder = os.path.join(test_dir, "dags") + + use_otel = os.getenv("use_otel", default="false") + log_level = os.getenv("log_level", default="none") + + celery_command_args = [ + "celery", + "--app", + "airflow.providers.celery.executors.celery_executor.app", + "worker", + "--concurrency", + "1", + "--loglevel", + "INFO", + ] + + scheduler_command_args = [ + "airflow", + "scheduler", + ] + + dags: dict[str, DAG] = {} + + @classmethod + def setup_class(cls): + os.environ["AIRFLOW__TRACES__OTEL_ON"] = "True" + os.environ["AIRFLOW__TRACES__OTEL_HOST"] = "breeze-otel-collector" + os.environ["AIRFLOW__TRACES__OTEL_PORT"] = "4318" + if cls.use_otel != "true": + os.environ["AIRFLOW__TRACES__OTEL_DEBUGGING_ON"] = "True" + + os.environ["AIRFLOW__SCHEDULER__STANDALONE_DAG_PROCESSOR"] = "False" + os.environ["AIRFLOW__SCHEDULER__PROCESSOR_POLL_INTERVAL"] = "2" + + # The heartrate is determined by the conf "AIRFLOW__SCHEDULER__SCHEDULER_HEARTBEAT_SEC". + # By default, the heartrate is 5 seconds. Every iteration of the scheduler loop, checks the + # time passed since the last heartbeat and if it was longer than the 5 second heartrate, + # it performs a heartbeat update. + # If there hasn't been a heartbeat for an amount of time longer than the + # SCHEDULER_HEALTH_CHECK_THRESHOLD, then the scheduler is considered unhealthy. + # Approximately, there is a scheduler heartbeat every 5-6 seconds. Set the threshold to 15. + os.environ["AIRFLOW__SCHEDULER__SCHEDULER_HEALTH_CHECK_THRESHOLD"] = "15" + + os.environ["AIRFLOW__CORE__DAGS_FOLDER"] = f"{cls.dag_folder}" + + os.environ["AIRFLOW__CORE__LOAD_EXAMPLES"] = "False" + os.environ["AIRFLOW__CORE__UNIT_TEST_MODE"] = "False" + + if cls.log_level == "debug": + log.setLevel(logging.DEBUG) + + @classmethod + def serialize_and_get_dags(cls) -> dict[str, DAG]: + log.info("Serializing Dags from directory %s", cls.dag_folder) + # Load DAGs from the dag directory. + dag_bag = DagBag(dag_folder=cls.dag_folder, include_examples=False) + + dag_ids = dag_bag.dag_ids + assert len(dag_ids) == 3 + + dag_dict: dict[str, DAG] = {} + with create_session() as session: + for dag_id in dag_ids: + dag = dag_bag.get_dag(dag_id) + dag_dict[dag_id] = dag + + assert dag is not None, f"DAG with ID {dag_id} not found." + + # Sync the DAG to the database. + dag.sync_to_db(session=session) + # Manually serialize the dag and write it to the db to avoid a db error. + SerializedDagModel.write_dag(dag, session=session) + + session.commit() + + return dag_dict + + @pytest.fixture + def celery_worker_env_vars(self, monkeypatch): + os.environ["AIRFLOW__CORE__EXECUTOR"] = "CeleryExecutor" + executor_name = ExecutorName( + module_path="airflow.providers.celery.executors.celery_executor.CeleryExecutor", + alias="CeleryExecutor", + ) + monkeypatch.setattr(executor_loader, "_alias_to_executors", {"CeleryExecutor": executor_name}) + + @pytest.fixture(autouse=True) + def reset_db(self): + reset_command = ["airflow", "db", "reset", "--yes"] + + # Reset the db using the cli. + subprocess.run(reset_command, check=True, env=os.environ.copy()) + + migrate_command = ["airflow", "db", "migrate"] + subprocess.run(migrate_command, check=True, env=os.environ.copy()) + + self.dags = self.serialize_and_get_dags() + + def test_same_scheduler_processing_the_entire_dag( + self, monkeypatch, celery_worker_env_vars, capfd, session + ): + """The same scheduler will start and finish the dag processing.""" + celery_worker_process = None + scheduler_process = None + try: + # Start the processes here and not as fixtures or in a common setup, + # so that the test can capture their output. + celery_worker_process, scheduler_process = self.start_worker_and_scheduler1() + + dag_id = "otel_test_dag" + + assert len(self.dags) > 0 + dag = self.dags[dag_id] + + assert dag is not None + + run_id = unpause_trigger_dag_and_get_run_id(dag_id=dag_id) + + wait_for_dag_run_and_check_span_status( + dag_id=dag_id, run_id=run_id, max_wait_time=90, span_status=SpanStatus.ENDED + ) + + # The ti span_status is updated while processing the executor events, + # which is after the dag_run state has been updated. + time.sleep(10) + + with create_session() as session: + tis: list[TaskInstance] = dag.get_task_instances(session=session) + + for ti in tis: + check_ti_state_and_span_status( + task_id=ti.task_id, run_id=run_id, state=State.SUCCESS, span_status=SpanStatus.ENDED + ) + print_ti_output(ti) + finally: + if self.log_level == "debug": + with create_session() as session: + dump_airflow_metadata_db(session) + + # Terminate the processes. + celery_worker_process.terminate() + celery_worker_process.wait() + + celery_status = celery_worker_process.poll() + assert ( + celery_status is not None + ), "The celery worker process status is None, which means that it hasn't terminated as expected." + + scheduler_process.terminate() + scheduler_process.wait() + + scheduler_status = scheduler_process.poll() + assert ( + scheduler_status is not None + ), "The scheduler_1 process status is None, which means that it hasn't terminated as expected." + + out, err = capfd.readouterr() + log.info("out-start --\n%s\n-- out-end", out) + log.info("err-start --\n%s\n-- err-end", err) + + if self.use_otel != "true": + # Dag run should have succeeded. Test the spans from the output. + check_spans_without_continuance(output=out, dag=dag) + + def test_scheduler_change_after_the_first_task_finishes( + self, monkeypatch, celery_worker_env_vars, capfd, session + ): + """ + The scheduler thread will be paused after the first task ends and a new scheduler process + will handle the rest of the dag processing. The paused thread will be resumed afterwards. + """ + + celery_worker_process = None + scheduler_process_1 = None + scheduler_process_2 = None + try: + # Start the processes here and not as fixtures or in a common setup, + # so that the test can capture their output. + celery_worker_process, scheduler_process_1 = self.start_worker_and_scheduler1() + + dag_id = "otel_test_dag" + dag = self.dags[dag_id] + + run_id = unpause_trigger_dag_and_get_run_id(dag_id=dag_id) + + with create_session() as session: + tis: list[TaskInstance] = dag.get_task_instances(session=session) + + task_1 = tis[0] + + while True: + with create_session() as session: + ti = ( + session.query(TaskInstance) + .filter( + TaskInstance.task_id == task_1.task_id, + TaskInstance.run_id == task_1.run_id, + ) + .first() + ) + + if ti is None: + continue + + # Wait until the task has been finished. + if ti.state in State.finished: + break + + with capfd.disabled(): + # When the scheduler1 thread is paused, capfd keeps trying to read the + # file descriptors for the process and ends up freezing the test. + # Temporarily disable capfd to avoid that. + scheduler_process_1.send_signal(signal.SIGSTOP) + + scheduler_process_2 = subprocess.Popen( + self.scheduler_command_args, + env=os.environ.copy(), + stdout=None, + stderr=None, + ) + + check_dag_run_state_and_span_status( + dag_id=dag_id, run_id=run_id, state=State.RUNNING, span_status=SpanStatus.ACTIVE + ) + + # Wait for scheduler2 to be up and running. + time.sleep(10) + + wait_for_dag_run_and_check_span_status( + dag_id=dag_id, run_id=run_id, max_wait_time=120, span_status=SpanStatus.SHOULD_END + ) + + scheduler_process_1.send_signal(signal.SIGCONT) + + # Wait for the scheduler to start again and continue running. + time.sleep(10) + + wait_for_dag_run_and_check_span_status( + dag_id=dag_id, run_id=run_id, max_wait_time=30, span_status=SpanStatus.ENDED + ) + + print_output_for_dag_tis(dag=dag) + finally: + if self.log_level == "debug": + with create_session() as session: + dump_airflow_metadata_db(session) + + # Terminate the processes. + celery_worker_process.terminate() + celery_worker_process.wait() + + scheduler_process_1.terminate() + scheduler_process_1.wait() + + scheduler_process_2.terminate() + scheduler_process_2.wait() + + out, err = capfd.readouterr() + log.info("out-start --\n%s\n-- out-end", out) + log.info("err-start --\n%s\n-- err-end", err) + + if self.use_otel != "true": + # Dag run should have succeeded. Test the spans in the output. + check_spans_without_continuance(output=out, dag=dag) + + def test_scheduler_change_in_the_middle_of_first_task_until_the_end( + self, monkeypatch, celery_worker_env_vars, capfd, session + ): + """ + The scheduler that starts the dag run, will be paused and a new scheduler process will handle + the rest of the dag processing. The paused thread will be resumed so that the test + can check that it properly handles the spans. + + A txt file will be used for signaling the test and the dag in order to make sure that + the 1st scheduler is handled accordingly while the first task is executing and that + the 2nd scheduler picks up the task and dag processing. + The steps will be + - The dag starts running, creates the file with a signal word and waits until the word is changed. + - The test checks if the file exist, stops the scheduler, starts a new scheduler and updates the file. + - The dag gets the update and continues until the task is finished. + At this point, the second scheduler should handle the rest of the dag processing. + """ + + celery_worker_process = None + scheduler_process_1 = None + scheduler_process_2 = None + try: + # Start the processes here and not as fixtures or in a common setup, + # so that the test can capture their output. + celery_worker_process, scheduler_process_1 = self.start_worker_and_scheduler1() + + dag_id = "otel_test_dag_with_pause_in_task" + dag = self.dags[dag_id] + + run_id = unpause_trigger_dag_and_get_run_id(dag_id=dag_id) + + # Control file path. + control_file = os.path.join(self.dag_folder, "dag_control.txt") + + while True: + try: + with open(control_file) as file: + file_contents = file.read() + + if "pause" in file_contents: + log.info("Control file exists and the task has been paused.") + break + else: + continue + except FileNotFoundError: + print("Control file not found. Waiting...") + time.sleep(1) + continue + + with capfd.disabled(): + # When the scheduler1 thread is paused, capfd keeps trying to read the + # file descriptors for the process and ends up freezing the test. + # Temporarily disable capfd to avoid that. + scheduler_process_1.send_signal(signal.SIGSTOP) + + scheduler_process_2 = subprocess.Popen( + self.scheduler_command_args, + env=os.environ.copy(), + stdout=None, + stderr=None, + ) + + # Wait for scheduler2 to be up and running. + time.sleep(10) + + check_dag_run_state_and_span_status( + dag_id=dag_id, run_id=run_id, state=State.RUNNING, span_status=SpanStatus.ACTIVE + ) + + # Rewrite the file to unpause the dag. + with open(control_file, "w") as file: + file.write("continue") + + # Scheduler2 should finish processing the dag and set the status + # so that scheduler1 can end the spans when it is resumed. + wait_for_dag_run_and_check_span_status( + dag_id=dag_id, run_id=run_id, max_wait_time=120, span_status=SpanStatus.SHOULD_END + ) + + scheduler_process_1.send_signal(signal.SIGCONT) + + # Wait for the scheduler to start again and continue running. + time.sleep(10) + + # Scheduler1 should end the spans and update the status. + wait_for_dag_run_and_check_span_status( + dag_id=dag_id, run_id=run_id, max_wait_time=30, span_status=SpanStatus.ENDED + ) + + print_output_for_dag_tis(dag=dag) + finally: + if self.log_level == "debug": + with create_session() as session: + dump_airflow_metadata_db(session) + + # Terminate the processes. + celery_worker_process.terminate() + celery_worker_process.wait() + + scheduler_process_1.terminate() + scheduler_process_1.wait() + + scheduler_process_2.terminate() + scheduler_process_2.wait() + + out, err = capfd.readouterr() + log.info("out-start --\n%s\n-- out-end", out) + log.info("err-start --\n%s\n-- err-end", err) + + if self.use_otel != "true": + # Dag run should have succeeded. Test the spans in the output. + check_spans_without_continuance(output=out, dag=dag) + + def test_scheduler_exits_gracefully_in_the_middle_of_the_first_task( + self, monkeypatch, celery_worker_env_vars, capfd, session + ): + """ + The scheduler that starts the dag run will be stopped, while the first task is executing, + and start a new scheduler will be started. That way, the new process will pick up the dag processing. + The initial scheduler will exit gracefully. + """ + + celery_worker_process = None + scheduler_process_1 = None + scheduler_process_2 = None + try: + # Start the processes here and not as fixtures or in a common setup, + # so that the test can capture their output. + celery_worker_process, scheduler_process_1 = self.start_worker_and_scheduler1() + + dag_id = "otel_test_dag_with_pause_in_task" + dag = self.dags[dag_id] + + run_id = unpause_trigger_dag_and_get_run_id(dag_id=dag_id) + + # Control file path. + control_file = os.path.join(self.dag_folder, "dag_control.txt") + + while True: + try: + with open(control_file) as file: + file_contents = file.read() + + if "pause" in file_contents: + log.info("Control file exists and the task has been paused.") + break + else: + continue + except FileNotFoundError: + print("Control file not found. Waiting...") + time.sleep(1) + continue + + # Since, we are past the loop, then the file exists and the dag has been paused. + # Terminate scheduler1 and start scheduler2. + scheduler_process_1.terminate() + + check_dag_run_state_and_span_status( + dag_id=dag_id, run_id=run_id, state=State.RUNNING, span_status=SpanStatus.NEEDS_CONTINUANCE + ) + + scheduler_process_2 = subprocess.Popen( + self.scheduler_command_args, + env=os.environ.copy(), + stdout=None, + stderr=None, + ) + + # Wait for scheduler2 to be up and running. + time.sleep(10) + + # Rewrite the file to unpause the dag. + with open(control_file, "w") as file: + file.write("continue") + + wait_for_dag_run_and_check_span_status( + dag_id=dag_id, run_id=run_id, max_wait_time=120, span_status=SpanStatus.ENDED + ) + + print_output_for_dag_tis(dag=dag) + finally: + if self.log_level == "debug": + with create_session() as session: + dump_airflow_metadata_db(session) + + # Terminate the processes. + celery_worker_process.terminate() + celery_worker_process.wait() + + scheduler_process_1.wait() + + scheduler_process_2.terminate() + scheduler_process_2.wait() + + out, err = capfd.readouterr() + log.info("out-start --\n%s\n-- out-end", out) + log.info("err-start --\n%s\n-- err-end", err) + + if self.use_otel != "true": + # Dag run should have succeeded. Test the spans in the output. + check_spans_with_continuance(output=out, dag=dag) + + def test_scheduler_exits_forcefully_in_the_middle_of_the_first_task( + self, monkeypatch, celery_worker_env_vars, capfd, session + ): + """ + The first scheduler will exit forcefully while the first task is running, + so that it won't have time end any active spans. + """ + + celery_worker_process = None + scheduler_process_2 = None + try: + # Start the processes here and not as fixtures or in a common setup, + # so that the test can capture their output. + celery_worker_process, scheduler_process_1 = self.start_worker_and_scheduler1() + + dag_id = "otel_test_dag_with_pause_in_task" + dag = self.dags[dag_id] + + run_id = unpause_trigger_dag_and_get_run_id(dag_id=dag_id) + + # Control file path. + control_file = os.path.join(self.dag_folder, "dag_control.txt") + + while True: + try: + with open(control_file) as file: + file_contents = file.read() + + if "pause" in file_contents: + log.info("Control file exists and the task has been paused.") + break + else: + continue + except FileNotFoundError: + print("Control file not found. Waiting...") + time.sleep(1) + continue + + # Since, we are past the loop, then the file exists and the dag has been paused. + # Kill scheduler1 and start scheduler2. + scheduler_process_1.send_signal(signal.SIGKILL) + + # The process shouldn't have changed the span_status. + check_dag_run_state_and_span_status( + dag_id=dag_id, run_id=run_id, state=State.RUNNING, span_status=SpanStatus.ACTIVE + ) + + # Wait so that the health threshold passes and scheduler1 is considered unhealthy. + time.sleep(15) + + scheduler_process_2 = subprocess.Popen( + self.scheduler_command_args, + env=os.environ.copy(), + stdout=None, + stderr=None, + ) + + # Wait for scheduler2 to be up and running. + time.sleep(10) + + # Rewrite the file to unpause the dag. + with open(control_file, "w") as file: + file.write("continue") + + wait_for_dag_run_and_check_span_status( + dag_id=dag_id, run_id=run_id, max_wait_time=120, span_status=SpanStatus.ENDED + ) + + print_output_for_dag_tis(dag=dag) + finally: + if self.log_level == "debug": + with create_session() as session: + dump_airflow_metadata_db(session) + + # Terminate the processes. + celery_worker_process.terminate() + celery_worker_process.wait() + + scheduler_process_2.terminate() + scheduler_process_2.wait() + + out, err = capfd.readouterr() + log.info("out-start --\n%s\n-- out-end", out) + log.info("err-start --\n%s\n-- err-end", err) + + if self.use_otel != "true": + # Dag run should have succeeded. Test the spans in the output. + check_spans_without_continuance(output=out, dag=dag, is_recreated=True) + + def test_scheduler_exits_forcefully_after_the_first_task_finishes( + self, monkeypatch, celery_worker_env_vars, capfd, session + ): + """ + The first scheduler will exit forcefully after the first task finishes, + so that it won't have time to end any active spans. + In this scenario, the sub-spans for the first task will be lost. + The only way to retrieve them, would be to re-run the task. + """ + + celery_worker_process = None + scheduler_process_2 = None + try: + # Start the processes here and not as fixtures or in a common setup, + # so that the test can capture their output. + celery_worker_process, scheduler_process_1 = self.start_worker_and_scheduler1() + + dag_id = "otel_test_dag_with_pause" + dag = self.dags[dag_id] + + run_id = unpause_trigger_dag_and_get_run_id(dag_id=dag_id) + + # Control file path. + control_file = os.path.join(self.dag_folder, "dag_control.txt") + + while True: + try: + with open(control_file) as file: + file_contents = file.read() + + if "pause" in file_contents: + log.info("Control file exists and the task has been paused.") + break + else: + continue + except FileNotFoundError: + print("Control file not found. Waiting...") + time.sleep(1) + continue + + # Since, we are past the loop, then the file exists and the dag has been paused. + # Kill scheduler1 and start scheduler2. + scheduler_process_1.send_signal(signal.SIGKILL) + + # The process shouldn't have changed the span_status. + check_dag_run_state_and_span_status( + dag_id=dag_id, run_id=run_id, state=State.RUNNING, span_status=SpanStatus.ACTIVE + ) + + # Rewrite the file to unpause the dag. + with open(control_file, "w") as file: + file.write("continue") + + time.sleep(15) + # The task should be adopted. + + scheduler_process_2 = subprocess.Popen( + self.scheduler_command_args, + env=os.environ.copy(), + stdout=None, + stderr=None, + ) + + # Wait for scheduler2 to be up and running. + time.sleep(10) + + wait_for_dag_run_and_check_span_status( + dag_id=dag_id, run_id=run_id, max_wait_time=120, span_status=SpanStatus.ENDED + ) + + print_output_for_dag_tis(dag=dag) + finally: + if self.log_level == "debug": + with create_session() as session: + dump_airflow_metadata_db(session) + + # Terminate the processes. + celery_worker_process.terminate() + celery_worker_process.wait() + + scheduler_process_2.terminate() + scheduler_process_2.wait() + + out, err = capfd.readouterr() + log.info("out-start --\n%s\n-- out-end", out) + log.info("err-start --\n%s\n-- err-end", err) + + if self.use_otel != "true": + # Dag run should have succeeded. Test the spans in the output. + check_spans_for_paused_dag(output=out, dag=dag, is_recreated=True, check_t1_sub_spans=False) + + def start_worker_and_scheduler1(self): + celery_worker_process = subprocess.Popen( + self.celery_command_args, + env=os.environ.copy(), + stdout=None, + stderr=None, + ) + + scheduler_process = subprocess.Popen( + self.scheduler_command_args, + env=os.environ.copy(), + stdout=None, + stderr=None, + ) + + # Wait to ensure both processes have started. + time.sleep(10) + + return celery_worker_process, scheduler_process diff --git a/tests/integration/otel/test_utils.py b/tests/integration/otel/test_utils.py new file mode 100644 index 0000000000000..8a60768c0b932 --- /dev/null +++ b/tests/integration/otel/test_utils.py @@ -0,0 +1,774 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +import logging +import pprint + +from sqlalchemy import inspect + +from airflow.models import Base + +log = logging.getLogger("integration.otel.test_utils") + + +def dump_airflow_metadata_db(session): + inspector = inspect(session.bind) + all_tables = inspector.get_table_names() + + # dump with the entire db + db_dump = {} + + log.debug("\n-----START_airflow_db_dump-----\n") + + for table_name in all_tables: + log.debug("\nDumping table: %s", table_name) + table = Base.metadata.tables.get(table_name) + if table is not None: + query = session.query(table) + results = [dict(row) for row in query.all()] + db_dump[table_name] = results + # Pretty-print the table contents + if table_name == "connection": + filtered_results = [row for row in results if row.get("conn_id") == "airflow_db"] + pprint.pprint({table_name: filtered_results}, width=120) + else: + pprint.pprint({table_name: results}, width=120) + else: + log.debug("Table %s not found in metadata.", table_name) + + log.debug("\nAirflow metadata database dump complete.") + log.debug("\n-----END_airflow_db_dump-----\n") + + +def extract_spans_from_output(output_lines: list): + """ + For a given list of ConsoleSpanExporter output lines, it extracts the json spans + and creates two dictionaries. + :return: root spans dict (key: root_span_id - value: root_span), spans dict (key: span_id - value: span) + """ + span_dict = {} + root_span_dict = {} + total_lines = len(output_lines) + index = 0 + + while index < total_lines: + line = output_lines[index].strip() + # The start and the end of the json object, don't have any indentation. + # We can use that to identify them. + if line.startswith("{") and line == "{": # Json start. + # Get all the lines and append them until we reach the end. + json_lines = [line] + index += 1 + while index < total_lines: + line = output_lines[index] + # The 'command' line uses single quotes, and it results in an error when parsing the json. + # It's not needed when checking for spans. So instead of formatting it properly, just skip it. + if '"command":' not in line: + json_lines.append(line) + if line.strip().startswith("}") and line == "}": # Json end. + # Since, this is the end of the object, break the loop. + break + index += 1 + # Create a formatted json string and then convert the string to a python dict. + json_str = "\n".join(json_lines) + try: + span = json.loads(json_str) + span_id = span["context"]["span_id"] + span_dict[span_id] = span + + if span["parent_id"] is None: + # This is a root span, add it to the root_span_map as well. + root_span_id = span["context"]["span_id"] + root_span_dict[root_span_id] = span + + except json.JSONDecodeError as e: + log.error("Failed to parse JSON span: %s", e) + log.error("Failed JSON string:") + log.error(json_str) + else: + index += 1 + + return root_span_dict, span_dict + + +def get_id_for_a_given_name(span_dict: dict, span_name: str): + for span_id, span in span_dict.items(): + if span["name"] == span_name: + return span_id + return None + + +def get_parent_child_dict(root_span_dict, span_dict): + """ + Create a dictionary with parent-child span relationships. + :return: key: root_span_id - value: list of child spans + """ + parent_child_dict = {} + for root_span_id, root_span in root_span_dict.items(): + # Compare each 'root_span_id' with each 'parent_id' from the span_dict. + # If there is a match, then the span in the span_dict, is a child. + # For every root span, create a list of child spans. + child_span_list = [] + for span_id, span in span_dict.items(): + if root_span_id == span_id: + # It's the same span, skip. + continue + # If the parent_id matches the root_span_id and if the trace_id is the same. + if ( + span["parent_id"] == root_span_id + and root_span["context"]["trace_id"] == span["context"]["trace_id"] + ): + child_span_list.append(span) + parent_child_dict[root_span_id] = child_span_list + return parent_child_dict + + +def get_child_list_for_non_root(span_dict: dict, span_name: str): + """ + Get a list of children spans for a parent span that isn't also a root span. + e.g. a task span with sub-spans, is a parent span but not a root span. + :return: list of spans + """ + parent_span_id = get_id_for_a_given_name(span_dict=span_dict, span_name=span_name) + parent_span = span_dict.get(parent_span_id) + + if parent_span is None: + return [] + + child_span_list = [] + for span_id, span in span_dict.items(): + if span_id == parent_span_id: + # It's the same span, skip. + continue + if ( + span["parent_id"] == parent_span_id + and span["context"]["trace_id"] == parent_span["context"]["trace_id"] + ): + child_span_list.append(span) + + return child_span_list + + +def assert_parent_name_and_get_id(root_span_dict: dict, span_name: str): + parent_id = get_id_for_a_given_name(root_span_dict, span_name) + + assert parent_id is not None, f"Parent span '{span_name}' wasn't found." + + return parent_id + + +def assert_span_name_belongs_to_root_span(root_span_dict: dict, span_name: str, should_succeed: bool): + """Check that a given span name belongs to a root span.""" + log.info("Checking that '%s' is a root span.", span_name) + # Check if any root span has the specified span_name + name_exists = any(root_span.get("name", None) == span_name for root_span in root_span_dict.values()) + + # Assert based on the should_succeed flag + if should_succeed: + assert name_exists, f"Expected span '{span_name}' to belong to a root span, but it does not." + log.info("Span '%s' belongs to a root span, as expected.", span_name) + else: + assert not name_exists, f"Expected span '{span_name}' not to belong to a root span, but it does." + log.info("Span '%s' doesn't belong to a root span, as expected.", span_name) + + +def assert_parent_children_spans( + parent_child_dict: dict, root_span_dict: dict, parent_name: str, children_names: list[str] +): + """Check that all spans in a given list are children of a given root span name.""" + log.info("Checking that spans '%s' are children of root span '%s'.", children_names, parent_name) + # Iterate the root_span_dict, to get the span_id for the parent_name. + parent_id = assert_parent_name_and_get_id(root_span_dict=root_span_dict, span_name=parent_name) + + # Use the root span_id to get the children ids. + child_span_list = parent_child_dict[parent_id] + + # For each children id, get the entry from the span_dict. + names_from_dict = [] + for child_span in child_span_list: + name = child_span["name"] + names_from_dict.append(name) + + # Assert that all given children names match the names from the dictionary. + for name in children_names: + assert ( + name in names_from_dict + ), f"Span name '{name}' wasn't found in children span names. It's not a child of span '{parent_name}'." + + +def assert_parent_children_spans_for_non_root(span_dict: dict, parent_name: str, children_names: list[str]): + """Check that all spans in a given list are children of a given non-root span name.""" + log.info("Checking that spans '%s' are children of span '%s'.", children_names, parent_name) + child_span_list = get_child_list_for_non_root(span_dict=span_dict, span_name=parent_name) + + # For each children id, get the entry from the span_dict. + names_from_dict = [] + for child_span in child_span_list: + name = child_span["name"] + names_from_dict.append(name) + + # Assert that all given children names match the names from the dictionary. + for name in children_names: + assert ( + name in names_from_dict + ), f"Span name '{name}' wasn't found in children span names. It's not a child of span '{parent_name}'." + + +def assert_span_not_in_children_spans( + parent_child_dict: dict, + root_span_dict: dict, + span_dict: dict, + parent_name: str, + child_name: str, + span_exists: bool, +): + """Check that a span for a given name, doesn't belong to the children of a given root span name.""" + log.info("Checking that span '%s' is not a child of span '%s'.", child_name, parent_name) + # Iterate the root_span_dict, to get the span_id for the parent_name. + parent_id = assert_parent_name_and_get_id(root_span_dict=root_span_dict, span_name=parent_name) + + # Use the root span_id to get the children ids. + child_span_id_list = parent_child_dict[parent_id] + + child_id = get_id_for_a_given_name(span_dict=span_dict, span_name=child_name) + + if span_exists: + assert child_id is not None, f"Span '{child_name}' should exist but it doesn't." + assert ( + child_id not in child_span_id_list + ), f"Span '{child_name}' shouldn't be a child of span '{parent_name}', but it is." + else: + assert child_id is None, f"Span '{child_name}' shouldn't exist but it does." + + +class TestUtilsUnit: + # The method that extracts the spans from the output, + # counts that there is no indentation on the cli, when a span starts and finishes. + example_output = """ +{ + "name": "test_dag", + "context": { + "trace_id": "0x01f441c9c53e793e8808c77939ddbf36", + "span_id": "0x779a3a331684439e", + "trace_state": "[]" + }, + "kind": "SpanKind.INTERNAL", + "parent_id": null, + "start_time": "2024-11-30T14:01:21.738052Z", + "end_time": "2024-11-30T14:01:36.541442Z", + "status": { + "status_code": "UNSET" + }, + "attributes": { + "airflow.category": "DAG runs", + "airflow.dag_run.dag_id": "otel_test_dag_with_pause", + "airflow.dag_run.logical_date": "2024-11-30 14:01:15+00:00", + "airflow.dag_run.run_id": "manual__2024-11-30T14:01:15.333003+00:00", + "airflow.dag_run.queued_at": "2024-11-30 14:01:21.738052+00:00", + "airflow.dag_run.run_start_date": "2024-11-30 14:01:22.192655+00:00", + "airflow.dag_run.run_end_date": "2024-11-30 14:01:36.541442+00:00", + "airflow.dag_run.run_duration": "14.348787", + "airflow.dag_run.state": "success", + "airflow.dag_run.external_trigger": "True", + "airflow.dag_run.run_type": "manual", + "airflow.dag_run.data_interval_start": "2024-11-30 14:01:15+00:00", + "airflow.dag_run.data_interval_end": "2024-11-30 14:01:15+00:00", + "airflow.dag_version.version": "2", + "airflow.dag_run.conf": "{}" + }, + "events": [ + { + "name": "airflow.dag_run.queued", + "timestamp": "2024-11-30T14:01:21.738052Z", + "attributes": {} + }, + { + "name": "airflow.dag_run.started", + "timestamp": "2024-11-30T14:01:22.192655Z", + "attributes": {} + }, + { + "name": "airflow.dag_run.ended", + "timestamp": "2024-11-30T14:01:36.541442Z", + "attributes": {} + } + ], + "links": [], + "resource": { + "attributes": { + "telemetry.sdk.language": "python", + "telemetry.sdk.name": "opentelemetry", + "telemetry.sdk.version": "1.27.0", + "host.name": "351295342ba2", + "service.name": "Airflow" + }, + "schema_url": "" + } +} +{ + "name": "task_1", + "context": { + "trace_id": "0x01f441c9c53e793e8808c77939ddbf36", + "span_id": "0xba9f48dcfac5d40a", + "trace_state": "[]" + }, + "kind": "SpanKind.INTERNAL", + "parent_id": "0x779a3a331684439e", + "start_time": "2024-11-30T14:01:22.220785Z", + "end_time": "2024-11-30T14:01:34.339423Z", + "status": { + "status_code": "UNSET" + }, + "attributes": { + "airflow.category": "scheduler", + "airflow.task.task_id": "task_1", + "airflow.task.dag_id": "otel_test_dag_with_pause", + "airflow.task.state": "success", + "airflow.task.start_date": "2024-11-30 14:01:23.468047+00:00", + "airflow.task.end_date": "2024-11-30 14:01:34.339423+00:00", + "airflow.task.duration": 10.871376, + "airflow.task.executor_config": "{}", + "airflow.task.logical_date": "2024-11-30 14:01:15+00:00", + "airflow.task.hostname": "351295342ba2", + "airflow.task.log_url": "http://localhost:8080/dags/otel_test_dag_with_pause/grid?dag_run_id=manual__2024-11-30T14%3A01%3A15.333003%2B00%3A00&task_id=task_1&base_date=2024-11-30T14%3A01%3A15%2B0000&tab=logs", + "airflow.task.operator": "PythonOperator", + "airflow.task.try_number": 1, + "airflow.task.executor_state": "success", + "airflow.task.pool": "default_pool", + "airflow.task.queue": "default", + "airflow.task.priority_weight": 2, + "airflow.task.queued_dttm": "2024-11-30 14:01:22.216965+00:00", + "airflow.task.queued_by_job_id": 1, + "airflow.task.pid": 1748 + }, + "events": [ + { + "name": "task to trigger", + "timestamp": "2024-11-30T14:01:22.220873Z", + "attributes": { + "command": "['airflow', 'tasks', 'run', 'otel_test_dag_with_pause', 'task_1', 'manual__2024-11-30T14:01:15.333003+00:00', '--local', '--subdir', 'DAGS_FOLDER/otel_test_dag_with_pause.py', '--carrier', '{\"traceparent\": \"00-01f441c9c53e793e8808c77939ddbf36-ba9f48dcfac5d40a-01\"}']", + "conf": "{}" + } + }, + { + "name": "airflow.task.queued", + "timestamp": "2024-11-30T14:01:22.216965Z", + "attributes": {} + }, + { + "name": "airflow.task.started", + "timestamp": "2024-11-30T14:01:23.468047Z", + "attributes": {} + }, + { + "name": "airflow.task.ended", + "timestamp": "2024-11-30T14:01:34.339423Z", + "attributes": {} + } + ], + "links": [ + { + "context": { + "trace_id": "0x01f441c9c53e793e8808c77939ddbf36", + "span_id": "0x779a3a331684439e", + "trace_state": "[]" + }, + "attributes": { + "meta.annotation_type": "link", + "from": "parenttrace" + } + } + ], + "resource": { + "attributes": { + "telemetry.sdk.language": "python", + "telemetry.sdk.name": "opentelemetry", + "telemetry.sdk.version": "1.27.0", + "host.name": "351295342ba2", + "service.name": "Airflow" + }, + "schema_url": "" + } +} +{ + "name": "start_new_processes", + "context": { + "trace_id": "0x3f6d11237d2b2b8cb987e7ec923a4dc4", + "span_id": "0x0b133494760fa56d", + "trace_state": "[]" + }, + "kind": "SpanKind.INTERNAL", + "parent_id": "0xcf656e5db2b777be", + "start_time": "2024-11-30T14:01:29.316313Z", + "end_time": "2024-11-30T14:01:29.316397Z", + "status": { + "status_code": "UNSET" + }, + "attributes": {}, + "events": [], + "links": [], + "resource": { + "attributes": { + "telemetry.sdk.language": "python", + "telemetry.sdk.name": "opentelemetry", + "telemetry.sdk.version": "1.27.0", + "host.name": "351295342ba2", + "service.name": "Airflow" + }, + "schema_url": "" + } +} +{ + "name": "task_2", + "context": { + "trace_id": "0x01f441c9c53e793e8808c77939ddbf36", + "span_id": "0xe573c104743b6d34", + "trace_state": "[]" + }, + "kind": "SpanKind.INTERNAL", + "parent_id": "0x779a3a331684439e", + "start_time": "2024-11-30T14:01:34.698666Z", + "end_time": "2024-11-30T14:01:36.002687Z", + "status": { + "status_code": "UNSET" + }, + "attributes": { + "airflow.category": "scheduler", + "airflow.task.task_id": "task_2", + "airflow.task.dag_id": "otel_test_dag_with_pause", + "airflow.task.state": "success", + "airflow.task.start_date": "2024-11-30 14:01:35.872318+00:00", + "airflow.task.end_date": "2024-11-30 14:01:36.002687+00:00", + "airflow.task.duration": 0.130369, + "airflow.task.executor_config": "{}", + "airflow.task.logical_date": "2024-11-30 14:01:15+00:00", + "airflow.task.hostname": "351295342ba2", + "airflow.task.log_url": "http://localhost:8080/dags/otel_test_dag_with_pause/grid?dag_run_id=manual__2024-11-30T14%3A01%3A15.333003%2B00%3A00&task_id=task_2&base_date=2024-11-30T14%3A01%3A15%2B0000&tab=logs", + "airflow.task.operator": "PythonOperator", + "airflow.task.try_number": 1, + "airflow.task.executor_state": "success", + "airflow.task.pool": "default_pool", + "airflow.task.queue": "default", + "airflow.task.priority_weight": 1, + "airflow.task.queued_dttm": "2024-11-30 14:01:34.694842+00:00", + "airflow.task.queued_by_job_id": 3, + "airflow.task.pid": 1950 + }, + "events": [ + { + "name": "task to trigger", + "timestamp": "2024-11-30T14:01:34.698810Z", + "attributes": { + "command": "['airflow', 'tasks', 'run', 'otel_test_dag_with_pause', 'task_2', 'manual__2024-11-30T14:01:15.333003+00:00', '--local', '--subdir', 'DAGS_FOLDER/otel_test_dag_with_pause.py', '--carrier', '{\"traceparent\": \"00-01f441c9c53e793e8808c77939ddbf36-e573c104743b6d34-01\"}']", + "conf": "{}" + } + }, + { + "name": "airflow.task.queued", + "timestamp": "2024-11-30T14:01:34.694842Z", + "attributes": {} + }, + { + "name": "airflow.task.started", + "timestamp": "2024-11-30T14:01:35.872318Z", + "attributes": {} + }, + { + "name": "airflow.task.ended", + "timestamp": "2024-11-30T14:01:36.002687Z", + "attributes": {} + } + ], + "links": [ + { + "context": { + "trace_id": "0x01f441c9c53e793e8808c77939ddbf36", + "span_id": "0x779a3a331684439e", + "trace_state": "[]" + }, + "attributes": { + "meta.annotation_type": "link", + "from": "parenttrace" + } + } + ], + "resource": { + "attributes": { + "telemetry.sdk.language": "python", + "telemetry.sdk.name": "opentelemetry", + "telemetry.sdk.version": "1.27.0", + "host.name": "351295342ba2", + "service.name": "Airflow" + }, + "schema_url": "" + } +} +{ + "name": "task_1_sub_span", + "context": { + "trace_id": "0x01f441c9c53e793e8808c77939ddbf36", + "span_id": "0x7fc9e2289c7df4b8", + "trace_state": "[]" + }, + "kind": "SpanKind.INTERNAL", + "parent_id": "0xba9f48dcfac5d40a", + "start_time": "2024-11-30T14:01:34.321996Z", + "end_time": "2024-11-30T14:01:34.324249Z", + "status": { + "status_code": "UNSET" + }, + "attributes": { + "attr1": "val1" + }, + "events": [], + "links": [ + { + "context": { + "trace_id": "0x01f441c9c53e793e8808c77939ddbf36", + "span_id": "0xba9f48dcfac5d40a", + "trace_state": "[]" + }, + "attributes": { + "meta.annotation_type": "link", + "from": "parenttrace" + } + } + ], + "resource": { + "attributes": { + "telemetry.sdk.language": "python", + "telemetry.sdk.name": "opentelemetry", + "telemetry.sdk.version": "1.27.0", + "host.name": "351295342ba2", + "service.name": "Airflow" + }, + "schema_url": "" + } +} +{ + "name": "emit_metrics", + "context": { + "trace_id": "0x3f6d11237d2b2b8cb987e7ec923a4dc4", + "span_id": "0xa19a88e8dac9645b", + "trace_state": "[]" + }, + "kind": "SpanKind.INTERNAL", + "parent_id": "0xcf656e5db2b777be", + "start_time": "2024-11-30T14:01:29.315255Z", + "end_time": "2024-11-30T14:01:29.315290Z", + "status": { + "status_code": "UNSET" + }, + "attributes": { + "total_parse_time": 0.9342440839973278, + "dag_bag_size": 2, + "import_errors": 0 + }, + "events": [], + "links": [], + "resource": { + "attributes": { + "telemetry.sdk.language": "python", + "telemetry.sdk.name": "opentelemetry", + "telemetry.sdk.version": "1.27.0", + "host.name": "351295342ba2", + "service.name": "Airflow" + }, + "schema_url": "" + } +} +{ + "name": "dag_parsing_loop", + "context": { + "trace_id": "0x3f6d11237d2b2b8cb987e7ec923a4dc4", + "span_id": "0xcf656e5db2b777be", + "trace_state": "[]" + }, + "kind": "SpanKind.INTERNAL", + "parent_id": null, + "start_time": "2024-11-30T14:01:28.382690Z", + "end_time": "2024-11-30T14:01:29.316499Z", + "status": { + "status_code": "UNSET" + }, + "attributes": {}, + "events": [ + { + "name": "heartbeat", + "timestamp": "2024-11-30T14:01:29.313549Z", + "attributes": {} + }, + { + "name": "_kill_timed_out_processors", + "timestamp": "2024-11-30T14:01:29.314763Z", + "attributes": {} + }, + { + "name": "prepare_file_path_queue", + "timestamp": "2024-11-30T14:01:29.315300Z", + "attributes": {} + }, + { + "name": "start_new_processes", + "timestamp": "2024-11-30T14:01:29.315941Z", + "attributes": {} + }, + { + "name": "collect_results", + "timestamp": "2024-11-30T14:01:29.316409Z", + "attributes": {} + }, + { + "name": "print_stat", + "timestamp": "2024-11-30T14:01:29.316432Z", + "attributes": {} + } + ], + "links": [], + "resource": { + "attributes": { + "telemetry.sdk.language": "python", + "telemetry.sdk.name": "opentelemetry", + "telemetry.sdk.version": "1.27.0", + "host.name": "351295342ba2", + "service.name": "Airflow" + }, + "schema_url": "" + } +} + """ + + # In the example output, there are two parent child relationships. + # + # test_dag + # |_ task_1 + # |_ task_1_sub_span + # |_ task_2 + # + # dag_parsing_loop + # |_ emit_metrics + # |_ start_new_processes + + def test_extract_spans_from_output(self): + output_lines = self.example_output.splitlines() + root_span_dict, span_dict = extract_spans_from_output(output_lines) + + assert len(root_span_dict) == 2 + assert len(span_dict) == 7 + + expected_root_span_names = ["test_dag", "dag_parsing_loop"] + actual_root_span_names = [] + for key, value in root_span_dict.items(): + assert key == value["context"]["span_id"] + assert value["parent_id"] is None + actual_root_span_names.append(value["name"]) + + assert sorted(actual_root_span_names) == sorted(expected_root_span_names) + + expected_span_names = [ + "test_dag", + "task_1", + "task_1_sub_span", + "task_2", + "dag_parsing_loop", + "emit_metrics", + "start_new_processes", + ] + actual_span_names = [] + for key, value in span_dict.items(): + assert key == value["context"]["span_id"] + actual_span_names.append(value["name"]) + + assert sorted(actual_span_names) == sorted(expected_span_names) + + def test_get_id_for_a_given_name(self): + output_lines = self.example_output.splitlines() + root_span_dict, span_dict = extract_spans_from_output(output_lines) + + span_name_to_test = "test_dag" + + span_id = get_id_for_a_given_name(span_dict, span_name_to_test) + + # Get the id from the two dictionaries, and then cross-reference the name. + span_from_root_dict = root_span_dict.get(span_id) + span_from_dict = span_dict.get(span_id) + + assert span_from_root_dict is not None + assert span_from_dict is not None + + assert span_name_to_test == span_from_root_dict["name"] + assert span_name_to_test == span_from_dict["name"] + + def test_get_parent_child_dict(self): + output_lines = self.example_output.splitlines() + root_span_dict, span_dict = extract_spans_from_output(output_lines) + + parent_child_dict = get_parent_child_dict(root_span_dict, span_dict) + + # There are two root spans. The dictionary should also have length equal to two. + assert len(parent_child_dict) == 2 + + assert sorted(root_span_dict.keys()) == sorted(parent_child_dict.keys()) + + for root_span_id, child_spans in parent_child_dict.items(): + # Both root spans have two direct child spans. + assert len(child_spans) == 2 + + root_span = root_span_dict.get(root_span_id) + root_span_trace_id = root_span["context"]["trace_id"] + + expected_child_span_names = [] + if root_span["name"] == "test_dag": + expected_child_span_names.extend(["task_1", "task_2"]) + elif root_span["name"] == "dag_parsing_loop": + expected_child_span_names.extend(["emit_metrics", "start_new_processes"]) + + actual_child_span_names = [] + + for child_span in child_spans: + # root_span_id should be the parent. + assert root_span_id == child_span["parent_id"] + # all spans should have the same trace_id. + assert root_span_trace_id == child_span["context"]["trace_id"] + actual_child_span_names.append(child_span["name"]) + + assert sorted(actual_child_span_names) == sorted(expected_child_span_names) + + def test_get_child_list_for_non_root(self): + output_lines = self.example_output.splitlines() + root_span_dict, span_dict = extract_spans_from_output(output_lines) + + span_name_to_test = "task_1" + span_id = get_id_for_a_given_name(span_dict, span_name_to_test) + + assert span_name_to_test == span_dict.get(span_id)["name"] + + # The span isn't a root span. + assert span_id not in root_span_dict.keys() + assert span_id in span_dict.keys() + + expected_child_span_names = ["task_1_sub_span"] + actual_child_span_names = [] + + task_1_child_spans = get_child_list_for_non_root(span_dict, "task_1") + + for span in task_1_child_spans: + actual_child_span_names.append(span["name"]) + + assert sorted(actual_child_span_names) == sorted(expected_child_span_names) diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 0ea27f97bb81d..52a33df58e62c 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -68,9 +68,12 @@ from airflow.sdk.definitions.asset import Asset from airflow.serialization.serialized_objects import SerializedDAG from airflow.timetables.base import DataInterval +from airflow.traces.tracer import Trace from airflow.utils import timezone from airflow.utils.session import create_session, provide_session +from airflow.utils.span_status import SpanStatus from airflow.utils.state import DagRunState, State, TaskInstanceState +from airflow.utils.thread_safe_dict import ThreadSafeDict from airflow.utils.types import DagRunType from tests.listeners import dag_listener @@ -2447,6 +2450,194 @@ def test_runs_are_created_after_max_active_runs_was_reached(self, dag_maker, ses dag_runs = DagRun.find(dag_id=dag.dag_id, session=session) assert len(dag_runs) == 2 + @pytest.mark.parametrize( + "ti_state, final_ti_span_status", + [(State.SUCCESS, SpanStatus.ENDED), (State.RUNNING, SpanStatus.ACTIVE)], + ) + def test_recreate_unhealthy_scheduler_spans_if_needed(self, ti_state, final_ti_span_status, dag_maker): + with dag_maker( + dag_id="test_recreate_unhealthy_scheduler_spans_if_needed", + start_date=DEFAULT_DATE, + max_active_runs=1, + dagrun_timeout=datetime.timedelta(seconds=60), + ): + EmptyOperator(task_id="dummy") + + session = settings.Session() + + old_job = Job() + old_job.id = 1 + old_job.job_type = SchedulerJobRunner.job_type + + session.add(old_job) + session.commit() + + assert old_job.is_alive() is False + + new_job = Job() + new_job.id = 2 + new_job.job_type = SchedulerJobRunner.job_type + + self.job_runner = SchedulerJobRunner(job=new_job, subdir=os.devnull) + self.job_runner.active_spans = ThreadSafeDict() + assert len(self.job_runner.active_spans.get_all()) == 0 + + dr = dag_maker.create_dagrun(external_trigger=True) + dr.state = State.RUNNING + dr.span_status = SpanStatus.ACTIVE + dr.scheduled_by_job_id = old_job.id + + ti = dr.get_task_instances(session=session)[0] + ti.state = ti_state + ti.start_date = timezone.utcnow() + ti.span_status = SpanStatus.ACTIVE + ti.queued_by_job_id = old_job.id + session.merge(ti) + session.merge(dr) + session.commit() + + # Given + assert dr.scheduled_by_job_id != self.job_runner.job.id + assert dr.scheduled_by_job_id == old_job.id + assert dr.run_id is not None + assert dr.state == State.RUNNING + assert dr.span_status == SpanStatus.ACTIVE + assert self.job_runner.active_spans.get(dr.run_id) is None + + assert self.job_runner.active_spans.get(ti.key) is None + assert ti.state == ti_state + assert ti.span_status == SpanStatus.ACTIVE + + # When + self.job_runner._recreate_unhealthy_scheduler_spans_if_needed(dr, session) + + # Then + assert self.job_runner.active_spans.get(dr.run_id) is not None + + if final_ti_span_status == SpanStatus.ACTIVE: + assert self.job_runner.active_spans.get(ti.key) is not None + assert len(self.job_runner.active_spans.get_all()) == 2 + else: + assert self.job_runner.active_spans.get(ti.key) is None + assert len(self.job_runner.active_spans.get_all()) == 1 + + assert dr.span_status == SpanStatus.ACTIVE + assert ti.span_status == final_ti_span_status + + def test_end_spans_of_externally_ended_ops(self, dag_maker): + with dag_maker( + dag_id="test_end_spans_of_externally_ended_ops", + start_date=DEFAULT_DATE, + max_active_runs=1, + dagrun_timeout=datetime.timedelta(seconds=60), + ): + EmptyOperator(task_id="dummy") + + session = settings.Session() + + job = Job() + job.id = 1 + job.job_type = SchedulerJobRunner.job_type + + self.job_runner = SchedulerJobRunner(job=job, subdir=os.devnull) + self.job_runner.active_spans = ThreadSafeDict() + assert len(self.job_runner.active_spans.get_all()) == 0 + + dr = dag_maker.create_dagrun(external_trigger=True) + dr.state = State.SUCCESS + dr.span_status = SpanStatus.SHOULD_END + + ti = dr.get_task_instances(session=session)[0] + ti.state = State.SUCCESS + ti.span_status = SpanStatus.SHOULD_END + ti.context_carrier = {} + session.merge(ti) + session.merge(dr) + session.commit() + + dr_span = Trace.start_root_span(span_name="dag_run_span", start_as_current=False) + ti_span = Trace.start_child_span(span_name="ti_span", start_as_current=False) + + self.job_runner.active_spans.set(dr.run_id, dr_span) + self.job_runner.active_spans.set(ti.key, ti_span) + + # Given + assert dr.span_status == SpanStatus.SHOULD_END + assert ti.span_status == SpanStatus.SHOULD_END + + assert self.job_runner.active_spans.get(dr.run_id) is not None + assert self.job_runner.active_spans.get(ti.key) is not None + + # When + self.job_runner._end_spans_of_externally_ended_ops(session) + + # Then + assert dr.span_status == SpanStatus.ENDED + assert ti.span_status == SpanStatus.ENDED + + assert self.job_runner.active_spans.get(dr.run_id) is None + assert self.job_runner.active_spans.get(ti.key) is None + + @pytest.mark.parametrize( + "state, final_span_status", + [(State.SUCCESS, SpanStatus.ENDED), (State.RUNNING, SpanStatus.NEEDS_CONTINUANCE)], + ) + def test_end_active_spans(self, state, final_span_status, dag_maker): + with dag_maker( + dag_id="test_end_active_spans", + start_date=DEFAULT_DATE, + max_active_runs=1, + dagrun_timeout=datetime.timedelta(seconds=60), + ): + EmptyOperator(task_id="dummy") + + session = settings.Session() + + job = Job() + job.id = 1 + job.job_type = SchedulerJobRunner.job_type + + self.job_runner = SchedulerJobRunner(job=job, subdir=os.devnull) + self.job_runner.active_spans = ThreadSafeDict() + assert len(self.job_runner.active_spans.get_all()) == 0 + + dr = dag_maker.create_dagrun(external_trigger=True) + dr.state = state + dr.span_status = SpanStatus.ACTIVE + + ti = dr.get_task_instances(session=session)[0] + ti.state = state + ti.span_status = SpanStatus.ACTIVE + ti.context_carrier = {} + session.merge(ti) + session.merge(dr) + session.commit() + + dr_span = Trace.start_root_span(span_name="dag_run_span", start_as_current=False) + ti_span = Trace.start_child_span(span_name="ti_span", start_as_current=False) + + self.job_runner.active_spans.set(dr.run_id, dr_span) + self.job_runner.active_spans.set(ti.key, ti_span) + + # Given + assert dr.span_status == SpanStatus.ACTIVE + assert ti.span_status == SpanStatus.ACTIVE + + assert self.job_runner.active_spans.get(dr.run_id) is not None + assert self.job_runner.active_spans.get(ti.key) is not None + assert len(self.job_runner.active_spans.get_all()) == 2 + + # When + self.job_runner._end_active_spans(session) + + # Then + assert dr.span_status == final_span_status + assert ti.span_status == final_span_status + + assert self.job_runner.active_spans.get(dr.run_id) is None + assert self.job_runner.active_spans.get(ti.key) is None + assert len(self.job_runner.active_spans.get_all()) == 0 + def test_dagrun_timeout_verify_max_active_runs(self, dag_maker): """ Test if a a dagrun will not be scheduled if max_dag_runs diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 96c06d279bc3b..9b9a997930a9f 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -32,6 +32,7 @@ from airflow.decorators import setup, task, task_group, teardown from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG, DagModel +from airflow.models.dag_version import DagVersion from airflow.models.dagrun import DagRun, DagRunNote from airflow.models.taskinstance import TaskInstance, TaskInstanceNote, clear_task_instances from airflow.models.taskmap import TaskMap @@ -43,7 +44,9 @@ from airflow.stats import Stats from airflow.triggers.base import StartTriggerArgs from airflow.utils import timezone +from airflow.utils.span_status import SpanStatus from airflow.utils.state import DagRunState, State, TaskInstanceState +from airflow.utils.thread_safe_dict import ThreadSafeDict from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import DagRunType @@ -464,6 +467,162 @@ def test_on_success_callback_when_task_skipped(self, session): assert dag_run.state == DagRunState.SUCCESS mock_on_success.assert_called_once() + def test_start_dr_spans_if_needed_new_span(self, testing_dag_bundle, session): + dag = DAG( + dag_id="test_start_dr_spans_if_needed_new_span", + schedule=datetime.timedelta(days=1), + start_date=datetime.datetime(2017, 1, 1), + ) + DAG.bulk_write_to_db("testing", None, dags=[dag], session=session) + + dag_task1 = EmptyOperator(task_id="test_task1", dag=dag) + dag_task2 = EmptyOperator(task_id="test_task2", dag=dag) + dag_task1.set_downstream(dag_task2) + + initial_task_states = { + "test_task1": TaskInstanceState.QUEUED, + "test_task2": TaskInstanceState.QUEUED, + } + + # Scheduler uses Serialized DAG -- so use that instead of the Actual DAG + dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + + dag_run = self.create_dag_run(dag=dag, task_states=initial_task_states, session=session) + + active_spans = ThreadSafeDict() + dag_run.set_active_spans(active_spans) + + tis = dag_run.get_task_instances() + + assert dag_run.active_spans is not None + assert dag_run.active_spans.get(dag_run.run_id) is None + assert dag_run.span_status == SpanStatus.NOT_STARTED + + dag_run.start_dr_spans_if_needed(tis=tis) + + assert dag_run.span_status == SpanStatus.ACTIVE + assert dag_run.active_spans.get(dag_run.run_id) is not None + + def test_start_dr_spans_if_needed_span_with_continuance(self, testing_dag_bundle, session): + dag = DAG( + dag_id="test_start_dr_spans_if_needed_span_with_continuance", + schedule=datetime.timedelta(days=1), + start_date=datetime.datetime(2017, 1, 1), + ) + DAG.bulk_write_to_db("testing", None, dags=[dag], session=session) + + dag_task1 = EmptyOperator(task_id="test_task1", dag=dag) + dag_task2 = EmptyOperator(task_id="test_task2", dag=dag) + dag_task1.set_downstream(dag_task2) + + initial_task_states = { + "test_task1": TaskInstanceState.RUNNING, + "test_task2": TaskInstanceState.QUEUED, + } + + # Scheduler uses Serialized DAG -- so use that instead of the Actual DAG + dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + + dag_run = self.create_dag_run(dag=dag, task_states=initial_task_states, session=session) + + active_spans = ThreadSafeDict() + dag_run.set_active_spans(active_spans) + + dag_run.span_status = SpanStatus.NEEDS_CONTINUANCE + + tis = dag_run.get_task_instances() + + first_ti = tis[0] + first_ti.span_status = SpanStatus.NEEDS_CONTINUANCE + + assert dag_run.active_spans is not None + assert dag_run.active_spans.get(dag_run.run_id) is None + assert dag_run.active_spans.get(first_ti.key) is None + assert dag_run.span_status == SpanStatus.NEEDS_CONTINUANCE + assert first_ti.span_status == SpanStatus.NEEDS_CONTINUANCE + + dag_run.start_dr_spans_if_needed(tis=tis) + + assert dag_run.span_status == SpanStatus.ACTIVE + assert first_ti.span_status == SpanStatus.ACTIVE + assert dag_run.active_spans.get(dag_run.run_id) is not None + assert dag_run.active_spans.get(first_ti.key) is not None + + def test_end_dr_span_if_needed(self, testing_dag_bundle, session): + dag = DAG( + dag_id="test_end_dr_span_if_needed", + schedule=datetime.timedelta(days=1), + start_date=datetime.datetime(2017, 1, 1), + ) + DAG.bulk_write_to_db("testing", None, dags=[dag], session=session) + + dag_task1 = EmptyOperator(task_id="test_task1", dag=dag) + dag_task2 = EmptyOperator(task_id="test_task2", dag=dag) + dag_task1.set_downstream(dag_task2) + + initial_task_states = { + "test_task1": TaskInstanceState.SUCCESS, + "test_task2": TaskInstanceState.SUCCESS, + } + + # Scheduler uses Serialized DAG -- so use that instead of the Actual DAG + dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + + dag_run = self.create_dag_run(dag=dag, task_states=initial_task_states, session=session) + + active_spans = ThreadSafeDict() + dag_run.set_active_spans(active_spans) + + from airflow.traces.tracer import Trace + + dr_span = Trace.start_root_span(span_name="test_span", start_as_current=False) + + active_spans.set(dag_run.run_id, dr_span) + + assert dag_run.active_spans is not None + assert dag_run.active_spans.get(dag_run.run_id) is not None + + dag_version = DagVersion.get_latest_version(dag.dag_id) + dag_run.end_dr_span_if_needed(dagv=dag_version) + + assert dag_run.span_status == SpanStatus.ENDED + assert dag_run.active_spans.get(dag_run.run_id) is None + + def test_end_dr_span_if_needed_with_span_from_another_scheduler(self, testing_dag_bundle, session): + dag = DAG( + dag_id="test_end_dr_span_if_needed_with_span_from_another_scheduler", + schedule=datetime.timedelta(days=1), + start_date=datetime.datetime(2017, 1, 1), + ) + DAG.bulk_write_to_db("testing", None, dags=[dag], session=session) + + dag_task1 = EmptyOperator(task_id="test_task1", dag=dag) + dag_task2 = EmptyOperator(task_id="test_task2", dag=dag) + dag_task1.set_downstream(dag_task2) + + initial_task_states = { + "test_task1": TaskInstanceState.SUCCESS, + "test_task2": TaskInstanceState.SUCCESS, + } + + # Scheduler uses Serialized DAG -- so use that instead of the Actual DAG + dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + + dag_run = self.create_dag_run(dag=dag, task_states=initial_task_states, session=session) + + active_spans = ThreadSafeDict() + dag_run.set_active_spans(active_spans) + + dag_run.span_status = SpanStatus.ACTIVE + + assert dag_run.active_spans is not None + assert dag_run.active_spans.get(dag_run.run_id) is None + + dag_version = DagVersion.get_latest_version(dag.dag_id) + dag_run.end_dr_span_if_needed(dagv=dag_version) + + assert dag_run.span_status == SpanStatus.SHOULD_END + def test_dagrun_update_state_with_handle_callback_success(self, testing_dag_bundle, session): def on_success_callable(context): assert context["dag_run"].dag_id == "test_dagrun_update_state_with_handle_callback_success" diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 29506c6676118..5a53dcec0c52c 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -93,6 +93,7 @@ from airflow.utils.db import merge_conn from airflow.utils.module_loading import qualname from airflow.utils.session import create_session, provide_session +from airflow.utils.span_status import SpanStatus from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.task_group import TaskGroup from airflow.utils.task_instance_session import set_current_task_instance_session @@ -4010,6 +4011,8 @@ def test_refresh_from_db(self, create_task_instance): "updated_at": None, "task_display_name": "Test Refresh from DB Task", "dag_version_id": None, + "context_carrier": {}, + "span_status": SpanStatus.ENDED, } # Make sure we aren't missing any new value in our expected_values list. expected_keys = {f"task_instance.{key}" for key in expected_values} diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py index 44c4316058171..f82c093252d9d 100644 --- a/tests/www/views/test_views_tasks.py +++ b/tests/www/views/test_views_tasks.py @@ -1106,6 +1106,8 @@ def test_task_instances(admin_client): "unixname": getuser(), "updated_at": DEFAULT_DATE.isoformat(), "dag_version_id": None, + "context_carrier": None, + "span_status": "not_started", }, "run_after_loop": { "custom_operator_name": None, @@ -1143,6 +1145,8 @@ def test_task_instances(admin_client): "unixname": getuser(), "updated_at": DEFAULT_DATE.isoformat(), "dag_version_id": None, + "context_carrier": None, + "span_status": "not_started", }, "run_this_last": { "custom_operator_name": None, @@ -1180,6 +1184,8 @@ def test_task_instances(admin_client): "unixname": getuser(), "updated_at": DEFAULT_DATE.isoformat(), "dag_version_id": None, + "context_carrier": None, + "span_status": "not_started", }, "runme_0": { "custom_operator_name": None, @@ -1217,6 +1223,8 @@ def test_task_instances(admin_client): "unixname": getuser(), "updated_at": DEFAULT_DATE.isoformat(), "dag_version_id": None, + "context_carrier": None, + "span_status": "not_started", }, "runme_1": { "custom_operator_name": None, @@ -1254,6 +1262,8 @@ def test_task_instances(admin_client): "unixname": getuser(), "updated_at": DEFAULT_DATE.isoformat(), "dag_version_id": None, + "context_carrier": None, + "span_status": "not_started", }, "runme_2": { "custom_operator_name": None, @@ -1291,6 +1301,8 @@ def test_task_instances(admin_client): "unixname": getuser(), "updated_at": DEFAULT_DATE.isoformat(), "dag_version_id": None, + "context_carrier": None, + "span_status": "not_started", }, "this_will_skip": { "custom_operator_name": None, @@ -1328,5 +1340,7 @@ def test_task_instances(admin_client): "unixname": getuser(), "updated_at": DEFAULT_DATE.isoformat(), "dag_version_id": None, + "context_carrier": None, + "span_status": "not_started", }, }