Skip to content

Commit

Permalink
Add bundle_name to ParseImportError
Browse files Browse the repository at this point in the history
This PR adds bundle_name to ParseImportError. Future work would make
the filename relative to the bundle path and that means we need to include
bundle_name as part of the ParseImportError so that if two DAG files are
having the same filename, we could differentiate them by the bundle they belong.
  • Loading branch information
ephraimbuddy committed Jan 13, 2025
1 parent 90eae56 commit b1a8ced
Show file tree
Hide file tree
Showing 15 changed files with 2,141 additions and 2,026 deletions.
5 changes: 4 additions & 1 deletion airflow/api/common/delete_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ def delete_dag(dag_id: str, keep_records_in_log: bool = True, session: Session =
# This handles the case when the dag_id is changed in the file
session.execute(
delete(ParseImportError)
.where(ParseImportError.filename == dag.fileloc)
.where(
ParseImportError.filename == dag.fileloc,
ParseImportError.bundle_name == dag.get_bundle_name(session),
)
.execution_options(synchronize_session="fetch")
)

Expand Down
27 changes: 21 additions & 6 deletions airflow/api_connexion/endpoints/import_error_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING

from sqlalchemy import func, select
from sqlalchemy import func, select, tuple_

from airflow.api_connexion import security
from airflow.api_connexion.exceptions import NotFound, PermissionDenied
Expand Down Expand Up @@ -61,7 +61,9 @@ def get_import_error(*, import_error_id: int, session: Session = NEW_SESSION) ->
readable_dag_ids = security.get_readable_dags()
file_dag_ids = {
dag_id[0]
for dag_id in session.query(DagModel.dag_id).filter(DagModel.fileloc == error.filename).all()
for dag_id in session.query(DagModel.dag_id)
.filter(DagModel.fileloc == error.filename, DagModel.bundle_name == error.bundle_name)
.all()
}

# Can the user read any DAGs in the file?
Expand Down Expand Up @@ -98,9 +100,17 @@ def get_import_errors(
if not can_read_all_dags:
# if the user doesn't have access to all DAGs, only display errors from visible DAGs
readable_dag_ids = security.get_readable_dags()
dagfiles_stmt = select(DagModel.fileloc).distinct().where(DagModel.dag_id.in_(readable_dag_ids))
query = query.where(ParseImportError.filename.in_(dagfiles_stmt))
count_query = count_query.where(ParseImportError.filename.in_(dagfiles_stmt))
dagfiles_stmt = (
select(DagModel.fileloc, DagModel.bundle_name)
.distinct()
.where(DagModel.dag_id.in_(readable_dag_ids))
)
query = query.where(
tuple_(ParseImportError.filename, ParseImportError.bundle_name).in_(dagfiles_stmt)
)
count_query = count_query.where(
tuple_(ParseImportError.filename, ParseImportError.bundle_name).in_(dagfiles_stmt)
)

total_entries = session.scalars(count_query).one()
import_errors = session.scalars(query.offset(offset).limit(limit)).all()
Expand All @@ -109,7 +119,12 @@ def get_import_errors(
for import_error in import_errors:
# Check if user has read access to all the DAGs defined in the file
file_dag_ids = (
session.query(DagModel.dag_id).filter(DagModel.fileloc == import_error.filename).all()
session.query(DagModel.dag_id)
.filter(
DagModel.fileloc == import_error.filename,
DagModel.bundle_name == import_error.bundle_name,
)
.all()
)
requests: Sequence[IsAuthorizedDagRequest] = [
{
Expand Down
1 change: 1 addition & 0 deletions airflow/api_connexion/schemas/error_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class Meta:
import_error_id = auto_field("id", dump_only=True)
timestamp = auto_field(format="iso", dump_only=True)
filename = auto_field(dump_only=True)
bundle_name = auto_field(dump_only=True)
stack_trace = auto_field("stacktrace", dump_only=True)


Expand Down
1 change: 1 addition & 0 deletions airflow/api_fastapi/core_api/routes/public/import_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def get_import_errors(
"id",
"timestamp",
"filename",
"bundle_name",
"stacktrace",
],
ParseImportError,
Expand Down
34 changes: 26 additions & 8 deletions airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,35 +241,53 @@ def _update_dag_warnings(
session.merge(warning_to_add)


def _update_import_errors(files_parsed: set[str], import_errors: dict[str, str], session: Session):
def _update_import_errors(
files_parsed: set[str], bundle_name: str, import_errors: dict[str, str], session: Session
):
from airflow.listeners.listener import get_listener_manager

# We can remove anything from files parsed in this batch that doesn't have an error. We need to remove old
# errors (i.e. from files that are removed) separately

session.execute(delete(ParseImportError).where(ParseImportError.filename.in_(list(files_parsed))))
session.execute(
delete(ParseImportError).where(
ParseImportError.filename.in_(list(files_parsed)), ParseImportError.bundle_name == bundle_name
)
)

existing_import_error_files = set(session.scalars(select(ParseImportError.filename)))
existing_import_error_files = set(
session.execute(select(ParseImportError.filename, ParseImportError.bundle_name))
)

# Add the errors of the processed files
for filename, stacktrace in import_errors.items():
if filename in existing_import_error_files:
session.query(ParseImportError).where(ParseImportError.filename == filename).update(
{"filename": filename, "timestamp": utcnow(), "stacktrace": stacktrace},
if (filename, bundle_name) in existing_import_error_files:
session.query(ParseImportError).where(
ParseImportError.filename == filename, ParseImportError.bundle_name == bundle_name
).update(
{
"filename": filename,
"bundle_name": bundle_name,
"timestamp": utcnow(),
"stacktrace": stacktrace,
},
)
# sending notification when an existing dag import error occurs
get_listener_manager().hook.on_existing_dag_import_error(filename=filename, stacktrace=stacktrace)
else:
session.add(
ParseImportError(
filename=filename,
bundle_name=bundle_name,
timestamp=utcnow(),
stacktrace=stacktrace,
)
)
# sending notification when a new dag import error occurs
get_listener_manager().hook.on_new_dag_import_error(filename=filename, stacktrace=stacktrace)
session.query(DagModel).filter(DagModel.fileloc == filename).update({"has_import_errors": True})
session.query(DagModel).filter(
DagModel.fileloc == filename, DagModel.bundle_name == bundle_name
).update({"has_import_errors": True})


def update_dag_parsing_results_in_db(
Expand Down Expand Up @@ -314,7 +332,6 @@ def update_dag_parsing_results_in_db(
try:
DAG.bulk_write_to_db(bundle_name, bundle_version, dags, session=session)
# Write Serialized DAGs to DB, capturing errors
# Write Serialized DAGs to DB, capturing errors
for dag in dags:
serialize_errors.extend(_serialize_dag_capturing_errors(dag, session))
except OperationalError:
Expand All @@ -332,6 +349,7 @@ def update_dag_parsing_results_in_db(
good_dag_filelocs = {dag.fileloc for dag in dags if dag.fileloc not in import_errors}
_update_import_errors(
files_parsed=good_dag_filelocs,
bundle_name=bundle_name,
import_errors=import_errors,
session=session,
)
Expand Down
7 changes: 5 additions & 2 deletions airflow/dag_processing/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

import attrs
from setproctitle import setproctitle
from sqlalchemy import delete, select, update
from sqlalchemy import and_, delete, select, update
from tabulate import tabulate
from uuid6 import uuid7

Expand Down Expand Up @@ -756,7 +756,10 @@ def clear_nonexistent_import_errors(self, session=NEW_SESSION):

if self._file_paths:
query = query.where(
ParseImportError.filename.notin_([f.path for f in self._file_paths]),
and_(
ParseImportError.filename.notin_([f.path for f in self._file_paths]),
ParseImportError.bundle_name.notin_([f.bundle_name for f in self._file_paths]),
)
)

session.execute(query.execution_options(synchronize_session="fetch"))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
add bundle_name to ParseImportError.
Revision ID: 03de77aaa4ec
Revises: e39a26ac59f6
Create Date: 2025-01-08 10:38:02.108760
"""

from __future__ import annotations

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "03de77aaa4ec"
down_revision = "e39a26ac59f6"
branch_labels = None
depends_on = None
airflow_version = "3.0.0"


def upgrade():
"""Apply add bundle_name to ParseImportError."""
with op.batch_alter_table("import_error", schema=None) as batch_op:
batch_op.add_column(sa.Column("bundle_name", sa.String(length=250), nullable=True))


def downgrade():
"""Unapply add bundle_name to ParseImportError."""
with op.batch_alter_table("import_error", schema=None) as batch_op:
batch_op.drop_column("bundle_name")
3 changes: 2 additions & 1 deletion airflow/models/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from sqlalchemy import Column, Integer, String, Text

from airflow.models.base import Base
from airflow.models.base import Base, StringID
from airflow.utils.sqlalchemy import UtcDateTime


Expand All @@ -30,4 +30,5 @@ class ParseImportError(Base):
id = Column(Integer, primary_key=True)
timestamp = Column(UtcDateTime)
filename = Column(String(1024))
bundle_name = Column(StringID())
stacktrace = Column(Text)
2 changes: 1 addition & 1 deletion airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class MappedClassProtocol(Protocol):
"2.9.2": "686269002441",
"2.10.0": "22ed7efa9da2",
"2.10.3": "5f2621c13b39",
"3.0.0": "e39a26ac59f6",
"3.0.0": "03de77aaa4ec",
}


Expand Down
11 changes: 8 additions & 3 deletions airflow/www/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,19 @@ def encode_dag_run(
return encoded_dag_run, None


def check_import_errors(fileloc, session):
def check_import_errors(fileloc, bundle_name, session):
# Check dag import errors
import_errors = session.scalars(
select(ParseImportError).where(ParseImportError.filename == fileloc)
select(ParseImportError).where(
ParseImportError.filename == fileloc, ParseImportError.bundle_name == bundle_name
)
).all()
if import_errors:
for import_error in import_errors:
flash(f"Broken DAG: [{import_error.filename}] {import_error.stacktrace}", "dag_import_error")
flash(
f"Broken DAG: [{import_error.filename}, Bundle name: {bundle_name}] {import_error.stacktrace}",
"dag_import_error",
)


def check_dag_warnings(dag_id, session):
Expand Down
9 changes: 6 additions & 3 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,10 @@ def index(self):
import_errors = import_errors.where(
ParseImportError.filename.in_(
select(DagModel.fileloc).distinct().where(DagModel.dag_id.in_(filter_dag_ids))
)
),
ParseImportError.bundle_name.in_(
select(DagModel.bundle_name).distinct().where(DagModel.dag_id.in_(filter_dag_ids))
),
)

import_errors = session.scalars(import_errors)
Expand Down Expand Up @@ -2876,10 +2879,10 @@ def grid(self, dag_id: str, session: Session = NEW_SESSION):
dag = get_airflow_app().dag_bag.get_dag(dag_id, session=session)
url_serializer = URLSafeSerializer(current_app.config["SECRET_KEY"])
dag_model = DagModel.get_dagmodel(dag_id, session=session)
if not dag:
if not dag or not dag_model:
flash(f'DAG "{dag_id}" seems to be missing from DagBag.', "error")
return redirect(url_for("Airflow.index"))
wwwutils.check_import_errors(dag.fileloc, session)
wwwutils.check_import_errors(dag.fileloc, dag_model.bundle_name, session)
wwwutils.check_dag_warnings(dag.dag_id, session)

included_events_raw = conf.get("webserver", "audit_view_included_events", fallback="")
Expand Down
2 changes: 1 addition & 1 deletion docs/apache-airflow/img/airflow_erd.sha256
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ca59d711e6304f8bfdb25f49339d455602430dd6b880e420869fc892faef0596
79449705d667d8fe382b4b53a0c59bf55bf31eeafc34941c59e6dceccc68d7a7
Loading

0 comments on commit b1a8ced

Please sign in to comment.