From 721fa2746d06fa778dffc6b13a723259e6b210ac Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Fri, 17 Jan 2025 08:26:21 -0800 Subject: [PATCH 1/8] Move config item dag_bundle_storage_path to dag_bundles section (#45697) --- airflow/config_templates/config.yml | 20 ++++++++++--------- airflow/dag_processing/bundles/base.py | 2 +- airflow/dag_processing/bundles/manager.py | 2 +- providers/tests/fab/auth_manager/conftest.py | 2 +- .../tests/execution_time/test_supervisor.py | 2 +- .../tests/execution_time/test_task_runner.py | 4 ++-- tests/conftest.py | 2 +- .../bundles/test_dag_bundle_manager.py | 14 ++++++------- tests/dag_processing/test_dag_bundles.py | 6 +++--- tests/dag_processing/test_manager.py | 4 ++-- 10 files changed, 30 insertions(+), 28 deletions(-) diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index ba6af6ca11e138..5b99c94a4f33db 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -28,14 +28,6 @@ core: type: string example: ~ default: "{AIRFLOW_HOME}/dags" - dag_bundle_storage_path: - description: | - The folder where Airflow bundles can store files locally (if required). - By default, this is ``tempfile.gettempdir()/airflow``. This path must be absolute. - version_added: 3.0.0 - type: string - example: "`tempfile.gettempdir()/dag_bundles" - default: ~ hostname_callable: description: | Hostname by providing a path to a callable, which will resolve the hostname. @@ -2670,7 +2662,17 @@ dag_bundles: Configuration for the DAG bundles. This allows Airflow to load DAGs from different sources. options: - backends: + dag_bundle_storage_path: + description: | + String path to folder where Airflow bundles can store files locally. Not templated. + If no path is provided, Airflow will use ``Path(tempfile.gettempdir()) / "airflow"``. + This path must be absolute. + version_added: 3.0.0 + type: string + example: "/tmp/some-place" + default: ~ + + config_list: description: | List of backend configs. Must supply name, classpath, and kwargs for each backend. diff --git a/airflow/dag_processing/bundles/base.py b/airflow/dag_processing/bundles/base.py index da60f77cf4a961..9b55c0d4f0ecf4 100644 --- a/airflow/dag_processing/bundles/base.py +++ b/airflow/dag_processing/bundles/base.py @@ -74,7 +74,7 @@ def _dag_bundle_root_storage_path(self) -> Path: This is the root path, shared by various bundles. Each bundle should have its own subdirectory. """ - if configured_location := conf.get("core", "dag_bundle_storage_path", fallback=None): + if configured_location := conf.get("dag_bundles", "dag_bundle_storage_path", fallback=None): return Path(configured_location) return Path(tempfile.gettempdir(), "airflow", "dag_bundles") diff --git a/airflow/dag_processing/bundles/manager.py b/airflow/dag_processing/bundles/manager.py index c5a2115b24f758..2aa8cf2303ddd5 100644 --- a/airflow/dag_processing/bundles/manager.py +++ b/airflow/dag_processing/bundles/manager.py @@ -54,7 +54,7 @@ def parse_config(self) -> None: if self._bundle_config: return - backends = conf.getjson("dag_bundles", "backends") + backends = conf.getjson("dag_bundles", "config_list") if not backends: return diff --git a/providers/tests/fab/auth_manager/conftest.py b/providers/tests/fab/auth_manager/conftest.py index 9c61f7ab2dccc0..4cb4b84a24cde3 100644 --- a/providers/tests/fab/auth_manager/conftest.py +++ b/providers/tests/fab/auth_manager/conftest.py @@ -95,7 +95,7 @@ def _config_bundle(path_to_parse: Path | str): "kwargs": {"path": str(path_to_parse), "refresh_interval": 0}, } ] - with conf_vars({("dag_bundles", "backends"): json.dumps(bundle_config)}): + with conf_vars({("dag_bundles", "config_list"): json.dumps(bundle_config)}): yield return _config_bundle diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index fb84713216625e..59afa26dc2aa5a 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -74,7 +74,7 @@ def lineno(): def local_dag_bundle_cfg(path, name="my-bundle"): return { - "AIRFLOW__DAG_BUNDLES__BACKENDS": json.dumps( + "AIRFLOW__DAG_BUNDLES__CONFIG_LIST": json.dumps( [ { "name": name, diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 0e3698050b88e6..60b39da455c69f 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -130,7 +130,7 @@ def test_parse(test_dags_dir: Path, make_ti_context): with patch.dict( os.environ, { - "AIRFLOW__DAG_BUNDLES__BACKENDS": json.dumps( + "AIRFLOW__DAG_BUNDLES__CONFIG_LIST": json.dumps( [ { "name": "my-bundle", @@ -574,7 +574,7 @@ def test_dag_parsing_context(make_ti_context, mock_supervisor_comms, monkeypatch ] ) - monkeypatch.setenv("AIRFLOW__DAG_BUNDLES__BACKENDS", dag_bundle_val) + monkeypatch.setenv("AIRFLOW__DAG_BUNDLES__CONFIG_LIST", dag_bundle_val) ti, _ = startup() # Presence of `conditional_task` below means DAG ID is properly set in the parsing context! diff --git a/tests/conftest.py b/tests/conftest.py index fca82aee34b871..8082238808dd47 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -111,7 +111,7 @@ def _config_bundle(path_to_parse: Path | str): "kwargs": {"path": str(path_to_parse), "refresh_interval": 0}, } ] - with conf_vars({("dag_bundles", "backends"): json.dumps(bundle_config)}): + with conf_vars({("dag_bundles", "config_list"): json.dumps(bundle_config)}): yield return _config_bundle diff --git a/tests/dag_processing/bundles/test_dag_bundle_manager.py b/tests/dag_processing/bundles/test_dag_bundle_manager.py index b0baa21c6f84c0..26f8b045837eba 100644 --- a/tests/dag_processing/bundles/test_dag_bundle_manager.py +++ b/tests/dag_processing/bundles/test_dag_bundle_manager.py @@ -70,7 +70,7 @@ def test_parse_bundle_config(value, expected): """Test that bundle_configs are read from configuration.""" envs = {"AIRFLOW__CORE__LOAD_EXAMPLES": "False"} if value: - envs["AIRFLOW__DAG_BUNDLES__BACKENDS"] = value + envs["AIRFLOW__DAG_BUNDLES__CONFIG_LIST"] = value cm = nullcontext() exp_fail = False if isinstance(expected, str): @@ -108,7 +108,7 @@ def path(self): def test_get_bundle(): """Test that get_bundle builds and returns a bundle.""" - with patch.dict(os.environ, {"AIRFLOW__DAG_BUNDLES__BACKENDS": json.dumps(BASIC_BUNDLE_CONFIG)}): + with patch.dict(os.environ, {"AIRFLOW__DAG_BUNDLES__CONFIG_LIST": json.dumps(BASIC_BUNDLE_CONFIG)}): bundle_manager = DagBundlesManager() with pytest.raises(ValueError, match="'bundle-that-doesn't-exist' is not configured"): @@ -120,7 +120,7 @@ def test_get_bundle(): assert bundle.refresh_interval == 1 # And none for version also works! - with patch.dict(os.environ, {"AIRFLOW__DAG_BUNDLES__BACKENDS": json.dumps(BASIC_BUNDLE_CONFIG)}): + with patch.dict(os.environ, {"AIRFLOW__DAG_BUNDLES__CONFIG_LIST": json.dumps(BASIC_BUNDLE_CONFIG)}): bundle = bundle_manager.get_bundle(name="my-test-bundle") assert isinstance(bundle, BasicBundle) assert bundle.name == "my-test-bundle" @@ -144,7 +144,7 @@ def _get_bundle_names_and_active(): ) # Initial add - with patch.dict(os.environ, {"AIRFLOW__DAG_BUNDLES__BACKENDS": json.dumps(BASIC_BUNDLE_CONFIG)}): + with patch.dict(os.environ, {"AIRFLOW__DAG_BUNDLES__CONFIG_LIST": json.dumps(BASIC_BUNDLE_CONFIG)}): manager = DagBundlesManager() manager.sync_bundles_to_db() assert _get_bundle_names_and_active() == [("my-test-bundle", True)] @@ -156,13 +156,13 @@ def _get_bundle_names_and_active(): assert _get_bundle_names_and_active() == [("dags-folder", True), ("my-test-bundle", False)] # Re-enable one that reappears in config - with patch.dict(os.environ, {"AIRFLOW__DAG_BUNDLES__BACKENDS": json.dumps(BASIC_BUNDLE_CONFIG)}): + with patch.dict(os.environ, {"AIRFLOW__DAG_BUNDLES__CONFIG_LIST": json.dumps(BASIC_BUNDLE_CONFIG)}): manager = DagBundlesManager() manager.sync_bundles_to_db() assert _get_bundle_names_and_active() == [("dags-folder", False), ("my-test-bundle", True)] -@conf_vars({("dag_bundles", "backends"): json.dumps(BASIC_BUNDLE_CONFIG)}) +@conf_vars({("dag_bundles", "config_list"): json.dumps(BASIC_BUNDLE_CONFIG)}) @pytest.mark.parametrize("version", [None, "hello"]) def test_view_url(version): """Test that view_url calls the bundle's view_url method.""" @@ -185,6 +185,6 @@ def test_example_dags_bundle_added(): def test_example_dags_name_is_reserved(): reserved_name_config = [{"name": "example_dags"}] - with conf_vars({("dag_bundles", "backends"): json.dumps(reserved_name_config)}): + with conf_vars({("dag_bundles", "config_list"): json.dumps(reserved_name_config)}): with pytest.raises(AirflowConfigException, match="Bundle name 'example_dags' is a reserved name."): DagBundlesManager().parse_config() diff --git a/tests/dag_processing/test_dag_bundles.py b/tests/dag_processing/test_dag_bundles.py index 32b2277b68c54e..6f6fb2c80f044e 100644 --- a/tests/dag_processing/test_dag_bundles.py +++ b/tests/dag_processing/test_dag_bundles.py @@ -39,12 +39,12 @@ @pytest.fixture(autouse=True) def bundle_temp_dir(tmp_path): - with conf_vars({("core", "dag_bundle_storage_path"): str(tmp_path)}): + with conf_vars({("dag_bundles", "dag_bundle_storage_path"): str(tmp_path)}): yield tmp_path def test_default_dag_storage_path(): - with conf_vars({("core", "dag_bundle_storage_path"): ""}): + with conf_vars({("dag_bundles", "dag_bundle_storage_path"): ""}): bundle = LocalDagBundle(name="test", path="/hello") assert bundle._dag_bundle_root_storage_path == Path(tempfile.gettempdir(), "airflow", "dag_bundles") @@ -60,7 +60,7 @@ def get_current_version(self): def path(self): pass - with conf_vars({("core", "dag_bundle_storage_path"): None}): + with conf_vars({("dag_bundles", "dag_bundle_storage_path"): None}): bundle = BasicBundle(name="test") assert bundle._dag_bundle_root_storage_path == Path(tempfile.gettempdir(), "airflow", "dag_bundles") diff --git a/tests/dag_processing/test_manager.py b/tests/dag_processing/test_manager.py index 4ab55c24eefc04..68740c4601ba03 100644 --- a/tests/dag_processing/test_manager.py +++ b/tests/dag_processing/test_manager.py @@ -857,7 +857,7 @@ def test_bundles_are_refreshed(self): bundletwo.refresh_interval = 300 bundletwo.get_current_version.return_value = None - with conf_vars({("dag_bundles", "backends"): json.dumps(config)}): + with conf_vars({("dag_bundles", "config_list"): json.dumps(config)}): DagBundlesManager().sync_bundles_to_db() with mock.patch( "airflow.dag_processing.bundles.manager.DagBundlesManager" @@ -910,7 +910,7 @@ def test_bundles_versions_are_stored(self): mybundle.supports_versioning = True mybundle.get_current_version.return_value = "123" - with conf_vars({("dag_bundles", "backends"): json.dumps(config)}): + with conf_vars({("dag_bundles", "config_list"): json.dumps(config)}): DagBundlesManager().sync_bundles_to_db() with mock.patch( "airflow.dag_processing.bundles.manager.DagBundlesManager" From 3da6796728df04419d153a36557b38b77c40de14 Mon Sep 17 00:00:00 2001 From: GPK Date: Fri, 17 Jan 2025 17:08:19 +0000 Subject: [PATCH 2/8] Bump UV to 0.5.20 (#45750) --- .github/actions/install-pre-commit/action.yml | 2 +- Dockerfile | 2 +- Dockerfile.ci | 2 +- dev/breeze/doc/ci/02_images.md | 2 +- .../src/airflow_breeze/commands/release_management_commands.py | 2 +- dev/breeze/src/airflow_breeze/global_constants.py | 2 +- scripts/ci/install_breeze.sh | 2 +- scripts/tools/setup_breeze | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/actions/install-pre-commit/action.yml b/.github/actions/install-pre-commit/action.yml index 8ac0440ceae7fc..30a3367710a923 100644 --- a/.github/actions/install-pre-commit/action.yml +++ b/.github/actions/install-pre-commit/action.yml @@ -24,7 +24,7 @@ inputs: default: "3.9" uv-version: description: 'uv version to use' - default: "0.5.14" # Keep this comment to allow automatic replacement of uv version + default: "0.5.20" # Keep this comment to allow automatic replacement of uv version pre-commit-version: description: 'pre-commit version to use' default: "4.0.1" # Keep this comment to allow automatic replacement of pre-commit version diff --git a/Dockerfile b/Dockerfile index fb82c882048c03..9b7e8a4391f3eb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -55,7 +55,7 @@ ARG PYTHON_BASE_IMAGE="python:3.9-slim-bookworm" # Also use `force pip` label on your PR to swap all places we use `uv` to `pip` ARG AIRFLOW_PIP_VERSION=24.3.1 # ARG AIRFLOW_PIP_VERSION="git+https://github.com/pypa/pip.git@main" -ARG AIRFLOW_UV_VERSION=0.5.14 +ARG AIRFLOW_UV_VERSION=0.5.20 ARG AIRFLOW_USE_UV="false" ARG UV_HTTP_TIMEOUT="300" ARG AIRFLOW_IMAGE_REPOSITORY="https://github.com/apache/airflow" diff --git a/Dockerfile.ci b/Dockerfile.ci index 5996ebe1ccb219..4e80ff1050abd1 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -1268,7 +1268,7 @@ COPY --from=scripts common.sh install_packaging_tools.sh install_additional_depe # Also use `force pip` label on your PR to swap all places we use `uv` to `pip` ARG AIRFLOW_PIP_VERSION=24.3.1 # ARG AIRFLOW_PIP_VERSION="git+https://github.com/pypa/pip.git@main" -ARG AIRFLOW_UV_VERSION=0.5.14 +ARG AIRFLOW_UV_VERSION=0.5.20 # TODO(potiuk): automate with upgrade check (possibly) ARG AIRFLOW_PRE_COMMIT_VERSION="4.0.1" ARG AIRFLOW_PRE_COMMIT_UV_VERSION="4.1.4" diff --git a/dev/breeze/doc/ci/02_images.md b/dev/breeze/doc/ci/02_images.md index 3d1d7d8b53eb7e..84f71f34c3f1d7 100644 --- a/dev/breeze/doc/ci/02_images.md +++ b/dev/breeze/doc/ci/02_images.md @@ -443,7 +443,7 @@ can be used for CI images: | `ADDITIONAL_DEV_APT_DEPS` | | Additional apt dev dependencies installed in the first part of the image | | `ADDITIONAL_DEV_APT_ENV` | | Additional env variables defined when installing dev deps | | `AIRFLOW_PIP_VERSION` | `24.3.1` | `pip` version used. | -| `AIRFLOW_UV_VERSION` | `0.5.14` | `uv` version used. | +| `AIRFLOW_UV_VERSION` | `0.5.20` | `uv` version used. | | `AIRFLOW_PRE_COMMIT_VERSION` | `4.0.1` | `pre-commit` version used. | | `AIRFLOW_PRE_COMMIT_UV_VERSION` | `4.1.4` | `pre-commit-uv` version used. | | `AIRFLOW_USE_UV` | `true` | Whether to use UV for installation. | diff --git a/dev/breeze/src/airflow_breeze/commands/release_management_commands.py b/dev/breeze/src/airflow_breeze/commands/release_management_commands.py index 1dad40f7e5b9f0..99e619240702bc 100644 --- a/dev/breeze/src/airflow_breeze/commands/release_management_commands.py +++ b/dev/breeze/src/airflow_breeze/commands/release_management_commands.py @@ -234,7 +234,7 @@ class VersionedFile(NamedTuple): AIRFLOW_PIP_VERSION = "24.3.1" -AIRFLOW_UV_VERSION = "0.5.14" +AIRFLOW_UV_VERSION = "0.5.20" AIRFLOW_USE_UV = False # TODO: automate these as well WHEEL_VERSION = "0.44.0" diff --git a/dev/breeze/src/airflow_breeze/global_constants.py b/dev/breeze/src/airflow_breeze/global_constants.py index b18f8c4da32273..287611732567ba 100644 --- a/dev/breeze/src/airflow_breeze/global_constants.py +++ b/dev/breeze/src/airflow_breeze/global_constants.py @@ -189,7 +189,7 @@ ALLOWED_INSTALL_MYSQL_CLIENT_TYPES = ["mariadb", "mysql"] PIP_VERSION = "24.3.1" -UV_VERSION = "0.5.14" +UV_VERSION = "0.5.20" DEFAULT_UV_HTTP_TIMEOUT = 300 DEFAULT_WSL2_HTTP_TIMEOUT = 900 diff --git a/scripts/ci/install_breeze.sh b/scripts/ci/install_breeze.sh index a9a830ca319142..98c29d44d07c2a 100755 --- a/scripts/ci/install_breeze.sh +++ b/scripts/ci/install_breeze.sh @@ -22,7 +22,7 @@ cd "$( dirname "${BASH_SOURCE[0]}" )/../../" PYTHON_ARG="" PIP_VERSION="24.3.1" -UV_VERSION="0.5.14" +UV_VERSION="0.5.20" if [[ ${PYTHON_VERSION=} != "" ]]; then PYTHON_ARG="--python=$(which python"${PYTHON_VERSION}") " fi diff --git a/scripts/tools/setup_breeze b/scripts/tools/setup_breeze index 8b3932c982008e..272d56c89d64b5 100755 --- a/scripts/tools/setup_breeze +++ b/scripts/tools/setup_breeze @@ -27,7 +27,7 @@ COLOR_YELLOW=$'\e[33m' COLOR_BLUE=$'\e[34m' COLOR_RESET=$'\e[0m' -UV_VERSION="0.5.14" +UV_VERSION="0.5.20" function manual_instructions() { echo From 4521e8df69ec2b7ab6327a56c13276ed5717f010 Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Fri, 17 Jan 2025 18:08:48 +0100 Subject: [PATCH 3/8] Skip serialization tests when latest botocore is installed (#45755) --- tests/serialization/test_dag_serialization.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index a67864fb1ca6bf..2b5d4cce4c7bb5 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -415,6 +415,13 @@ def setup_test_cases(self): ) ) + # Skip that test if latest botocore is used - it reads all example dags and in case latest botocore + # is upgraded to latest, usually aiobotocore can't be installed and some of the system tests will fail with + # import errors. + @pytest.mark.skipif( + os.environ.get("UPGRADE_BOTO", "") == "true", + reason="This test is skipped when latest botocore is installed", + ) @pytest.mark.db_test def test_serialization(self): """Serialization and deserialization should work for every DAG and Operator.""" From 9984dcdd7c9cdffc89bc35c8cb2077f96215fe44 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Sat, 18 Jan 2025 02:16:05 +0800 Subject: [PATCH 4/8] =?UTF-8?q?Add=20newsfragment=20and=20migration=20rule?= =?UTF-8?q?s=20for=20`scheduler.dag=5Fdir=5Flist=5Finterval`=20=E2=86=92?= =?UTF-8?q?=20`dag=5Fbundles.refresh=5Finterval`=20configuration=20change?= =?UTF-8?q?=20(#45737)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../commands/remote_commands/config_command.py | 4 ++++ newsfragments/45722.significant.rst | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+) create mode 100644 newsfragments/45722.significant.rst diff --git a/airflow/cli/commands/remote_commands/config_command.py b/airflow/cli/commands/remote_commands/config_command.py index 1e0bcd17c9b78c..5fc0148a2b7900 100644 --- a/airflow/cli/commands/remote_commands/config_command.py +++ b/airflow/cli/commands/remote_commands/config_command.py @@ -327,6 +327,10 @@ def message(self) -> str: config=ConfigParameter("scheduler", "statsd_custom_client_path"), renamed_to=ConfigParameter("metrics", "statsd_custom_client_path"), ), + ConfigChange( + config=ConfigParameter("scheduler", "dag_dir_list_interval"), + renamed_to=ConfigParameter("dag_bundles", "refresh_interval"), + ), # celery ConfigChange( config=ConfigParameter("celery", "stalled_task_timeout"), diff --git a/newsfragments/45722.significant.rst b/newsfragments/45722.significant.rst new file mode 100644 index 00000000000000..3e9068a1ac13dc --- /dev/null +++ b/newsfragments/45722.significant.rst @@ -0,0 +1,18 @@ +Move airflow config ``scheduler.dag_dir_list_interval`` to ``dag_bundles.refresh_interval`` + +* Types of change + + * [ ] DAG changes + * [x] Config changes + * [ ] API changes + * [ ] CLI changes + * [ ] Behaviour changes + * [ ] Plugin changes + * [ ] Dependency change + * [ ] Code interface change + +* Migration rules needed + + * ``airflow config lint`` + + * [x] ``scheduler.dag_dir_list_interval`` → ``dag_bundles.refresh_interval`` From 418b701bbdcad58bb0b8bc6d9bd4a0fedca937f1 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Sat, 18 Jan 2025 01:51:29 +0530 Subject: [PATCH 5/8] AIP-72: Add support for `outlet_events` in Task Context (#45727) part of https://github.com/apache/airflow/issues/45717 and https://github.com/apache/airflow/issues/45752 This PR adds support for `outlet_events` in Context dict within the Task SDK by adding an endpoint on the API Server which is fetched when outlet_events is accessed. --- .../execution_api/datamodels/asset.py | 36 +++++ .../execution_api/routes/__init__.py | 10 +- .../execution_api/routes/assets.py | 71 ++++++++++ airflow/serialization/serialized_objects.py | 3 +- airflow/utils/context.py | 104 ++------------- task_sdk/src/airflow/sdk/api/client.py | 25 ++++ .../airflow/sdk/api/datamodels/_generated.py | 20 +++ .../airflow/sdk/definitions/asset/__init__.py | 4 +- .../src/airflow/sdk/execution_time/comms.py | 34 ++++- .../src/airflow/sdk/execution_time/context.py | 123 +++++++++++++++++- .../airflow/sdk/execution_time/supervisor.py | 11 ++ .../airflow/sdk/execution_time/task_runner.py | 4 +- task_sdk/tests/execution_time/test_context.py | 103 ++++++++++++++- .../tests/execution_time/test_supervisor.py | 39 +++++- .../tests/execution_time/test_task_runner.py | 9 +- .../execution_api/routes/test_assets.py | 110 ++++++++++++++++ .../serialization/test_serialized_objects.py | 3 +- tests/utils/test_context.py | 102 --------------- 18 files changed, 603 insertions(+), 208 deletions(-) create mode 100644 airflow/api_fastapi/execution_api/datamodels/asset.py create mode 100644 airflow/api_fastapi/execution_api/routes/assets.py create mode 100644 tests/api_fastapi/execution_api/routes/test_assets.py delete mode 100644 tests/utils/test_context.py diff --git a/airflow/api_fastapi/execution_api/datamodels/asset.py b/airflow/api_fastapi/execution_api/datamodels/asset.py new file mode 100644 index 00000000000000..6d3a53c3e4ca85 --- /dev/null +++ b/airflow/api_fastapi/execution_api/datamodels/asset.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from airflow.api_fastapi.core_api.base import BaseModel + + +class AssetResponse(BaseModel): + """Asset schema for responses with fields that are needed for Runtime.""" + + name: str + uri: str + group: str + extra: dict | None = None + + +class AssetAliasResponse(BaseModel): + """Asset alias schema with fields that are needed for Runtime.""" + + name: str + group: str diff --git a/airflow/api_fastapi/execution_api/routes/__init__.py b/airflow/api_fastapi/execution_api/routes/__init__.py index 0383503f18b874..793cd8fe084944 100644 --- a/airflow/api_fastapi/execution_api/routes/__init__.py +++ b/airflow/api_fastapi/execution_api/routes/__init__.py @@ -17,9 +17,17 @@ from __future__ import annotations from airflow.api_fastapi.common.router import AirflowRouter -from airflow.api_fastapi.execution_api.routes import connections, health, task_instances, variables, xcoms +from airflow.api_fastapi.execution_api.routes import ( + assets, + connections, + health, + task_instances, + variables, + xcoms, +) execution_api_router = AirflowRouter() +execution_api_router.include_router(assets.router, prefix="/assets", tags=["Assets"]) execution_api_router.include_router(connections.router, prefix="/connections", tags=["Connections"]) execution_api_router.include_router(health.router, tags=["Health"]) execution_api_router.include_router(task_instances.router, prefix="/task-instances", tags=["Task Instances"]) diff --git a/airflow/api_fastapi/execution_api/routes/assets.py b/airflow/api_fastapi/execution_api/routes/assets.py new file mode 100644 index 00000000000000..213c599befb3e3 --- /dev/null +++ b/airflow/api_fastapi/execution_api/routes/assets.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import Annotated + +from fastapi import HTTPException, Query, status +from sqlalchemy import select + +from airflow.api_fastapi.common.db.common import SessionDep +from airflow.api_fastapi.common.router import AirflowRouter +from airflow.api_fastapi.execution_api.datamodels.asset import AssetResponse +from airflow.models.asset import AssetModel + +# TODO: Add dependency on JWT token +router = AirflowRouter( + responses={ + status.HTTP_404_NOT_FOUND: {"description": "Asset not found"}, + status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, + }, +) + + +@router.get("/by-name") +def get_asset_by_name( + name: Annotated[str, Query(description="The name of the Asset")], + session: SessionDep, +) -> AssetResponse: + """Get an Airflow Asset by `name`.""" + asset = session.scalar(select(AssetModel).where(AssetModel.name == name, AssetModel.active.has())) + _raise_if_not_found(asset, f"Asset with name {name} not found") + + return AssetResponse.model_validate(asset) + + +@router.get("/by-uri") +def get_asset_by_uri( + uri: Annotated[str, Query(description="The URI of the Asset")], + session: SessionDep, +) -> AssetResponse: + """Get an Airflow Asset by `uri`.""" + asset = session.scalar(select(AssetModel).where(AssetModel.uri == uri, AssetModel.active.has())) + _raise_if_not_found(asset, f"Asset with URI {uri} not found") + + return AssetResponse.model_validate(asset) + + +def _raise_if_not_found(asset, msg): + if asset is None: + raise HTTPException( + status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": msg, + }, + ) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 11c293b531fa6d..d828a9a5b6b241 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -64,6 +64,7 @@ BaseAsset, ) from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator +from airflow.sdk.execution_time.context import AssetAliasEvent, OutletEventAccessor from airflow.serialization.dag_dependency import DagDependency from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.helpers import serialize_template_field @@ -77,10 +78,8 @@ from airflow.triggers.base import BaseTrigger, StartTriggerArgs from airflow.utils.code_utils import get_python_source from airflow.utils.context import ( - AssetAliasEvent, ConnectionAccessor, Context, - OutletEventAccessor, OutletEventAccessors, VariableAccessor, ) diff --git a/airflow/utils/context.py b/airflow/utils/context.py index 1f453457e43235..168243290fabc4 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -19,7 +19,6 @@ from __future__ import annotations -import contextlib from collections.abc import ( Container, Iterator, @@ -51,9 +50,9 @@ AssetRef, AssetUniqueKey, AssetUriRef, - BaseAssetUniqueKey, ) from airflow.sdk.definitions.context import Context +from airflow.sdk.execution_time.context import OutletEventAccessors as OutletEventAccessorsSDK from airflow.utils.db import LazySelectSequence from airflow.utils.session import create_session from airflow.utils.types import NOTSET @@ -156,104 +155,29 @@ def get(self, key: str, default_conn: Any = None) -> Any: return default_conn -@attrs.define() -class AssetAliasEvent: - """ - Represeation of asset event to be triggered by an asset alias. - - :meta private: - """ - - source_alias_name: str - dest_asset_key: AssetUniqueKey - extra: dict[str, Any] - - -@attrs.define() -class OutletEventAccessor: - """ - Wrapper to access an outlet asset event in template. - - :meta private: - """ - - key: BaseAssetUniqueKey - extra: dict[str, Any] = attrs.Factory(dict) - asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list) - - def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: - """Add an AssetEvent to an existing Asset.""" - if not isinstance(self.key, AssetAliasUniqueKey): - return - - asset_alias_name = self.key.name - event = AssetAliasEvent( - source_alias_name=asset_alias_name, - dest_asset_key=AssetUniqueKey.from_asset(asset), - extra=extra or {}, - ) - self.asset_alias_events.append(event) - - -class OutletEventAccessors(Mapping[Union[Asset, AssetAlias], OutletEventAccessor]): +class OutletEventAccessors(OutletEventAccessorsSDK): """ Lazy mapping of outlet asset event accessors. :meta private: """ - _asset_ref_cache: dict[AssetRef, AssetUniqueKey] = {} - - def __init__(self) -> None: - self._dict: dict[BaseAssetUniqueKey, OutletEventAccessor] = {} - - def __str__(self) -> str: - return f"OutletEventAccessors(_dict={self._dict})" - - def __iter__(self) -> Iterator[Asset | AssetAlias]: - return ( - key.to_asset() if isinstance(key, AssetUniqueKey) else key.to_asset_alias() for key in self._dict - ) - - def __len__(self) -> int: - return len(self._dict) - - def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessor: - hashable_key: BaseAssetUniqueKey - if isinstance(key, Asset): - hashable_key = AssetUniqueKey.from_asset(key) - elif isinstance(key, AssetAlias): - hashable_key = AssetAliasUniqueKey.from_asset_alias(key) - elif isinstance(key, AssetRef): - hashable_key = self._resolve_asset_ref(key) - else: - raise TypeError(f"Key should be either an asset or an asset alias, not {type(key)}") - - if hashable_key not in self._dict: - self._dict[hashable_key] = OutletEventAccessor(extra={}, key=hashable_key) - return self._dict[hashable_key] - - def _resolve_asset_ref(self, ref: AssetRef) -> AssetUniqueKey: - with contextlib.suppress(KeyError): - return self._asset_ref_cache[ref] - - refs_to_cache: list[AssetRef] - with create_session() as session: - if isinstance(ref, AssetNameRef): + @staticmethod + def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset: + if name: + with create_session() as session: asset = session.scalar( - select(AssetModel).where(AssetModel.name == ref.name, AssetModel.active.has()) + select(AssetModel).where(AssetModel.name == name, AssetModel.active.has()) ) - refs_to_cache = [ref, AssetUriRef(asset.uri)] - elif isinstance(ref, AssetUriRef): + elif uri: + with create_session() as session: asset = session.scalar( - select(AssetModel).where(AssetModel.uri == ref.uri, AssetModel.active.has()) + select(AssetModel).where(AssetModel.uri == uri, AssetModel.active.has()) ) - refs_to_cache = [ref, AssetNameRef(asset.name)] - else: - raise TypeError(f"Unimplemented asset ref: {type(ref)}") - for ref in refs_to_cache: - self._asset_ref_cache[ref] = unique_key = AssetUniqueKey.from_asset(asset) - return unique_key + else: + raise ValueError("Either name or uri must be provided") + + return asset.to_public() class LazyAssetEventSelectSequence(LazySelectSequence[AssetEvent]): diff --git a/task_sdk/src/airflow/sdk/api/client.py b/task_sdk/src/airflow/sdk/api/client.py index 5ee270591481e3..e73e5aebea64b5 100644 --- a/task_sdk/src/airflow/sdk/api/client.py +++ b/task_sdk/src/airflow/sdk/api/client.py @@ -34,6 +34,7 @@ from airflow.sdk import __version__ from airflow.sdk.api.datamodels._generated import ( + AssetResponse, ConnectionResponse, DagRunType, TerminalTIState, @@ -267,6 +268,24 @@ def set( return {"ok": True} +class AssetOperations: + __slots__ = ("client",) + + def __init__(self, client: Client): + self.client = client + + def get(self, name: str | None = None, uri: str | None = None) -> AssetResponse: + """Get Asset value from the API server.""" + if name: + resp = self.client.get("assets/by-name", params={"name": name}) + elif uri: + resp = self.client.get("assets/by-uri", params={"uri": uri}) + else: + raise ValueError("Either `name` or `uri` must be provided") + + return AssetResponse.model_validate_json(resp.read()) + + class BearerAuth(httpx.Auth): def __init__(self, token: str): self.token: str = token @@ -374,6 +393,12 @@ def xcoms(self) -> XComOperations: """Operations related to XComs.""" return XComOperations(self) + @lru_cache() # type: ignore[misc] + @property + def assets(self) -> AssetOperations: + """Operations related to XComs.""" + return AssetOperations(self) + # This is only used for parsing. ServerResponseError is raised instead class _ErrorBody(BaseModel): diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py index a8b478d07f029a..f0a04da21c8942 100644 --- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -29,6 +29,15 @@ from pydantic import BaseModel, ConfigDict, Field +class AssetAliasResponse(BaseModel): + """ + Asset alias schema with fields that are needed for Runtime. + """ + + name: Annotated[str, Field(title="Name")] + group: Annotated[str, Field(title="Group")] + + class ConnectionResponse(BaseModel): """ Connection schema for responses with fields that are needed for Runtime. @@ -187,6 +196,17 @@ class TaskInstance(BaseModel): hostname: Annotated[str | None, Field(title="Hostname")] = None +class AssetResponse(BaseModel): + """ + Asset schema for responses with fields that are needed for Runtime. + """ + + name: Annotated[str, Field(title="Name")] + uri: Annotated[str, Field(title="Uri")] + group: Annotated[str, Field(title="Group")] + extra: Annotated[dict[str, Any] | None, Field(title="Extra")] = None + + class DagRun(BaseModel): """ Schema for DagRun model with minimal required fields needed for Runtime. diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py index 5b0cbb4a784d97..ea89f1b681701a 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -488,14 +488,14 @@ def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterat ) -@attrs.define() +@attrs.define(hash=True) class AssetNameRef(AssetRef): """Name reference to an asset.""" name: str -@attrs.define() +@attrs.define(hash=True) class AssetUriRef(AssetRef): """URI reference to an asset.""" diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py index b6874d47f090cd..f8aaab65af4f13 100644 --- a/task_sdk/src/airflow/sdk/execution_time/comms.py +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -50,6 +50,7 @@ from pydantic import BaseModel, ConfigDict, Field, JsonValue from airflow.sdk.api.datamodels._generated import ( + AssetResponse, BundleInfo, ConnectionResponse, TaskInstance, @@ -79,6 +80,25 @@ class StartupDetails(BaseModel): type: Literal["StartupDetails"] = "StartupDetails" +class AssetResult(AssetResponse): + """Response to ReadXCom request.""" + + type: Literal["AssetResult"] = "AssetResult" + + @classmethod + def from_asset_response(cls, asset_response: AssetResponse) -> AssetResult: + """ + Get AssetResult from AssetResponse. + + AssetResponse is autogenerated from the API schema, so we need to convert it to AssetResult + for communication between the Supervisor and the task process. + """ + # Exclude defaults to avoid sending unnecessary data + # Pass the type as AssetResult explicitly so we can then call model_dump_json with exclude_unset=True + # to avoid sending unset fields (which are defaults in our case). + return cls(**asset_response.model_dump(exclude_defaults=True), type="AssetResult") + + class XComResult(XComResponse): """Response to ReadXCom request.""" @@ -133,7 +153,7 @@ class ErrorResponse(BaseModel): ToTask = Annotated[ - Union[StartupDetails, XComResult, ConnectionResult, VariableResult, ErrorResponse], + Union[StartupDetails, XComResult, ConnectionResult, VariableResult, ErrorResponse, AssetResult], Field(discriminator="type"), ] @@ -231,12 +251,24 @@ class SetRenderedFields(BaseModel): type: Literal["SetRenderedFields"] = "SetRenderedFields" +class GetAssetByName(BaseModel): + name: str + type: Literal["GetAssetByName"] = "GetAssetByName" + + +class GetAssetByUri(BaseModel): + uri: str + type: Literal["GetAssetByUri"] = "GetAssetByUri" + + ToSupervisor = Annotated[ Union[ TaskState, GetXCom, GetConnection, GetVariable, + GetAssetByName, + GetAssetByUri, DeferTask, PutVariable, SetXCom, diff --git a/task_sdk/src/airflow/sdk/execution_time/context.py b/task_sdk/src/airflow/sdk/execution_time/context.py index cdb3880bb36b33..918526c3004c2e 100644 --- a/task_sdk/src/airflow/sdk/execution_time/context.py +++ b/task_sdk/src/airflow/sdk/execution_time/context.py @@ -17,20 +17,31 @@ from __future__ import annotations import contextlib -from collections.abc import Generator -from typing import TYPE_CHECKING, Any +from collections.abc import Generator, Iterator, Mapping +from typing import TYPE_CHECKING, Any, Union +import attrs import structlog from airflow.sdk.definitions._internal.contextmanager import _CURRENT_CONTEXT from airflow.sdk.definitions._internal.types import NOTSET +from airflow.sdk.definitions.asset import ( + Asset, + AssetAlias, + AssetAliasUniqueKey, + AssetNameRef, + AssetRef, + AssetUniqueKey, + AssetUriRef, + BaseAssetUniqueKey, +) from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType if TYPE_CHECKING: from airflow.sdk.definitions.connection import Connection from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.variable import Variable - from airflow.sdk.execution_time.comms import ConnectionResult, VariableResult + from airflow.sdk.execution_time.comms import AssetResult, ConnectionResult, VariableResult log = structlog.get_logger(logger_name="task") @@ -163,6 +174,112 @@ def __eq__(self, other: object) -> bool: return True +@attrs.define +class AssetAliasEvent: + """Representation of asset event to be triggered by an asset alias.""" + + source_alias_name: str + dest_asset_key: AssetUniqueKey + extra: dict[str, Any] + + +@attrs.define +class OutletEventAccessor: + """Wrapper to access an outlet asset event in template.""" + + key: BaseAssetUniqueKey + extra: dict[str, Any] = attrs.Factory(dict) + asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list) + + def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: + """Add an AssetEvent to an existing Asset.""" + if not isinstance(self.key, AssetAliasUniqueKey): + return + + asset_alias_name = self.key.name + event = AssetAliasEvent( + source_alias_name=asset_alias_name, + dest_asset_key=AssetUniqueKey.from_asset(asset), + extra=extra or {}, + ) + self.asset_alias_events.append(event) + + +class OutletEventAccessors(Mapping[Union[Asset, AssetAlias], OutletEventAccessor]): + """Lazy mapping of outlet asset event accessors.""" + + _asset_ref_cache: dict[AssetRef, AssetUniqueKey] = {} + + def __init__(self) -> None: + self._dict: dict[BaseAssetUniqueKey, OutletEventAccessor] = {} + + def __str__(self) -> str: + return f"OutletEventAccessors(_dict={self._dict})" + + def __iter__(self) -> Iterator[Asset | AssetAlias]: + return ( + key.to_asset() if isinstance(key, AssetUniqueKey) else key.to_asset_alias() for key in self._dict + ) + + def __len__(self) -> int: + return len(self._dict) + + def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessor: + hashable_key: BaseAssetUniqueKey + if isinstance(key, Asset): + hashable_key = AssetUniqueKey.from_asset(key) + elif isinstance(key, AssetAlias): + hashable_key = AssetAliasUniqueKey.from_asset_alias(key) + elif isinstance(key, AssetRef): + hashable_key = self._resolve_asset_ref(key) + else: + raise TypeError(f"Key should be either an asset or an asset alias, not {type(key)}") + + if hashable_key not in self._dict: + self._dict[hashable_key] = OutletEventAccessor(extra={}, key=hashable_key) + return self._dict[hashable_key] + + def _resolve_asset_ref(self, ref: AssetRef) -> AssetUniqueKey: + with contextlib.suppress(KeyError): + return self._asset_ref_cache[ref] + + refs_to_cache: list[AssetRef] + if isinstance(ref, AssetNameRef): + asset = self._get_asset_from_db(name=ref.name) + refs_to_cache = [ref, AssetUriRef(asset.uri)] + elif isinstance(ref, AssetUriRef): + asset = self._get_asset_from_db(uri=ref.uri) + refs_to_cache = [ref, AssetNameRef(asset.name)] + else: + raise TypeError(f"Unimplemented asset ref: {type(ref)}") + unique_key = AssetUniqueKey.from_asset(asset) + for ref in refs_to_cache: + self._asset_ref_cache[ref] = unique_key + return unique_key + + # TODO: This is temporary to avoid code duplication between here & airflow/models/taskinstance.py + @staticmethod + def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset: + from airflow.sdk.definitions.asset import Asset + from airflow.sdk.execution_time.comms import ErrorResponse, GetAssetByName, GetAssetByUri + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + if name: + SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetByName(name=name)) + elif uri: + SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetByUri(uri=uri)) + else: + raise ValueError("Either name or uri must be provided") + + msg = SUPERVISOR_COMMS.get_message() + if isinstance(msg, ErrorResponse): + raise AirflowRuntimeError(msg) + + if TYPE_CHECKING: + assert isinstance(msg, AssetResult) + return Asset(**msg.model_dump(exclude={"type"})) + + @contextlib.contextmanager def set_current_context(context: Context) -> Generator[Context, None, None]: """ diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 32895d36524d84..bd50ee5126b94a 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -61,8 +61,11 @@ VariableResponse, ) from airflow.sdk.execution_time.comms import ( + AssetResult, ConnectionResult, DeferTask, + GetAssetByName, + GetAssetByUri, GetConnection, GetVariable, GetXCom, @@ -787,6 +790,14 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): self.client.variables.set(msg.key, msg.value, msg.description) elif isinstance(msg, SetRenderedFields): self.client.task_instances.set_rtif(self.id, msg.rendered_fields) + elif isinstance(msg, GetAssetByName): + asset_resp = self.client.assets.get(name=msg.name) + asset_result = AssetResult.from_asset_response(asset_resp) + resp = asset_result.model_dump_json(exclude_unset=True).encode() + elif isinstance(msg, GetAssetByUri): + asset_resp = self.client.assets.get(uri=msg.uri) + asset_result = AssetResult.from_asset_response(asset_resp) + resp = asset_result.model_dump_json(exclude_unset=True).encode() else: log.error("Unhandled request", msg=msg) return diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 186faac878a0a7..d252c24be180c0 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -50,6 +50,7 @@ from airflow.sdk.execution_time.context import ( ConnectionAccessor, MacrosAccessor, + OutletEventAccessors, VariableAccessor, set_current_context, ) @@ -92,12 +93,13 @@ def get_template_context(self) -> Context: # TODO: Ensure that ti.log_url and such are available to use in context # especially after removal of `conf` from Context. "ti": self, - # "outlet_events": OutletEventAccessors(), + "outlet_events": OutletEventAccessors(), # "expanded_ti_count": expanded_ti_count, "expanded_ti_count": None, # TODO: Implement this # "inlet_events": InletEventsAccessors(task.inlets, session=session), "macros": MacrosAccessor(), # "params": validated_params, + # TODO: Make this go through Public API longer term. # "prev_data_interval_start_success": get_prev_data_interval_start_success(), # "prev_data_interval_end_success": get_prev_data_interval_end_success(), # "prev_start_date_success": get_prev_start_date_success(), diff --git a/task_sdk/tests/execution_time/test_context.py b/task_sdk/tests/execution_time/test_context.py index 6527d517e375f4..e3ef15dc934cf6 100644 --- a/task_sdk/tests/execution_time/test_context.py +++ b/task_sdk/tests/execution_time/test_context.py @@ -22,12 +22,16 @@ import pytest from airflow.sdk import get_current_context +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasUniqueKey, AssetUniqueKey from airflow.sdk.definitions.connection import Connection from airflow.sdk.definitions.variable import Variable from airflow.sdk.exceptions import ErrorType -from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse, VariableResult +from airflow.sdk.execution_time.comms import AssetResult, ConnectionResult, ErrorResponse, VariableResult from airflow.sdk.execution_time.context import ( + AssetAliasEvent, ConnectionAccessor, + OutletEventAccessor, + OutletEventAccessors, VariableAccessor, _convert_connection_result_conn, _convert_variable_result_to_variable, @@ -248,3 +252,100 @@ def test_nested_context(self): assert ctx["ContextId"] == i # End of with statement ctx_list[i].__exit__(None, None, None) + + +class TestOutletEventAccessor: + @pytest.mark.parametrize( + "key, asset_alias_events", + ( + (AssetUniqueKey.from_asset(Asset("test_uri")), []), + ( + AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")), + [ + AssetAliasEvent( + source_alias_name="test_alias", + dest_asset_key=AssetUniqueKey(uri="test_uri", name="test_uri"), + extra={}, + ) + ], + ), + ), + ) + def test_add(self, key, asset_alias_events, mock_supervisor_comms): + asset = Asset("test_uri") + mock_supervisor_comms.get_message.return_value = asset + + outlet_event_accessor = OutletEventAccessor(key=key, extra={}) + outlet_event_accessor.add(asset) + assert outlet_event_accessor.asset_alias_events == asset_alias_events + + @pytest.mark.parametrize( + "key, asset_alias_events", + ( + (AssetUniqueKey.from_asset(Asset("test_uri")), []), + ( + AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")), + [ + AssetAliasEvent( + source_alias_name="test_alias", + dest_asset_key=AssetUniqueKey(name="test-asset", uri="test://asset-uri/"), + extra={}, + ) + ], + ), + ), + ) + def test_add_with_db(self, key, asset_alias_events, mock_supervisor_comms): + asset = Asset(uri="test://asset-uri", name="test-asset") + mock_supervisor_comms.get_message.return_value = asset + + outlet_event_accessor = OutletEventAccessor(key=key, extra={"not": ""}) + outlet_event_accessor.add(asset, extra={}) + assert outlet_event_accessor.asset_alias_events == asset_alias_events + + +class TestOutletEventAccessors: + @pytest.mark.parametrize( + "access_key, internal_key", + ( + (Asset("test"), AssetUniqueKey.from_asset(Asset("test"))), + ( + Asset(name="test", uri="test://asset"), + AssetUniqueKey.from_asset(Asset(name="test", uri="test://asset")), + ), + (AssetAlias("test_alias"), AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias"))), + ), + ) + def test__get_item__dict_key_not_exists(self, access_key, internal_key): + outlet_event_accessors = OutletEventAccessors() + assert len(outlet_event_accessors) == 0 + outlet_event_accessor = outlet_event_accessors[access_key] + assert len(outlet_event_accessors) == 1 + assert outlet_event_accessor.key == internal_key + assert outlet_event_accessor.extra == {} + + @pytest.mark.parametrize( + ["access_key", "asset"], + ( + (Asset.ref(name="test"), Asset(name="test")), + (Asset.ref(name="test1"), Asset(name="test1", uri="test://asset-uri")), + (Asset.ref(uri="test://asset-uri"), Asset(uri="test://asset-uri")), + ), + ) + def test__get_item__asset_ref(self, access_key, asset, mock_supervisor_comms): + """Test accessing OutletEventAccessors with AssetRef resolves to correct Asset.""" + internal_key = AssetUniqueKey.from_asset(asset) + outlet_event_accessors = OutletEventAccessors() + assert len(outlet_event_accessors) == 0 + + # Asset from the API Server via the supervisor + mock_supervisor_comms.get_message.return_value = AssetResult( + name=asset.name, + uri=asset.uri, + group=asset.group, + ) + + outlet_event_accessor = outlet_event_accessors[access_key] + assert len(outlet_event_accessors) == 1 + assert outlet_event_accessor.key == internal_key + assert outlet_event_accessor.extra == {} diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 59afa26dc2aa5a..5455d0f70cdef0 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -41,8 +41,11 @@ from airflow.sdk.api.client import ServerResponseError from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState from airflow.sdk.execution_time.comms import ( + AssetResult, ConnectionResult, DeferTask, + GetAssetByName, + GetAssetByUri, GetConnection, GetVariable, GetXCom, @@ -805,13 +808,14 @@ def watched_subprocess(self, mocker): ) @pytest.mark.parametrize( - ["message", "expected_buffer", "client_attr_path", "method_arg", "mock_response"], + ["message", "expected_buffer", "client_attr_path", "method_arg", "method_kwarg", "mock_response"], [ pytest.param( GetConnection(conn_id="test_conn"), b'{"conn_id":"test_conn","conn_type":"mysql","type":"ConnectionResult"}\n', "connections.get", ("test_conn",), + {}, ConnectionResult(conn_id="test_conn", conn_type="mysql"), id="get_connection", ), @@ -820,6 +824,7 @@ def watched_subprocess(self, mocker): b'{"key":"test_key","value":"test_value","type":"VariableResult"}\n', "variables.get", ("test_key",), + {}, VariableResult(key="test_key", value="test_value"), id="get_variable", ), @@ -828,6 +833,7 @@ def watched_subprocess(self, mocker): b"", "variables.set", ("test_key", "test_value", "test_description"), + {}, {"ok": True}, id="set_variable", ), @@ -836,6 +842,7 @@ def watched_subprocess(self, mocker): b"", "task_instances.defer", (TI_ID, DeferTask(next_method="execute_callback", classpath="my-classpath")), + {}, "", id="patch_task_instance_to_deferred", ), @@ -853,6 +860,7 @@ def watched_subprocess(self, mocker): end_date=timezone.parse("2024-10-31T12:00:00Z"), ), ), + {}, "", id="patch_task_instance_to_up_for_reschedule", ), @@ -861,6 +869,7 @@ def watched_subprocess(self, mocker): b'{"key":"test_key","value":"test_value","type":"XComResult"}\n', "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", None), + {}, XComResult(key="test_key", value="test_value"), id="get_xcom", ), @@ -871,6 +880,7 @@ def watched_subprocess(self, mocker): b'{"key":"test_key","value":"test_value","type":"XComResult"}\n', "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", 2), + {}, XComResult(key="test_key", value="test_value"), id="get_xcom_map_index", ), @@ -879,6 +889,7 @@ def watched_subprocess(self, mocker): b'{"key":"test_key","value":null,"type":"XComResult"}\n', "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", None), + {}, XComResult(key="test_key", value=None, type="XComResult"), id="get_xcom_not_found", ), @@ -900,6 +911,7 @@ def watched_subprocess(self, mocker): '{"key": "test_key", "value": {"key2": "value2"}}', None, ), + {}, {"ok": True}, id="set_xcom", ), @@ -922,6 +934,7 @@ def watched_subprocess(self, mocker): '{"key": "test_key", "value": {"key2": "value2"}}', 2, ), + {}, {"ok": True}, id="set_xcom_with_map_index", ), @@ -932,6 +945,7 @@ def watched_subprocess(self, mocker): b"", "", (), + {}, "", id="patch_task_instance_to_skipped", ), @@ -940,9 +954,28 @@ def watched_subprocess(self, mocker): b"", "task_instances.set_rtif", (TI_ID, {"field1": "rendered_value1", "field2": "rendered_value2"}), + {}, {"ok": True}, id="set_rtif", ), + pytest.param( + GetAssetByName(name="asset"), + b'{"name":"asset","uri":"s3://bucket/obj","group":"asset","type":"AssetResult"}\n', + "assets.get", + [], + {"name": "asset"}, + AssetResult(name="asset", uri="s3://bucket/obj", group="asset"), + id="get_asset_by_name", + ), + pytest.param( + GetAssetByUri(uri="s3://bucket/obj"), + b'{"name":"asset","uri":"s3://bucket/obj","group":"asset","type":"AssetResult"}\n', + "assets.get", + [], + {"uri": "s3://bucket/obj"}, + AssetResult(name="asset", uri="s3://bucket/obj", group="asset"), + id="get_asset_by_uri", + ), ], ) def test_handle_requests( @@ -953,8 +986,8 @@ def test_handle_requests( expected_buffer, client_attr_path, method_arg, + method_kwarg, mock_response, - time_machine, ): """ Test handling of different messages to the subprocess. For any new message type, add a @@ -980,7 +1013,7 @@ def test_handle_requests( # Verify the correct client method was called if client_attr_path: - mock_client_method.assert_called_once_with(*method_arg) + mock_client_method.assert_called_once_with(*method_arg, **method_kwarg) # Verify the response was added to the buffer val = watched_subprocess.stdin.getvalue() diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 60b39da455c69f..f7734279b3ffb6 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -52,7 +52,12 @@ VariableResult, XComResult, ) -from airflow.sdk.execution_time.context import ConnectionAccessor, MacrosAccessor, VariableAccessor +from airflow.sdk.execution_time.context import ( + ConnectionAccessor, + MacrosAccessor, + OutletEventAccessors, + VariableAccessor, +) from airflow.sdk.execution_time.task_runner import ( CommsDecoder, RuntimeTaskInstance, @@ -613,6 +618,7 @@ def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_ "inlets": task.inlets, "macros": MacrosAccessor(), "map_index_template": task.map_index_template, + "outlet_events": OutletEventAccessors(), "outlets": task.outlets, "run_id": "test_run", "task": task, @@ -645,6 +651,7 @@ def test_get_context_with_ti_context_from_server(self, create_runtime_ti): "inlets": task.inlets, "macros": MacrosAccessor(), "map_index_template": task.map_index_template, + "outlet_events": OutletEventAccessors(), "outlets": task.outlets, "run_id": "test_run", "task": task, diff --git a/tests/api_fastapi/execution_api/routes/test_assets.py b/tests/api_fastapi/execution_api/routes/test_assets.py new file mode 100644 index 00000000000000..2cf34f8dd7bc72 --- /dev/null +++ b/tests/api_fastapi/execution_api/routes/test_assets.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pytest + +from airflow.models.asset import AssetActive, AssetModel +from airflow.utils import timezone + +DEFAULT_DATE = timezone.parse("2021-01-01T00:00:00") + +pytestmark = pytest.mark.db_test + + +class TestGetAssetByName: + def test_get_asset_by_name(self, client, session): + asset = AssetModel( + id=1, + name="test_get_asset_by_name", + uri="s3://bucket/key", + group="asset", + extra={"foo": "bar"}, + created_at=DEFAULT_DATE, + updated_at=DEFAULT_DATE, + ) + + asset_active = AssetActive.for_asset(asset) + + session.add_all([asset, asset_active]) + session.commit() + + response = client.get("/execution/assets/by-name", params={"name": "test_get_asset_by_name"}) + + assert response.status_code == 200 + assert response.json() == { + "name": "test_get_asset_by_name", + "uri": "s3://bucket/key", + "group": "asset", + "extra": {"foo": "bar"}, + } + + session.delete(asset) + session.delete(asset_active) + session.commit() + + def test_asset_name_not_found(self, client): + response = client.get("/execution/assets/by-name", params={"name": "non_existent"}) + + assert response.status_code == 404 + assert response.json() == { + "detail": { + "message": "Asset with name non_existent not found", + "reason": "not_found", + } + } + + +class TestGetAssetByUri: + def test_get_asset_by_uri(self, client, session): + asset = AssetModel( + name="test_get_asset_by_uri", + uri="s3://bucket/key", + group="asset", + extra={"foo": "bar"}, + ) + + asset_active = AssetActive.for_asset(asset) + + session.add_all([asset, asset_active]) + session.commit() + + response = client.get("/execution/assets/by-uri", params={"uri": "s3://bucket/key"}) + + assert response.status_code == 200 + assert response.json() == { + "name": "test_get_asset_by_uri", + "uri": "s3://bucket/key", + "group": "asset", + "extra": {"foo": "bar"}, + } + + session.delete(asset) + session.delete(asset_active) + session.commit() + + def test_asset_uri_not_found(self, client): + response = client.get("/execution/assets/by-uri", params={"uri": "non_existent"}) + + assert response.status_code == 404 + assert response.json() == { + "detail": { + "message": "Asset with URI non_existent not found", + "reason": "not_found", + } + } diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py index 0faeed038e648f..707595b92ffa22 100644 --- a/tests/serialization/test_serialized_objects.py +++ b/tests/serialization/test_serialized_objects.py @@ -43,11 +43,12 @@ from airflow.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey +from airflow.sdk.execution_time.context import AssetAliasEvent, OutletEventAccessor from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.serialized_objects import BaseSerialization from airflow.triggers.base import BaseTrigger from airflow.utils import timezone -from airflow.utils.context import AssetAliasEvent, OutletEventAccessor, OutletEventAccessors +from airflow.utils.context import OutletEventAccessors from airflow.utils.db import LazySelectSequence from airflow.utils.operator_resources import Resources from airflow.utils.state import DagRunState, State diff --git a/tests/utils/test_context.py b/tests/utils/test_context.py deleted file mode 100644 index 0046ca33cc4da8..00000000000000 --- a/tests/utils/test_context.py +++ /dev/null @@ -1,102 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from __future__ import annotations - -import pytest - -from airflow.models.asset import AssetActive, AssetAliasModel, AssetModel -from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasUniqueKey, AssetUniqueKey -from airflow.utils.context import AssetAliasEvent, OutletEventAccessor, OutletEventAccessors - - -class TestOutletEventAccessor: - @pytest.mark.parametrize( - "key, asset_alias_events", - ( - (AssetUniqueKey.from_asset(Asset("test_uri")), []), - ( - AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")), - [ - AssetAliasEvent( - source_alias_name="test_alias", - dest_asset_key=AssetUniqueKey(uri="test_uri", name="test_uri"), - extra={}, - ) - ], - ), - ), - ) - @pytest.mark.db_test - def test_add(self, key, asset_alias_events, session): - asset = Asset("test_uri") - session.add_all([AssetModel.from_public(asset), AssetActive.for_asset(asset)]) - session.flush() - - outlet_event_accessor = OutletEventAccessor(key=key, extra={}) - outlet_event_accessor.add(asset) - assert outlet_event_accessor.asset_alias_events == asset_alias_events - - @pytest.mark.db_test - @pytest.mark.parametrize( - "key, asset_alias_events", - ( - (AssetUniqueKey.from_asset(Asset("test_uri")), []), - ( - AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")), - [ - AssetAliasEvent( - source_alias_name="test_alias", - dest_asset_key=AssetUniqueKey(name="test-asset", uri="test://asset-uri/"), - extra={}, - ) - ], - ), - ), - ) - def test_add_with_db(self, key, asset_alias_events, session): - asset = Asset(uri="test://asset-uri", name="test-asset") - asm = AssetModel.from_public(asset) - aam = AssetAliasModel(name="test_alias") - session.add_all([asm, aam, AssetActive.for_asset(asset)]) - session.flush() - - outlet_event_accessor = OutletEventAccessor(key=key, extra={"not": ""}) - outlet_event_accessor.add(asset, extra={}) - assert outlet_event_accessor.asset_alias_events == asset_alias_events - - -class TestOutletEventAccessors: - @pytest.mark.parametrize( - "access_key, internal_key", - ( - (Asset("test"), AssetUniqueKey.from_asset(Asset("test"))), - ( - Asset(name="test", uri="test://asset"), - AssetUniqueKey.from_asset(Asset(name="test", uri="test://asset")), - ), - (AssetAlias("test_alias"), AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias"))), - ), - ) - def test___get_item__dict_key_not_exists(self, access_key, internal_key): - outlet_event_accessors = OutletEventAccessors() - assert len(outlet_event_accessors) == 0 - outlet_event_accessor = outlet_event_accessors[access_key] - assert len(outlet_event_accessors) == 1 - assert outlet_event_accessor.key == internal_key - assert outlet_event_accessor.extra == {} From caa401da7948057cff056c67f272307f1f5c7f4f Mon Sep 17 00:00:00 2001 From: Nate Robinson <137531405+nrobinson-intelycare@users.noreply.github.com> Date: Fri, 17 Jan 2025 16:56:06 -0500 Subject: [PATCH 6/8] Add support for timeout to BatchOperator (#45660) An execution timeout for the submit_job api call can now be passed through the operator to the boto3 call. --- .../src/airflow/providers/amazon/aws/operators/batch.py | 6 ++++++ providers/tests/amazon/aws/operators/test_batch.py | 7 +++++++ 2 files changed, 13 insertions(+) diff --git a/providers/src/airflow/providers/amazon/aws/operators/batch.py b/providers/src/airflow/providers/amazon/aws/operators/batch.py index 3df00fb04c37f1..e69508d89319f3 100644 --- a/providers/src/airflow/providers/amazon/aws/operators/batch.py +++ b/providers/src/airflow/providers/amazon/aws/operators/batch.py @@ -95,6 +95,7 @@ class BatchOperator(BaseOperator): If it is an array job, only the logs of the first task will be printed. :param awslogs_fetch_interval: The interval with which cloudwatch logs are to be fetched, 30 sec. :param poll_interval: (Deferrable mode only) Time in seconds to wait between polling. + :param submit_job_timeout: Execution timeout in seconds for submitted batch job. .. note:: Any custom waiters must return a waiter for these calls: @@ -184,6 +185,7 @@ def __init__( poll_interval: int = 30, awslogs_enabled: bool = False, awslogs_fetch_interval: timedelta = timedelta(seconds=30), + submit_job_timeout: int | None = None, **kwargs, ) -> None: BaseOperator.__init__(self, **kwargs) @@ -208,6 +210,7 @@ def __init__( self.poll_interval = poll_interval self.awslogs_enabled = awslogs_enabled self.awslogs_fetch_interval = awslogs_fetch_interval + self.submit_job_timeout = submit_job_timeout # params for hook self.max_retries = max_retries @@ -315,6 +318,9 @@ def submit_job(self, context: Context): "schedulingPriorityOverride": self.scheduling_priority_override, } + if self.submit_job_timeout: + args["timeout"] = {"attemptDurationSeconds": self.submit_job_timeout} + try: response = self.hook.client.submit_job(**trim_none_values(args)) except Exception as e: diff --git a/providers/tests/amazon/aws/operators/test_batch.py b/providers/tests/amazon/aws/operators/test_batch.py index 0c14c256edba93..c1b1d847b7d916 100644 --- a/providers/tests/amazon/aws/operators/test_batch.py +++ b/providers/tests/amazon/aws/operators/test_batch.py @@ -70,6 +70,7 @@ def setup_method(self, _, get_client_type_mock): aws_conn_id="airflow_test", region_name="eu-west-1", tags={}, + submit_job_timeout=3600, ) self.client_mock = self.get_client_type_mock.return_value # We're mocking all actual AWS calls and don't need a connection. This @@ -109,6 +110,7 @@ def test_init(self): assert self.batch.hook.client == self.client_mock assert self.batch.tags == {} assert self.batch.wait_for_completion is True + assert self.batch.submit_job_timeout == 3600 self.get_client_type_mock.assert_called_once_with(region_name="eu-west-1") @@ -141,6 +143,7 @@ def test_init_defaults(self): assert issubclass(type(batch_job.hook.client), botocore.client.BaseClient) assert batch_job.tags == {} assert batch_job.wait_for_completion is True + assert batch_job.submit_job_timeout is None def test_template_fields_overrides(self): assert self.batch.template_fields == ( @@ -181,6 +184,7 @@ def test_execute_without_failures(self, check_mock, wait_mock, job_description_m parameters={}, retryStrategy={"attempts": 1}, tags={}, + timeout={"attemptDurationSeconds": 3600}, ) assert self.batch.job_id == JOB_ID @@ -205,6 +209,7 @@ def test_execute_with_failures(self): parameters={}, retryStrategy={"attempts": 1}, tags={}, + timeout={"attemptDurationSeconds": 3600}, ) @mock.patch.object(BatchClientHook, "get_job_description") @@ -261,6 +266,7 @@ def test_execute_with_ecs_overrides(self, check_mock, wait_mock, job_description parameters={}, retryStrategy={"attempts": 1}, tags={}, + timeout={"attemptDurationSeconds": 3600}, ) @mock.patch.object(BatchClientHook, "get_job_description") @@ -359,6 +365,7 @@ def test_execute_with_eks_overrides(self, check_mock, wait_mock, job_description parameters={}, retryStrategy={"attempts": 1}, tags={}, + timeout={"attemptDurationSeconds": 3600}, ) @mock.patch.object(BatchClientHook, "check_job_success") From 3af9ddda3e342e03a5dfd25b8c2c3bdcd3987d4f Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Sat, 18 Jan 2025 03:54:40 +0530 Subject: [PATCH 7/8] Handle DagRun `conf` errors in FAB's list view (#45763) Before: ``` root@6080aa107d9c:/opt/airflow# airflow webserver ____________ _____________ ____ |__( )_________ __/__ /________ __ ____ /| |_ /__ ___/_ /_ __ /_ __ \_ | /| / / ___ ___ | / _ / _ __/ _ / / /_/ /_ |/ |/ / _/_/ |_/_/ /_/ /_/ /_/ \____/____/|__/ Running the Gunicorn Server with: Workers: 4 sync Host: 0.0.0.0:8080 Timeout: 120 Logfiles: - - Access Logformat: ================================================================= [2025-01-17T21:39:45.349+0000] {override.py:949} WARNING - No user yet created, use flask fab command to do it. [2025-01-17T21:39:45.637+0000] {forms.py:107} ERROR - Column conf Type not supported [2025-01-17T21:39:45.637+0000] {forms.py:107} ERROR - Column conf Type not supported ^C[2025-01-17T21:39:46.187+0000] {webserver_command.py:430} INFO - Received signal: 2. Closing gunicorn. ``` After: ``` root@6080aa107d9c:/opt/airflow# airflow webserver ____________ _____________ ____ |__( )_________ __/__ /________ __ ____ /| |_ /__ ___/_ /_ __ /_ __ \_ | /| / / ___ ___ | / _ / _ __/ _ / / /_/ /_ |/ |/ / _/_/ |_/_/ /_/ /_/ /_/ \____/____/|__/ Running the Gunicorn Server with: Workers: 4 sync Host: 0.0.0.0:8080 Timeout: 120 Logfiles: - - Access Logformat: ================================================================= [2025-01-17T21:40:16.365+0000] {override.py:949} WARNING - No user yet created, use flask fab command to do it. [2025-01-17 21:40:18 +0000] [18864] [INFO] Starting gunicorn 23.0.0 [2025-01-17 21:40:29 +0000] [18864] [INFO] Listening at: http://0.0.0.0:8080 (18864) [2025-01-17 21:40:29 +0000] [18864] [INFO] Using worker: sync ``` --- airflow/www/views.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/airflow/www/views.py b/airflow/www/views.py index e021c441f5d225..ded98f1e1d860d 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -4768,6 +4768,8 @@ class DagRunModelView(AirflowModelView): permissions.ACTION_CAN_ACCESS_MENU, ] + add_exclude_columns = ["conf"] + list_columns = [ "state", "dag_id", @@ -4803,7 +4805,6 @@ class DagRunModelView(AirflowModelView): "start_date", "end_date", "run_id", - "conf", "note", ] From 262537b9325640dc4a62bab9177cc8a31ebd8fe3 Mon Sep 17 00:00:00 2001 From: Shubham Raj <48172486+shubhamraj-git@users.noreply.github.com> Date: Sat, 18 Jan 2025 04:12:02 +0530 Subject: [PATCH 8/8] Use bulk API for importing variables (#45744) * initial refactor * colour * remove import api * remove import tests --- .../core_api/openapi/v1-generated.yaml | 92 --------------- .../core_api/routes/public/variables.py | 59 +--------- airflow/ui/openapi-gen/queries/common.ts | 3 - airflow/ui/openapi-gen/queries/queries.ts | 41 ------- .../ui/openapi-gen/requests/schemas.gen.ts | 37 ------ .../ui/openapi-gen/requests/services.gen.ts | 30 ----- airflow/ui/openapi-gen/requests/types.gen.ts | 51 --------- .../pages/Variables/ImportVariablesForm.tsx | 86 +++++++++++--- airflow/ui/src/queries/useImportVariables.ts | 33 +++--- .../core_api/routes/public/test_variables.py | 107 ------------------ 10 files changed, 92 insertions(+), 447 deletions(-) diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index 9c0d70a265ad29..ed0ce173bc55bb 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -5989,68 +5989,6 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' - /public/variables/import: - post: - tags: - - Variable - summary: Import Variables - description: Import variables from a JSON file. - operationId: import_variables - parameters: - - name: action_if_exists - in: query - required: false - schema: - enum: - - overwrite - - fail - - skip - type: string - default: fail - title: Action If Exists - requestBody: - required: true - content: - multipart/form-data: - schema: - $ref: '#/components/schemas/Body_import_variables' - responses: - '200': - description: Successful Response - content: - application/json: - schema: - $ref: '#/components/schemas/VariablesImportResponse' - '401': - content: - application/json: - schema: - $ref: '#/components/schemas/HTTPExceptionResponse' - description: Unauthorized - '403': - content: - application/json: - schema: - $ref: '#/components/schemas/HTTPExceptionResponse' - description: Forbidden - '400': - content: - application/json: - schema: - $ref: '#/components/schemas/HTTPExceptionResponse' - description: Bad Request - '409': - content: - application/json: - schema: - $ref: '#/components/schemas/HTTPExceptionResponse' - description: Conflict - '422': - content: - application/json: - schema: - $ref: '#/components/schemas/HTTPExceptionResponse' - description: Unprocessable Entity /public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/logs/{try_number}: get: tags: @@ -6629,16 +6567,6 @@ components: - status title: BaseInfoResponse description: Base info serializer for responses. - Body_import_variables: - properties: - file: - type: string - format: binary - title: File - type: object - required: - - file - title: Body_import_variables ClearTaskInstancesBody: properties: dry_run: @@ -10066,26 +9994,6 @@ components: - is_encrypted title: VariableResponse description: Variable serializer for responses. - VariablesImportResponse: - properties: - created_variable_keys: - items: - type: string - type: array - title: Created Variable Keys - import_count: - type: integer - title: Import Count - created_count: - type: integer - title: Created Count - type: object - required: - - created_variable_keys - - import_count - - created_count - title: VariablesImportResponse - description: Import Variables serializer for responses. VersionInfo: properties: version: diff --git a/airflow/api_fastapi/core_api/routes/public/variables.py b/airflow/api_fastapi/core_api/routes/public/variables.py index 19d1b24d7eba8d..6bd850d76e609d 100644 --- a/airflow/api_fastapi/core_api/routes/public/variables.py +++ b/airflow/api_fastapi/core_api/routes/public/variables.py @@ -16,10 +16,9 @@ # under the License. from __future__ import annotations -import json -from typing import Annotated, Literal +from typing import Annotated -from fastapi import Depends, HTTPException, Query, UploadFile, status +from fastapi import Depends, HTTPException, Query, status from fastapi.exceptions import RequestValidationError from pydantic import ValidationError from sqlalchemy import select @@ -39,7 +38,6 @@ VariableBulkResponse, VariableCollectionResponse, VariableResponse, - VariablesImportResponse, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc from airflow.api_fastapi.core_api.services.public.variables import ( @@ -192,59 +190,6 @@ def post_variable( return variable -@variables_router.post( - "/import", - status_code=status.HTTP_200_OK, - responses=create_openapi_http_exception_doc( - [status.HTTP_400_BAD_REQUEST, status.HTTP_409_CONFLICT, status.HTTP_422_UNPROCESSABLE_ENTITY] - ), -) -def import_variables( - file: UploadFile, - session: SessionDep, - action_if_exists: Literal["overwrite", "fail", "skip"] = "fail", -) -> VariablesImportResponse: - """Import variables from a JSON file.""" - try: - file_content = file.file.read().decode("utf-8") - variables = json.loads(file_content) - - if not isinstance(variables, dict): - raise ValueError("Uploaded JSON must contain key-value pairs.") - except (json.JSONDecodeError, ValueError) as e: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid JSON format: {e}") - - if not variables: - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail="No variables found in the provided JSON.", - ) - - existing_keys = {variable for variable in session.execute(select(Variable.key)).scalars()} - import_keys = set(variables.keys()) - - matched_keys = existing_keys & import_keys - - if action_if_exists == "fail" and matched_keys: - raise HTTPException( - status_code=status.HTTP_409_CONFLICT, - detail=f"The variables with these keys: {matched_keys} already exists.", - ) - elif action_if_exists == "skip": - create_keys = import_keys - matched_keys - else: - create_keys = import_keys - - for key in create_keys: - Variable.set(key=key, value=variables[key], session=session) - - return VariablesImportResponse( - created_count=len(create_keys), - import_count=len(import_keys), - created_variable_keys=list(create_keys), - ) - - @variables_router.patch("") def bulk_variables( request: VariableBulkBody, diff --git a/airflow/ui/openapi-gen/queries/common.ts b/airflow/ui/openapi-gen/queries/common.ts index a34f57020cd83a..b6cf77099005b9 100644 --- a/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow/ui/openapi-gen/queries/common.ts @@ -1625,9 +1625,6 @@ export type PoolServicePostPoolMutationResult = Awaited >; -export type VariableServiceImportVariablesMutationResult = Awaited< - ReturnType ->; export type BackfillServicePauseBackfillMutationResult = Awaited< ReturnType >; diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index a87e218adce4c4..a43172d73c7ca3 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -32,7 +32,6 @@ import { } from "../requests/services.gen"; import { BackfillPostBody, - Body_import_variables, ClearTaskInstancesBody, ConnectionBody, ConnectionBulkBody, @@ -3124,46 +3123,6 @@ export const useVariableServicePostVariable = < VariableService.postVariable({ requestBody }) as unknown as Promise, ...options, }); -/** - * Import Variables - * Import variables from a JSON file. - * @param data The data for the request. - * @param data.formData - * @param data.actionIfExists - * @returns VariablesImportResponse Successful Response - * @throws ApiError - */ -export const useVariableServiceImportVariables = < - TData = Common.VariableServiceImportVariablesMutationResult, - TError = unknown, - TContext = unknown, ->( - options?: Omit< - UseMutationOptions< - TData, - TError, - { - actionIfExists?: "overwrite" | "fail" | "skip"; - formData: Body_import_variables; - }, - TContext - >, - "mutationFn" - >, -) => - useMutation< - TData, - TError, - { - actionIfExists?: "overwrite" | "fail" | "skip"; - formData: Body_import_variables; - }, - TContext - >({ - mutationFn: ({ actionIfExists, formData }) => - VariableService.importVariables({ actionIfExists, formData }) as unknown as Promise, - ...options, - }); /** * Pause Backfill * @param data The data for the request. diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow/ui/openapi-gen/requests/schemas.gen.ts index 99341b970cac14..598c487ccbd024 100644 --- a/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -482,19 +482,6 @@ export const $BaseInfoResponse = { description: "Base info serializer for responses.", } as const; -export const $Body_import_variables = { - properties: { - file: { - type: "string", - format: "binary", - title: "File", - }, - }, - type: "object", - required: ["file"], - title: "Body_import_variables", -} as const; - export const $ClearTaskInstancesBody = { properties: { dry_run: { @@ -5749,30 +5736,6 @@ export const $VariableResponse = { description: "Variable serializer for responses.", } as const; -export const $VariablesImportResponse = { - properties: { - created_variable_keys: { - items: { - type: "string", - }, - type: "array", - title: "Created Variable Keys", - }, - import_count: { - type: "integer", - title: "Import Count", - }, - created_count: { - type: "integer", - title: "Created Count", - }, - }, - type: "object", - required: ["created_variable_keys", "import_count", "created_count"], - title: "VariablesImportResponse", - description: "Import Variables serializer for responses.", -} as const; - export const $VersionInfo = { properties: { version: { diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index 0ce36911eca13f..7f888aca34da86 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -187,8 +187,6 @@ import type { PostVariableResponse, BulkVariablesData, BulkVariablesResponse, - ImportVariablesData, - ImportVariablesResponse, ReparseDagFileData, ReparseDagFileResponse, GetHealthResponse, @@ -3095,34 +3093,6 @@ export class VariableService { }, }); } - - /** - * Import Variables - * Import variables from a JSON file. - * @param data The data for the request. - * @param data.formData - * @param data.actionIfExists - * @returns VariablesImportResponse Successful Response - * @throws ApiError - */ - public static importVariables(data: ImportVariablesData): CancelablePromise { - return __request(OpenAPI, { - method: "POST", - url: "/public/variables/import", - query: { - action_if_exists: data.actionIfExists, - }, - formData: data.formData, - mediaType: "multipart/form-data", - errors: { - 400: "Bad Request", - 401: "Unauthorized", - 403: "Forbidden", - 409: "Conflict", - 422: "Unprocessable Entity", - }, - }); - } } export class DagParsingService { diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index 09911de7c97454..81925913fba42e 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -138,10 +138,6 @@ export type BaseInfoResponse = { status: string | null; }; -export type Body_import_variables = { - file: Blob | File; -}; - /** * Request body for Clear Task Instances endpoint. */ @@ -1389,15 +1385,6 @@ export type VariableResponse = { is_encrypted: boolean; }; -/** - * Import Variables serializer for responses. - */ -export type VariablesImportResponse = { - created_variable_keys: Array; - import_count: number; - created_count: number; -}; - /** * Version information serializer for responses. */ @@ -2259,13 +2246,6 @@ export type BulkVariablesData = { export type BulkVariablesResponse = VariableBulkResponse; -export type ImportVariablesData = { - actionIfExists?: "overwrite" | "fail" | "skip"; - formData: Body_import_variables; -}; - -export type ImportVariablesResponse = VariablesImportResponse; - export type ReparseDagFileData = { fileToken: string; }; @@ -4766,37 +4746,6 @@ export type $OpenApiTs = { }; }; }; - "/public/variables/import": { - post: { - req: ImportVariablesData; - res: { - /** - * Successful Response - */ - 200: VariablesImportResponse; - /** - * Bad Request - */ - 400: HTTPExceptionResponse; - /** - * Unauthorized - */ - 401: HTTPExceptionResponse; - /** - * Forbidden - */ - 403: HTTPExceptionResponse; - /** - * Conflict - */ - 409: HTTPExceptionResponse; - /** - * Unprocessable Entity - */ - 422: HTTPExceptionResponse; - }; - }; - }; "/public/parseDagFile/{file_token}": { put: { req: ReparseDagFileData; diff --git a/airflow/ui/src/pages/Variables/ImportVariablesForm.tsx b/airflow/ui/src/pages/Variables/ImportVariablesForm.tsx index 89029a1fd0aef9..5645c91d4baf34 100644 --- a/airflow/ui/src/pages/Variables/ImportVariablesForm.tsx +++ b/airflow/ui/src/pages/Variables/ImportVariablesForm.tsx @@ -55,21 +55,65 @@ const ImportVariablesForm = ({ onClose }: ImportVariablesFormProps) => { onSuccessConfirm: onClose, }); - const [selectedFile, setSelectedFile] = useState(undefined); - const [actionIfExists, setActionIfExists] = useState<"fail" | "overwrite" | "skip" | undefined>("fail"); + const [actionIfExists, setActionIfExists] = useState<"fail" | "overwrite" | "skip">("fail"); + const [isParsing, setIsParsing] = useState(false); + const [fileContent, setFileContent] = useState | undefined>(undefined); + + const onFileChange = (file: File) => { + setIsParsing(true); + const reader = new FileReader(); + + reader.addEventListener("load", (event) => { + try { + const text = event.target?.result as string; + const parsedContent = JSON.parse(text) as unknown; + + if ( + typeof parsedContent === "object" && + parsedContent !== null && + Object.entries(parsedContent).every( + ([key, value]) => typeof key === "string" && typeof value === "string", + ) + ) { + const typedContent = parsedContent as Record; + + setFileContent(typedContent); + } else { + throw new Error("Invalid JSON format"); + } + } catch { + setError({ + body: { + detail: + 'Error Parsing JSON File: Upload a JSON file containing variables (e.g., {"key": "value", ...}).', + }, + }); + setFileContent(undefined); + } finally { + setIsParsing(false); + } + }); + + reader.readAsText(file); + }; const onSubmit = () => { setError(undefined); - if (selectedFile) { - const formData = new FormData(); + if (fileContent) { + const formattedPayload = { + actions: [ + { + action: "create" as const, + action_if_exists: actionIfExists, + variables: Object.entries(fileContent).map(([key, value]) => ({ + key, + value, + })), + }, + ], + }; - formData.append("file", selectedFile); - mutate({ - actionIfExists, - formData: { - file: selectedFile, - }, - }); + mutate({ requestBody: formattedPayload }); } }; @@ -82,7 +126,11 @@ const ImportVariablesForm = ({ onClose }: ImportVariablesFormProps) => { mb={6} onFileChange={(files) => { if (files.acceptedFiles.length > 0) { - setSelectedFile(files.acceptedFiles[0]); + setError(undefined); + setFileContent(undefined); + if (files.acceptedFiles[0]) { + onFileChange(files.acceptedFiles[0]); + } } }} required @@ -99,7 +147,8 @@ const ImportVariablesForm = ({ onClose }: ImportVariablesFormProps) => { focusVisibleRing="inside" me="-1" onClick={() => { - setSelectedFile(undefined); + setError(undefined); + setFileContent(undefined); }} pointerEvents="auto" size="xs" @@ -112,6 +161,11 @@ const ImportVariablesForm = ({ onClose }: ImportVariablesFormProps) => { > + {isParsing ? ( +
+ Parsing file... +
+ ) : undefined} { {isPending ? ( - +
- +
) : undefined} -
diff --git a/airflow/ui/src/queries/useImportVariables.ts b/airflow/ui/src/queries/useImportVariables.ts index b4692e37c535fa..212d9c83e19935 100644 --- a/airflow/ui/src/queries/useImportVariables.ts +++ b/airflow/ui/src/queries/useImportVariables.ts @@ -19,36 +19,43 @@ import { useQueryClient } from "@tanstack/react-query"; import { useState } from "react"; -import { useVariableServiceGetVariablesKey, useVariableServiceImportVariables } from "openapi/queries"; +import { useVariableServiceBulkVariables, useVariableServiceGetVariablesKey } from "openapi/queries"; import { toaster } from "src/components/ui"; export const useImportVariables = ({ onSuccessConfirm }: { onSuccessConfirm: () => void }) => { const queryClient = useQueryClient(); const [error, setError] = useState(undefined); - const onSuccess = async (responseData: { - created_count: number; - created_variable_keys: Array; - import_count: number; - }) => { + const onSuccess = async (responseData: { create?: { errors: Array; success: Array } }) => { await queryClient.invalidateQueries({ queryKey: [useVariableServiceGetVariablesKey], }); - toaster.create({ - description: `${responseData.created_count} of ${responseData.import_count} variables imported successfully. Keys imported are ${responseData.created_variable_keys.join(", ")}`, - title: "Import Variables Request Successful", - type: "success", - }); + if (responseData.create) { + const { errors, success } = responseData.create; + + if (Array.isArray(errors) && errors.length > 0) { + const apiError = errors[0] as { error: string }; - onSuccessConfirm(); + setError({ + body: { detail: apiError.error }, + }); + } else if (Array.isArray(success) && success.length > 0) { + toaster.create({ + description: `${success.length} variables created successfully. Keys: ${success.join(", ")}`, + title: "Import Variables Request Successful", + type: "success", + }); + onSuccessConfirm(); + } + } }; const onError = (_error: unknown) => { setError(_error); }; - const { isPending, mutate } = useVariableServiceImportVariables({ + const { isPending, mutate } = useVariableServiceBulkVariables({ onError, onSuccess, }); diff --git a/tests/api_fastapi/core_api/routes/public/test_variables.py b/tests/api_fastapi/core_api/routes/public/test_variables.py index 3cbab24878ac78..fac8b27472449c 100644 --- a/tests/api_fastapi/core_api/routes/public/test_variables.py +++ b/tests/api_fastapi/core_api/routes/public/test_variables.py @@ -433,113 +433,6 @@ def test_post_should_respond_422_when_value_is_null(self, test_client): } -class TestImportVariables(TestVariableEndpoint): - @pytest.mark.enable_redact - @pytest.mark.parametrize( - "variables_data, behavior, expected_status_code, expected_created_count, expected_created_keys, expected_conflict_keys", - [ - ( - {"new_key1": "new_value1", "new_key2": "new_value2"}, - "overwrite", - 200, - 2, - {"new_key1", "new_key2"}, - set(), - ), - ( - {"new_key1": "new_value1", "new_key2": "new_value2"}, - "skip", - 200, - 2, - {"new_key1", "new_key2"}, - set(), - ), - ( - {"test_variable_key": "new_value", "new_key": "new_value"}, - "fail", - 409, - 0, - set(), - {"test_variable_key"}, - ), - ( - {"test_variable_key": "new_value", "new_key": "new_value"}, - "skip", - 200, - 1, - {"new_key"}, - {"test_variable_key"}, - ), - ( - {"test_variable_key": "new_value", "new_key": "new_value"}, - "overwrite", - 200, - 2, - {"test_variable_key", "new_key"}, - set(), - ), - ], - ) - def test_import_variables( - self, - test_client, - variables_data, - behavior, - expected_status_code, - expected_created_count, - expected_created_keys, - expected_conflict_keys, - session, - ): - """Test variable import with different behaviors (overwrite, fail, skip).""" - - self.create_variables() - - file = create_file_upload(variables_data) - response = test_client.post( - "/public/variables/import", - files={"file": ("variables.json", file, "application/json")}, - params={"action_if_exists": behavior}, - ) - - assert response.status_code == expected_status_code - - if expected_status_code == 200: - body = response.json() - assert body["created_count"] == expected_created_count - assert set(body["created_variable_keys"]) == expected_created_keys - - elif expected_status_code == 409: - body = response.json() - assert ( - f"The variables with these keys: {expected_conflict_keys} already exists." == body["detail"] - ) - - def test_import_invalid_json(self, test_client): - """Test invalid JSON import.""" - file = BytesIO(b"import variable test") - response = test_client.post( - "/public/variables/import", - files={"file": ("variables.json", file, "application/json")}, - params={"action_if_exists": "overwrite"}, - ) - - assert response.status_code == 400 - assert "Invalid JSON format" in response.json()["detail"] - - def test_import_empty_file(self, test_client): - """Test empty file import.""" - file = create_file_upload({}) - response = test_client.post( - "/public/variables/import", - files={"file": ("empty_variables.json", file, "application/json")}, - params={"action_if_exists": "overwrite"}, - ) - - assert response.status_code == 422 - assert response.json()["detail"] == "No variables found in the provided JSON." - - class TestBulkVariables(TestVariableEndpoint): @pytest.mark.enable_redact @pytest.mark.parametrize(