diff --git a/airflow-ctl-tests/tests/airflowctl_tests/test_airflowctl_commands.py b/airflow-ctl-tests/tests/airflowctl_tests/test_airflowctl_commands.py index 15595269d26c7..e92cb450e5234 100644 --- a/airflow-ctl-tests/tests/airflowctl_tests/test_airflowctl_commands.py +++ b/airflow-ctl-tests/tests/airflowctl_tests/test_airflowctl_commands.py @@ -93,6 +93,8 @@ def date_param(): "dags update --dag-id=example_bash_operator --no-is-paused", # Dag Run commands "dagrun list --dag-id example_bash_operator --state success --limit=1", + # Task instance commands - need a Dag run with completed tasks + 'taskinstances get --dag-id=example_bash_operator --dag-run-id="manual__{date_param}" --task-id=runme_0', # XCom commands - need a Dag run with completed tasks 'xcom add --dag-id=example_bash_operator --dag-run-id="manual__{date_param}" --task-id=runme_0 --key={xcom_key} --value=\'{{"test": "value"}}\'', 'xcom get --dag-id=example_bash_operator --dag-run-id="manual__{date_param}" --task-id=runme_0 --key={xcom_key}', diff --git a/airflow-ctl/docs/images/command_hashes.txt b/airflow-ctl/docs/images/command_hashes.txt index 8e590f3b820cd..3c8b7523e7e08 100644 --- a/airflow-ctl/docs/images/command_hashes.txt +++ b/airflow-ctl/docs/images/command_hashes.txt @@ -1,4 +1,4 @@ -main:27a22c00dcf32e7a1a4f06672dc8e3c8 +main:0460d9c03248bee26207b20b05aa36b9 assets:70619a2d92bda80930cde2aefcd8e1cd auth:d79e9c7d00c432bdbcbc2a86e2e32053 backfill:74c8737b0a62a86ed3605fa9e6165874 diff --git a/airflow-ctl/docs/images/output_main.svg b/airflow-ctl/docs/images/output_main.svg index f586877bce8eb..9c8cec5269b68 100644 --- a/airflow-ctl/docs/images/output_main.svg +++ b/airflow-ctl/docs/images/output_main.svg @@ -1,4 +1,4 @@ - + - - + + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + + + + + + + + + + - + - + - - Usage:airflowctl [-hGROUP_OR_COMMAND... - -Positional Arguments: -GROUP_OR_COMMAND - -    Groups -assetsPerform Assets operations -authManage authentication for CLI. Either pass token from -environment variable/parameter or pass username and -password. -backfillPerform Backfill operations -configPerform Config operations -connectionsPerform Connections operations -dagrunPerform DagRun operations -dagsPerform Dags operations -jobsPerform Jobs operations -pluginsPerform Plugins operations -poolsPerform Pools operations -providersPerform Providers operations -variablesPerform Variables operations -xcomPerform XCom operations - -    Commands: -versionShow version information - -Options: --h--helpshow this help message and exit + + Usage:airflowctl [-hGROUP_OR_COMMAND... + +Positional Arguments: +GROUP_OR_COMMAND + +    Groups +assetsPerform Assets operations +authManage authentication for CLI. Either pass token from +environment variable/parameter or pass username and +password. +backfillPerform Backfill operations +configPerform Config operations +connectionsPerform Connections operations +dagrunPerform DagRun operations +dagsPerform Dags operations +jobsPerform Jobs operations +pluginsPerform Plugins operations +poolsPerform Pools operations +providersPerform Providers operations +taskinstances +Perform TaskInstances operations +tasksManage Airflow tasks +variablesPerform Variables operations +xcomPerform XCom operations + +    Commands: +versionShow version information + +Options: +-h--helpshow this help message and exit diff --git a/airflow-ctl/src/airflowctl/api/client.py b/airflow-ctl/src/airflowctl/api/client.py index b01200fac1c7f..cd957060cb1db 100644 --- a/airflow-ctl/src/airflowctl/api/client.py +++ b/airflow-ctl/src/airflowctl/api/client.py @@ -57,6 +57,7 @@ PoolsOperations, ProvidersOperations, ServerResponseError, + TaskInstancesOperations, VariablesOperations, VersionOperations, XComOperations, @@ -467,6 +468,12 @@ def xcom(self): """Operations related to XComs.""" return XComOperations(self) + @lru_cache() # type: ignore[prop-decorator] + @property + def task_instances(self): + """Operations related to task instances.""" + return TaskInstancesOperations(self) + @lru_cache() # type: ignore[prop-decorator] @property def plugins(self): diff --git a/airflow-ctl/src/airflowctl/api/operations.py b/airflow-ctl/src/airflowctl/api/operations.py index f52ba055c1c72..ca3774d89022f 100644 --- a/airflow-ctl/src/airflowctl/api/operations.py +++ b/airflow-ctl/src/airflowctl/api/operations.py @@ -68,6 +68,7 @@ ProviderCollectionResponse, QueuedEventCollectionResponse, QueuedEventResponse, + TaskInstanceResponse, TriggerDAGRunPostBody, VariableBody, VariableCollectionResponse, @@ -906,6 +907,32 @@ def delete( raise e +class TaskInstancesOperations(BaseOperations): + """Task instance operations.""" + + def get( + self, + dag_id: str, + dag_run_id: str, + task_id: str, + map_index: int = None, # type: ignore + ) -> TaskInstanceResponse | ServerResponseError: + """ + Get a task instance. + + When ``map_index`` is non-negative, the mapped task instance endpoint is + called; otherwise the standard (unmapped) endpoint is used. + """ + path = f"dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}" + if map_index is not None and map_index >= 0: + path = f"{path}/{map_index}" + try: + self.response = self.client.get(path) + return TaskInstanceResponse.model_validate_json(self.response.content) + except ServerResponseError as e: + raise e + + class PluginsOperations(BaseOperations): """Plugins operations.""" diff --git a/airflow-ctl/src/airflowctl/ctl/cli_config.py b/airflow-ctl/src/airflowctl/ctl/cli_config.py index aa96304f501fb..1b10be6292743 100755 --- a/airflow-ctl/src/airflowctl/ctl/cli_config.py +++ b/airflow-ctl/src/airflowctl/ctl/cli_config.py @@ -268,6 +268,36 @@ def _load_help_texts_yaml() -> dict[str, dict[str, str]]: help="The Dag ID of the Dag to pause or unpause", ) +# Task Commands Args +ARG_TASK_DAG_ID = Arg( + flags=("--dag-id",), + type=str, + dest="dag_id", + required=True, + help="The Dag ID", +) +ARG_DAG_RUN_ID = Arg( + flags=("--dag-run-id",), + type=str, + dest="dag_run_id", + required=True, + help="The Dag Run ID", +) +ARG_TASK_ID = Arg( + flags=("--task-id",), + type=str, + dest="task_id", + required=True, + help="The Task ID", +) +ARG_MAP_INDEX = Arg( + flags=("--map-index",), + type=int, + dest="map_index", + default=-1, + help="If set, query the mapped task instance with this map index (negative means non-mapped)", +) + ARG_ACTION_ON_EXISTING_KEY = Arg( flags=("-a", "--action-on-existing-key"), type=str, @@ -953,6 +983,21 @@ def merge_commands( ), ) +TASK_COMMANDS = ( + ActionCommand( + name="state", + help="Get the state of a task instance", + func=lazy_load_command("airflowctl.ctl.commands.task_command.task_state"), + args=( + ARG_TASK_DAG_ID, + ARG_DAG_RUN_ID, + ARG_TASK_ID, + ARG_MAP_INDEX, + ARG_OUTPUT, + ), + ), +) + core_commands: list[CLICommand] = [ GroupCommand( name="auth", @@ -995,6 +1040,11 @@ def merge_commands( help="Manage Airflow variables", subcommands=VARIABLE_COMMANDS, ), + GroupCommand( + name="tasks", + help="Manage Airflow tasks", + subcommands=TASK_COMMANDS, + ), ] # Add generated group commands core_commands = merge_commands( diff --git a/airflow-ctl/src/airflowctl/ctl/commands/task_command.py b/airflow-ctl/src/airflowctl/ctl/commands/task_command.py new file mode 100644 index 0000000000000..b5dd1b56b9fac --- /dev/null +++ b/airflow-ctl/src/airflowctl/ctl/commands/task_command.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from airflowctl.api.client import NEW_API_CLIENT, ClientKind, provide_api_client +from airflowctl.ctl.console_formatting import AirflowConsole + + +@provide_api_client(kind=ClientKind.CLI) +def task_state(args, api_client=NEW_API_CLIENT) -> None: + """Get the state of a task instance.""" + ti = api_client.task_instances.get( + dag_id=args.dag_id, + dag_run_id=args.dag_run_id, + task_id=args.task_id, + map_index=args.map_index, + ) + AirflowConsole().print_as(data=[{"state": ti.state}], output=args.output) diff --git a/airflow-ctl/src/airflowctl/ctl/help_texts.yaml b/airflow-ctl/src/airflowctl/ctl/help_texts.yaml index eb566a96b1fb8..b60082bfacfa3 100644 --- a/airflow-ctl/src/airflowctl/ctl/help_texts.yaml +++ b/airflow-ctl/src/airflowctl/ctl/help_texts.yaml @@ -100,3 +100,6 @@ xcom: plugins: list: "List all installed Airflow plugins" list-import-errors: "List all plugin import errors" + +taskinstances: + get: "Retrieve a task instance by Dag ID, run ID, and task ID" diff --git a/airflow-ctl/tests/airflow_ctl/ctl/commands/test_task_command.py b/airflow-ctl/tests/airflow_ctl/ctl/commands/test_task_command.py new file mode 100644 index 0000000000000..6cccba8dcce30 --- /dev/null +++ b/airflow-ctl/tests/airflow_ctl/ctl/commands/test_task_command.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime +import uuid + +import pytest + +from airflowctl.api.client import ClientKind +from airflowctl.api.datamodels.generated import TaskInstanceResponse, TaskInstanceState +from airflowctl.api.operations import ServerResponseError +from airflowctl.ctl import cli_parser +from airflowctl.ctl.commands import task_command + + +class TestTaskCommands: + parser = cli_parser.get_parser() + dag_id = "example_dag" + dag_run_id = "manual__2024-01-01T00:00:00+00:00" + task_id = "my_task" + + task_instance_response = TaskInstanceResponse( + id=uuid.uuid4(), + task_id=task_id, + dag_id=dag_id, + dag_run_id=dag_run_id, + map_index=-1, + run_after=datetime.datetime(2024, 1, 1, 0, 0, 0), + try_number=1, + max_tries=1, + task_display_name=task_id, + dag_display_name=dag_id, + pool="default_pool", + pool_slots=1, + executor_config="{}", + state=TaskInstanceState.SUCCESS, + ) + + def test_task_state(self, api_client_maker): + api_client = api_client_maker( + path=f"/api/v2/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/taskInstances/{self.task_id}", + response_json=self.task_instance_response.model_dump(mode="json"), + expected_http_status_code=200, + kind=ClientKind.CLI, + ) + task_command.task_state( + self.parser.parse_args( + [ + "tasks", + "state", + f"--dag-id={self.dag_id}", + f"--dag-run-id={self.dag_run_id}", + f"--task-id={self.task_id}", + ] + ), + api_client=api_client, + ) + + def test_task_state_not_found(self, api_client_maker): + api_client = api_client_maker( + path=f"/api/v2/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/taskInstances/{self.task_id}", + response_json={"detail": "Task instance not found"}, + expected_http_status_code=404, + kind=ClientKind.CLI, + ) + with pytest.raises(ServerResponseError): + task_command.task_state( + self.parser.parse_args( + [ + "tasks", + "state", + f"--dag-id={self.dag_id}", + f"--dag-run-id={self.dag_run_id}", + f"--task-id={self.task_id}", + ] + ), + api_client=api_client, + ) + + @pytest.mark.parametrize("map_index", [0, 1, 7]) + def test_task_state_mapped(self, api_client_maker, map_index): + mapped_response = self.task_instance_response.model_copy(update={"map_index": map_index}) + api_client = api_client_maker( + path=( + f"/api/v2/dags/{self.dag_id}/dagRuns/{self.dag_run_id}" + f"/taskInstances/{self.task_id}/{map_index}" + ), + response_json=mapped_response.model_dump(mode="json"), + expected_http_status_code=200, + kind=ClientKind.CLI, + ) + task_command.task_state( + self.parser.parse_args( + [ + "tasks", + "state", + f"--dag-id={self.dag_id}", + f"--dag-run-id={self.dag_run_id}", + f"--task-id={self.task_id}", + f"--map-index={map_index}", + ] + ), + api_client=api_client, + )