diff --git a/airflow-e2e-tests/tests/airflow_e2e_tests/go_sdk_tests/test_go_sdk_concurrent_xcom.py b/airflow-e2e-tests/tests/airflow_e2e_tests/go_sdk_tests/test_go_sdk_concurrent_xcom.py new file mode 100644 index 0000000000000..90e9116ca8d81 --- /dev/null +++ b/airflow-e2e-tests/tests/airflow_e2e_tests/go_sdk_tests/test_go_sdk_concurrent_xcom.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""E2E test for the Go SDK ``concurrent_xcom_dag`` example. + +``pull_xcoms_concurrently`` (Go) pulls a batch of XComs sequentially then with one +goroutine per item, sharing the injected client, and fails if any result differs. +Asserts the task succeeds and the concurrent pull beats the sequential one. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timezone + +import pytest + +from airflow_e2e_tests.e2e_test_utils.clients import AirflowClient + +# The Go task seeds and pulls the batch twice; it is fast, but allow room for +# coordinator startup. +_GO_TASK_TIMEOUT = 300 + +_DAG_ID = "concurrent_xcom_dag" +_TASK_ID = "pull_xcoms_concurrently" + + +@dataclass +class _CompletedRun: + """The single ``concurrent_xcom_dag`` run shared across this module's tests.""" + + client: AirflowClient + run_id: str + state: str + ti_states: dict[str, str] + + def xcom(self, task_id: str, key: str = "return_value"): + return self.client.get_xcom_value(dag_id=_DAG_ID, task_id=task_id, run_id=self.run_id, key=key).get( + "value" + ) + + +@pytest.fixture(scope="module") +def completed_run() -> _CompletedRun: + """Trigger ``concurrent_xcom_dag`` once and wait for it to finish.""" + client = AirflowClient() + resp = client.trigger_dag(_DAG_ID, json={"logical_date": datetime.now(timezone.utc).isoformat()}) + run_id = resp["dag_run_id"] + state = client.wait_for_dag_run(dag_id=_DAG_ID, run_id=run_id, timeout=_GO_TASK_TIMEOUT) + ti_resp = client.get_task_instances(dag_id=_DAG_ID, run_id=run_id) + ti_states = {ti["task_id"]: ti.get("state") for ti in ti_resp.get("task_instances", [])} + return _CompletedRun(client=client, run_id=run_id, state=state, ti_states=ti_states) + + +def test_task_succeeded(completed_run: _CompletedRun): + """Run and task succeed -- the task errors on any goroutine mismatch, so this + proves the injected client was used safely.""" + assert completed_run.state == "success", ( + f"expected the run to succeed; got {completed_run.state!r}. task states: {completed_run.ti_states}" + ) + assert completed_run.ti_states.get(_TASK_ID) == "success", completed_run.ti_states + + +def test_concurrent_faster_than_sequential(completed_run: _CompletedRun): + """Concurrent pull-and-process beats the sequential loop.""" + value = completed_run.xcom(_TASK_ID) + assert isinstance(value, dict), ( + f"Expected the task's XCom to be a mapping, got {value!r} ({type(value).__name__})" + ) + + sequential = value.get("sequential_ms") + concurrent = value.get("concurrent_ms") + assert isinstance(sequential, int), f"bad sequential_ms: {sequential!r}" + assert isinstance(concurrent, int), f"bad concurrent_ms: {concurrent!r}" + assert sequential > 0, f"bad sequential_ms: {sequential!r}" + assert concurrent > 0, f"bad concurrent_ms: {concurrent!r}" + assert concurrent < sequential, ( + f"expected concurrent ({concurrent} ms) to beat sequential ({sequential} ms)" + ) + + +def test_num_xcoms(completed_run: _CompletedRun): + """The task reports the batch size it pulled.""" + assert completed_run.xcom(_TASK_ID).get("num_xcoms") == 10 diff --git a/go-sdk/cmd/airflow-go-pack/pack_integration_test.go b/go-sdk/cmd/airflow-go-pack/pack_integration_test.go index edb51824c4cda..77725b0ac4efc 100644 --- a/go-sdk/cmd/airflow-go-pack/pack_integration_test.go +++ b/go-sdk/cmd/airflow-go-pack/pack_integration_test.go @@ -145,6 +145,9 @@ sdk: supervisor_schema_version: "2026-06-16" source: "main.go" dags: + concurrent_xcom_dag: + tasks: + - "pull_xcoms_concurrently" simple_dag: tasks: - "extract" diff --git a/go-sdk/dags/go_examples.py b/go-sdk/dags/go_examples.py index 6c9ecf7b4558b..23e02dd5e49dc 100644 --- a/go-sdk/dags/go_examples.py +++ b/go-sdk/dags/go_examples.py @@ -15,10 +15,14 @@ # specific language governing permissions and limitations # under the License. """ -Python stub Dag mirroring the Go SDK example bundle (``go-sdk/example/bundle``). +Python stub Dags mirroring the Go SDK example bundle (``go-sdk/example/bundle``). -The graph sandwiches the Go tasks between two native Python tasks so the run -exercises XCom across the language boundary, the same way +Two Dags, both backed by the same Go bundle: ``simple_dag`` (extract/transform/ +load, below) and ``concurrent_xcom_dag`` (one ``pull_xcoms_concurrently`` task +timing sequential vs goroutine XCom pulls). + +``simple_dag`` sandwiches the Go tasks between two native Python tasks so the +run exercises XCom across the language boundary, the same way ``java-sdk/dags/java_examples.py`` does for the Java SDK:: python_task_1 >> extract >> transform >> [load, python_task_2] @@ -91,3 +95,15 @@ def simple_dag(): simple_dag() + + +@task.stub(queue="golang") +def pull_xcoms_concurrently(): ... + + +@dag(dag_id="concurrent_xcom_dag") +def concurrent_xcom_dag(): + pull_xcoms_concurrently() + + +concurrent_xcom_dag() diff --git a/go-sdk/example/bundle/concurrentxcom/concurrentxcom.go b/go-sdk/example/bundle/concurrentxcom/concurrentxcom.go new file mode 100644 index 0000000000000..0bc5fd233ec2a --- /dev/null +++ b/go-sdk/example/bundle/concurrentxcom/concurrentxcom.go @@ -0,0 +1,120 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Package concurrentxcom holds the pull_xcoms_concurrently task in its own +// package, so main.go can register tasks defined across packages with one +// RegisterDags. +package concurrentxcom + +import ( + "errors" + "fmt" + "log/slog" + "reflect" + "sync" + "time" + + "github.com/apache/airflow/go-sdk/pkg/api" + "github.com/apache/airflow/go-sdk/sdk" +) + +const ( + numXComs = 10 + // perItemWork is the per-item work the goroutines overlap. + perItemWork = 150 * time.Millisecond +) + +// PullXComsConcurrently pulls a batch of XComs sequentially then concurrently +// (one goroutine per item), exercising concurrent reads of the injected +// sdk.Client, and returns both timings. +func PullXComsConcurrently(ctx sdk.TIRunContext, client sdk.Client, log *slog.Logger) (any, error) { + ti := ctx.TaskInstance() + // PushXCom needs only the ids off the TaskInstance, not the UUID. + apiTI := api.TaskInstance{ + DagId: ti.DagID, + RunId: ti.RunID, + TaskId: ti.TaskID, + MapIndex: ti.MapIndex, + } + + keys := make([]string, numXComs) + for i := range keys { + keys[i] = fmt.Sprintf("item_%d", i) + if err := client.PushXCom(ctx, apiTI, keys[i], i); err != nil { + return nil, fmt.Errorf("seeding xcom %s: %w", keys[i], err) + } + } + + pull := func(key string) (any, error) { + v, err := client.GetXCom(ctx, ti.DagID, ti.RunID, ti.TaskID, nil, key, nil) + if err != nil { + return nil, err + } + time.Sleep(perItemWork) + return v, nil + } + + seqResults := make([]any, numXComs) + seqStart := time.Now() + for i, key := range keys { + v, err := pull(key) + if err != nil { + return nil, fmt.Errorf("sequential pull %s: %w", key, err) + } + seqResults[i] = v + } + sequential := time.Since(seqStart) + + concResults := make([]any, numXComs) + errs := make([]error, numXComs) + concStart := time.Now() + var wg sync.WaitGroup + for i, key := range keys { + wg.Add(1) + go func() { + defer wg.Done() + concResults[i], errs[i] = pull(key) + }() + } + wg.Wait() + concurrent := time.Since(concStart) + if err := errors.Join(errs...); err != nil { + return nil, fmt.Errorf("concurrent pulls failed: %w", err) + } + + for i := range concResults { + if !reflect.DeepEqual(concResults[i], seqResults[i]) { + return nil, fmt.Errorf( + "concurrent result %d = %v, want %v", + i, + concResults[i], + seqResults[i], + ) + } + } + + log.InfoContext(ctx, "pulled xcoms concurrently", + "num_xcoms", numXComs, + "sequential_ms", sequential.Milliseconds(), + "concurrent_ms", concurrent.Milliseconds(), + ) + return map[string]any{ + "num_xcoms": numXComs, + "sequential_ms": sequential.Milliseconds(), + "concurrent_ms": concurrent.Milliseconds(), + }, nil +} diff --git a/go-sdk/example/bundle/concurrentxcom/concurrentxcom_test.go b/go-sdk/example/bundle/concurrentxcom/concurrentxcom_test.go new file mode 100644 index 0000000000000..eab51fa23bbed --- /dev/null +++ b/go-sdk/example/bundle/concurrentxcom/concurrentxcom_test.go @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package concurrentxcom + +import ( + "context" + "log/slog" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/apache/airflow/go-sdk/pkg/api" + "github.com/apache/airflow/go-sdk/sdk" +) + +// mockXComClient is a mutex-guarded in-memory sdk.Client so the concurrent +// GetXCom goroutines are race-free under `go test -race`. +type mockXComClient struct { + mu sync.RWMutex + values map[string]any +} + +func newMockXComClient() *mockXComClient { + return &mockXComClient{values: make(map[string]any)} +} + +func (m *mockXComClient) PushXCom( + ctx context.Context, + ti api.TaskInstance, + key string, + value any, +) error { + m.mu.Lock() + defer m.mu.Unlock() + m.values[key] = value + return nil +} + +func (m *mockXComClient) GetXCom( + ctx context.Context, + dagId, runId, taskId string, + mapIndex *int, + key string, + value any, +) (any, error) { + m.mu.RLock() + defer m.mu.RUnlock() + v, ok := m.values[key] + if !ok { + return nil, sdk.XComNotFound + } + return v, nil +} + +func (m *mockXComClient) GetVariable(ctx context.Context, key string) (string, error) { + panic("unimplemented") +} + +func (m *mockXComClient) UnmarshalJSONVariable(ctx context.Context, key string, pointer any) error { + panic("unimplemented") +} + +func (m *mockXComClient) GetConnection(ctx context.Context, connID string) (sdk.Connection, error) { + panic("unimplemented") +} + +var _ sdk.Client = (*mockXComClient)(nil) + +func Test_PullXComsConcurrently(t *testing.T) { + ctx := sdk.NewTIRunContext( + context.Background(), + sdk.TaskInstance{ + DagID: "concurrent_xcom_dag", + RunID: "run", + TaskID: "pull_xcoms_concurrently", + }, + sdk.DagRun{}, + ) + + result, err := PullXComsConcurrently(ctx, newMockXComClient(), slog.Default()) + assert.NoError(t, err) + + m, ok := result.(map[string]any) + assert.True(t, ok) + assert.Equal(t, numXComs, m["num_xcoms"]) + + sequential := m["sequential_ms"].(int64) + concurrent := m["concurrent_ms"].(int64) + assert.Greater(t, sequential, int64(0)) + assert.Greater(t, concurrent, int64(0)) + // The per-item work overlaps across goroutines, so concurrent beats sequential. + assert.Less(t, concurrent, sequential) +} diff --git a/go-sdk/example/bundle/main.go b/go-sdk/example/bundle/main.go index 7f4f1d22dcf09..23e60bd1dd46e 100644 --- a/go-sdk/example/bundle/main.go +++ b/go-sdk/example/bundle/main.go @@ -26,6 +26,7 @@ import ( v1 "github.com/apache/airflow/go-sdk/bundle/bundlev1" "github.com/apache/airflow/go-sdk/bundle/bundlev1/bundlev1server" + "github.com/apache/airflow/go-sdk/example/bundle/concurrentxcom" "github.com/apache/airflow/go-sdk/sdk" ) @@ -50,6 +51,10 @@ func (m *myBundle) RegisterDags(dagbag v1.Registry) error { simpleDag.AddTask(transform) simpleDag.AddTask(load) + // Tasks defined in other packages register through the same dagbag. + concurrentDag := dagbag.AddDag("concurrent_xcom_dag") + concurrentDag.AddTaskWithName("pull_xcoms_concurrently", concurrentxcom.PullXComsConcurrently) + return nil }