From 7af45717ea3442ef205901b1e65a0ead37dd0f9e Mon Sep 17 00:00:00 2001 From: Felix Uellendall Date: Thu, 6 Feb 2025 17:39:08 +0100 Subject: [PATCH] Fix task sdk client dry-run mode (#46524) * Fix task sdk client dry-run mode The `run_after` field was missing in the fake dag run response. * Fix trailing whitespaces --- task_sdk/src/airflow/sdk/api/client.py | 1 + task_sdk/tests/api/test_client.py | 33 ++++++++++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/task_sdk/src/airflow/sdk/api/client.py b/task_sdk/src/airflow/sdk/api/client.py index 821e589ad522f..84992b9ab1945 100644 --- a/task_sdk/src/airflow/sdk/api/client.py +++ b/task_sdk/src/airflow/sdk/api/client.py @@ -344,6 +344,7 @@ def noop_handler(request: httpx.Request) -> httpx.Response: "logical_date": "2021-01-01T00:00:00Z", "start_date": "2021-01-01T00:00:00Z", "run_type": DagRunType.MANUAL, + "run_after": "2021-01-01T00:00:00Z", }, "max_tries": 0, }, diff --git a/task_sdk/tests/api/test_client.py b/task_sdk/tests/api/test_client.py index 8315a121fc4d7..319a00354bdb7 100644 --- a/task_sdk/tests/api/test_client.py +++ b/task_sdk/tests/api/test_client.py @@ -37,6 +37,11 @@ def make_client(transport: httpx.MockTransport) -> Client: return Client(base_url="test://server", token="", transport=transport) +def make_client_w_dry_run() -> Client: + """Get a client with dry_run enabled""" + return Client(base_url=None, dry_run=True, token="") + + def make_client_w_responses(responses: list[httpx.Response]) -> Client: """Helper fixture to create a mock client with custom responses.""" @@ -49,6 +54,34 @@ def handle_request(request: httpx.Request) -> httpx.Response: class TestClient: + @pytest.mark.parametrize( + ["path", "json_response"], + [ + ( + "/task-instances/1/run", + { + "dag_run": { + "dag_id": "test_dag", + "run_id": "test_run", + "logical_date": "2021-01-01T00:00:00Z", + "start_date": "2021-01-01T00:00:00Z", + "run_type": "manual", + "run_after": "2021-01-01T00:00:00Z", + }, + "max_tries": 0, + }, + ), + ], + ) + def test_dry_run(self, path, json_response): + client = make_client_w_dry_run() + assert client.base_url == "dry-run://server" + + resp = client.get(path) + + assert resp.status_code == 200 + assert resp.json() == json_response + def test_error_parsing(self): responses = [ httpx.Response(422, json={"detail": [{"loc": ["#0"], "msg": "err", "type": "required"}]})