Skip to content

Commit b80f7ee

Browse files
committed
Add Go-SDK pull XComs with goroutines in e2e test
1 parent 6877aea commit b80f7ee

5 files changed

Lines changed: 350 additions & 3 deletions

File tree

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""E2E test for the Go SDK ``concurrent_xcom_dag`` example.
18+
19+
``pull_xcoms_concurrently`` (Go) pulls a batch of XComs sequentially then with one
20+
goroutine per item, sharing the injected client, and fails if any result differs.
21+
Asserts the task succeeds and the concurrent pull beats the sequential one.
22+
"""
23+
24+
from __future__ import annotations
25+
26+
from dataclasses import dataclass
27+
from datetime import datetime, timezone
28+
29+
import pytest
30+
31+
from airflow_e2e_tests.e2e_test_utils.clients import AirflowClient
32+
33+
# The Go task seeds and pulls the batch twice; it is fast, but allow room for
34+
# coordinator startup.
35+
_GO_TASK_TIMEOUT = 300
36+
37+
_DAG_ID = "concurrent_xcom_dag"
38+
_TASK_ID = "pull_xcoms_concurrently"
39+
40+
41+
@dataclass
42+
class _CompletedRun:
43+
"""The single ``concurrent_xcom_dag`` run shared across this module's tests."""
44+
45+
client: AirflowClient
46+
run_id: str
47+
state: str
48+
ti_states: dict[str, str]
49+
50+
def xcom(self, task_id: str, key: str = "return_value"):
51+
return self.client.get_xcom_value(dag_id=_DAG_ID, task_id=task_id, run_id=self.run_id, key=key).get(
52+
"value"
53+
)
54+
55+
56+
@pytest.fixture(scope="module")
57+
def completed_run() -> _CompletedRun:
58+
"""Trigger ``concurrent_xcom_dag`` once and wait for it to finish."""
59+
client = AirflowClient()
60+
resp = client.trigger_dag(_DAG_ID, json={"logical_date": datetime.now(timezone.utc).isoformat()})
61+
run_id = resp["dag_run_id"]
62+
state = client.wait_for_dag_run(dag_id=_DAG_ID, run_id=run_id, timeout=_GO_TASK_TIMEOUT)
63+
ti_resp = client.get_task_instances(dag_id=_DAG_ID, run_id=run_id)
64+
ti_states = {ti["task_id"]: ti.get("state") for ti in ti_resp.get("task_instances", [])}
65+
return _CompletedRun(client=client, run_id=run_id, state=state, ti_states=ti_states)
66+
67+
68+
def test_task_succeeded(completed_run: _CompletedRun):
69+
"""Run and task succeed -- the task errors on any goroutine mismatch, so this
70+
proves the injected client was used safely."""
71+
assert completed_run.state == "success", (
72+
f"expected the run to succeed; got {completed_run.state!r}. task states: {completed_run.ti_states}"
73+
)
74+
assert completed_run.ti_states.get(_TASK_ID) == "success", completed_run.ti_states
75+
76+
77+
def test_concurrent_faster_than_sequential(completed_run: _CompletedRun):
78+
"""Concurrent pull-and-process beats the sequential loop."""
79+
value = completed_run.xcom(_TASK_ID)
80+
assert isinstance(value, dict), (
81+
f"Expected the task's XCom to be a mapping, got {value!r} ({type(value).__name__})"
82+
)
83+
84+
sequential = value.get("sequential_ms")
85+
concurrent = value.get("concurrent_ms")
86+
assert isinstance(sequential, int), f"bad sequential_ms: {sequential!r}"
87+
assert isinstance(concurrent, int), f"bad concurrent_ms: {concurrent!r}"
88+
assert sequential > 0, f"bad sequential_ms: {sequential!r}"
89+
assert concurrent > 0, f"bad concurrent_ms: {concurrent!r}"
90+
assert concurrent < sequential, (
91+
f"expected concurrent ({concurrent} ms) to beat sequential ({sequential} ms)"
92+
)
93+
94+
95+
def test_num_xcoms(completed_run: _CompletedRun):
96+
"""The task reports the batch size it pulled."""
97+
assert completed_run.xcom(_TASK_ID).get("num_xcoms") == 10

go-sdk/dags/go_examples.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,14 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""
18-
Python stub Dag mirroring the Go SDK example bundle (``go-sdk/example/bundle``).
18+
Python stub Dags mirroring the Go SDK example bundle (``go-sdk/example/bundle``).
1919
20-
The graph sandwiches the Go tasks between two native Python tasks so the run
21-
exercises XCom across the language boundary, the same way
20+
Two Dags, both backed by the same Go bundle: ``simple_dag`` (extract/transform/
21+
load, below) and ``concurrent_xcom_dag`` (one ``pull_xcoms_concurrently`` task
22+
timing sequential vs goroutine XCom pulls).
23+
24+
``simple_dag`` sandwiches the Go tasks between two native Python tasks so the
25+
run exercises XCom across the language boundary, the same way
2226
``java-sdk/dags/java_examples.py`` does for the Java SDK::
2327
2428
python_task_1 >> extract >> transform >> [load, python_task_2]
@@ -91,3 +95,15 @@ def simple_dag():
9195

9296

9397
simple_dag()
98+
99+
100+
@task.stub(queue="golang")
101+
def pull_xcoms_concurrently(): ...
102+
103+
104+
@dag(dag_id="concurrent_xcom_dag")
105+
def concurrent_xcom_dag():
106+
pull_xcoms_concurrently()
107+
108+
109+
concurrent_xcom_dag()
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
// Package concurrentxcom holds the pull_xcoms_concurrently task in its own
19+
// package, so main.go can register tasks defined across packages with one
20+
// RegisterDags.
21+
package concurrentxcom
22+
23+
import (
24+
"errors"
25+
"fmt"
26+
"log/slog"
27+
"reflect"
28+
"sync"
29+
"time"
30+
31+
"github.com/apache/airflow/go-sdk/pkg/api"
32+
"github.com/apache/airflow/go-sdk/sdk"
33+
)
34+
35+
const (
36+
numXComs = 10
37+
// perItemWork is the per-item work the goroutines overlap.
38+
perItemWork = 150 * time.Millisecond
39+
)
40+
41+
// PullXComsConcurrently pulls a batch of XComs sequentially then concurrently
42+
// (one goroutine per item), proving the injected sdk.Client is concurrency-safe,
43+
// and returns both timings.
44+
func PullXComsConcurrently(ctx sdk.TIRunContext, client sdk.Client, log *slog.Logger) (any, error) {
45+
ti := ctx.TaskInstance()
46+
// PushXCom needs only the ids off the TaskInstance, not the UUID.
47+
apiTI := api.TaskInstance{
48+
DagId: ti.DagID,
49+
RunId: ti.RunID,
50+
TaskId: ti.TaskID,
51+
MapIndex: ti.MapIndex,
52+
}
53+
54+
keys := make([]string, numXComs)
55+
for i := range keys {
56+
keys[i] = fmt.Sprintf("item_%d", i)
57+
if err := client.PushXCom(ctx, apiTI, keys[i], i); err != nil {
58+
return nil, fmt.Errorf("seeding xcom %s: %w", keys[i], err)
59+
}
60+
}
61+
62+
pull := func(key string) (any, error) {
63+
v, err := client.GetXCom(ctx, ti.DagID, ti.RunID, ti.TaskID, nil, key, nil)
64+
if err != nil {
65+
return nil, err
66+
}
67+
time.Sleep(perItemWork)
68+
return v, nil
69+
}
70+
71+
seqResults := make([]any, numXComs)
72+
seqStart := time.Now()
73+
for i, key := range keys {
74+
v, err := pull(key)
75+
if err != nil {
76+
return nil, fmt.Errorf("sequential pull %s: %w", key, err)
77+
}
78+
seqResults[i] = v
79+
}
80+
sequential := time.Since(seqStart)
81+
82+
concResults := make([]any, numXComs)
83+
errs := make([]error, numXComs)
84+
concStart := time.Now()
85+
var wg sync.WaitGroup
86+
for i, key := range keys {
87+
wg.Add(1)
88+
go func() {
89+
defer wg.Done()
90+
concResults[i], errs[i] = pull(key)
91+
}()
92+
}
93+
wg.Wait()
94+
concurrent := time.Since(concStart)
95+
if err := errors.Join(errs...); err != nil {
96+
return nil, fmt.Errorf("concurrent pulls failed: %w", err)
97+
}
98+
99+
for i := range concResults {
100+
if !reflect.DeepEqual(concResults[i], seqResults[i]) {
101+
return nil, fmt.Errorf(
102+
"concurrent result %d = %v, want %v",
103+
i,
104+
concResults[i],
105+
seqResults[i],
106+
)
107+
}
108+
}
109+
110+
log.InfoContext(ctx, "pulled xcoms concurrently",
111+
"num_xcoms", numXComs,
112+
"sequential_ms", sequential.Milliseconds(),
113+
"concurrent_ms", concurrent.Milliseconds(),
114+
)
115+
return map[string]any{
116+
"num_xcoms": numXComs,
117+
"sequential_ms": sequential.Milliseconds(),
118+
"concurrent_ms": concurrent.Milliseconds(),
119+
}, nil
120+
}
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
package concurrentxcom
19+
20+
import (
21+
"context"
22+
"log/slog"
23+
"sync"
24+
"testing"
25+
26+
"github.com/stretchr/testify/assert"
27+
28+
"github.com/apache/airflow/go-sdk/pkg/api"
29+
"github.com/apache/airflow/go-sdk/sdk"
30+
)
31+
32+
// mockXComClient is a mutex-guarded in-memory sdk.Client so the concurrent
33+
// GetXCom goroutines are race-free under `go test -race`.
34+
type mockXComClient struct {
35+
mu sync.RWMutex
36+
values map[string]any
37+
}
38+
39+
func newMockXComClient() *mockXComClient {
40+
return &mockXComClient{values: make(map[string]any)}
41+
}
42+
43+
func (m *mockXComClient) PushXCom(
44+
ctx context.Context,
45+
ti api.TaskInstance,
46+
key string,
47+
value any,
48+
) error {
49+
m.mu.Lock()
50+
defer m.mu.Unlock()
51+
m.values[key] = value
52+
return nil
53+
}
54+
55+
func (m *mockXComClient) GetXCom(
56+
ctx context.Context,
57+
dagId, runId, taskId string,
58+
mapIndex *int,
59+
key string,
60+
value any,
61+
) (any, error) {
62+
m.mu.RLock()
63+
defer m.mu.RUnlock()
64+
v, ok := m.values[key]
65+
if !ok {
66+
return nil, sdk.XComNotFound
67+
}
68+
return v, nil
69+
}
70+
71+
func (m *mockXComClient) GetVariable(ctx context.Context, key string) (string, error) {
72+
panic("unimplemented")
73+
}
74+
75+
func (m *mockXComClient) UnmarshalJSONVariable(ctx context.Context, key string, pointer any) error {
76+
panic("unimplemented")
77+
}
78+
79+
func (m *mockXComClient) GetConnection(ctx context.Context, connID string) (sdk.Connection, error) {
80+
panic("unimplemented")
81+
}
82+
83+
var _ sdk.Client = (*mockXComClient)(nil)
84+
85+
func Test_PullXComsConcurrently(t *testing.T) {
86+
ctx := sdk.NewTIRunContext(
87+
context.Background(),
88+
sdk.TaskInstance{
89+
DagID: "concurrent_xcom_dag",
90+
RunID: "run",
91+
TaskID: "pull_xcoms_concurrently",
92+
},
93+
sdk.DagRun{},
94+
)
95+
96+
result, err := PullXComsConcurrently(ctx, newMockXComClient(), slog.Default())
97+
assert.NoError(t, err)
98+
99+
m, ok := result.(map[string]any)
100+
assert.True(t, ok)
101+
assert.Equal(t, numXComs, m["num_xcoms"])
102+
103+
sequential := m["sequential_ms"].(int64)
104+
concurrent := m["concurrent_ms"].(int64)
105+
assert.Greater(t, sequential, int64(0))
106+
assert.Greater(t, concurrent, int64(0))
107+
// The per-item work overlaps across goroutines, so concurrent beats sequential.
108+
assert.Less(t, concurrent, sequential)
109+
}

go-sdk/example/bundle/main.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626

2727
v1 "github.com/apache/airflow/go-sdk/bundle/bundlev1"
2828
"github.com/apache/airflow/go-sdk/bundle/bundlev1/bundlev1server"
29+
"github.com/apache/airflow/go-sdk/example/bundle/concurrentxcom"
2930
"github.com/apache/airflow/go-sdk/sdk"
3031
)
3132

@@ -50,6 +51,10 @@ func (m *myBundle) RegisterDags(dagbag v1.Registry) error {
5051
simpleDag.AddTask(transform)
5152
simpleDag.AddTask(load)
5253

54+
// Tasks defined in other packages register through the same dagbag.
55+
concurrentDag := dagbag.AddDag("concurrent_xcom_dag")
56+
concurrentDag.AddTaskWithName("pull_xcoms_concurrently", concurrentxcom.PullXComsConcurrently)
57+
5358
return nil
5459
}
5560

0 commit comments

Comments
 (0)