Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""E2E test for the Go SDK ``concurrent_xcom_dag`` example.

``pull_xcoms_concurrently`` (Go) pulls a batch of XComs sequentially then with one
goroutine per item, sharing the injected client, and fails if any result differs.
Asserts the task succeeds and the concurrent pull beats the sequential one.
"""

from __future__ import annotations

from dataclasses import dataclass
from datetime import datetime, timezone

import pytest

from airflow_e2e_tests.e2e_test_utils.clients import AirflowClient

# The Go task seeds and pulls the batch twice; it is fast, but allow room for
# coordinator startup.
_GO_TASK_TIMEOUT = 300

_DAG_ID = "concurrent_xcom_dag"
_TASK_ID = "pull_xcoms_concurrently"


@dataclass
class _CompletedRun:
"""The single ``concurrent_xcom_dag`` run shared across this module's tests."""

client: AirflowClient
run_id: str
state: str
ti_states: dict[str, str]

def xcom(self, task_id: str, key: str = "return_value"):
return self.client.get_xcom_value(dag_id=_DAG_ID, task_id=task_id, run_id=self.run_id, key=key).get(
"value"
)


@pytest.fixture(scope="module")
def completed_run() -> _CompletedRun:
"""Trigger ``concurrent_xcom_dag`` once and wait for it to finish."""
client = AirflowClient()
resp = client.trigger_dag(_DAG_ID, json={"logical_date": datetime.now(timezone.utc).isoformat()})
run_id = resp["dag_run_id"]
state = client.wait_for_dag_run(dag_id=_DAG_ID, run_id=run_id, timeout=_GO_TASK_TIMEOUT)
ti_resp = client.get_task_instances(dag_id=_DAG_ID, run_id=run_id)
ti_states = {ti["task_id"]: ti.get("state") for ti in ti_resp.get("task_instances", [])}
return _CompletedRun(client=client, run_id=run_id, state=state, ti_states=ti_states)


def test_task_succeeded(completed_run: _CompletedRun):
"""Run and task succeed -- the task errors on any goroutine mismatch, so this
proves the injected client was used safely."""
assert completed_run.state == "success", (
f"expected the run to succeed; got {completed_run.state!r}. task states: {completed_run.ti_states}"
)
assert completed_run.ti_states.get(_TASK_ID) == "success", completed_run.ti_states


def test_concurrent_faster_than_sequential(completed_run: _CompletedRun):
"""Concurrent pull-and-process beats the sequential loop."""
value = completed_run.xcom(_TASK_ID)
assert isinstance(value, dict), (
f"Expected the task's XCom to be a mapping, got {value!r} ({type(value).__name__})"
)

sequential = value.get("sequential_ms")
concurrent = value.get("concurrent_ms")
assert isinstance(sequential, int), f"bad sequential_ms: {sequential!r}"
assert isinstance(concurrent, int), f"bad concurrent_ms: {concurrent!r}"
assert sequential > 0, f"bad sequential_ms: {sequential!r}"
assert concurrent > 0, f"bad concurrent_ms: {concurrent!r}"
assert concurrent < sequential, (
f"expected concurrent ({concurrent} ms) to beat sequential ({sequential} ms)"
)


def test_num_xcoms(completed_run: _CompletedRun):
"""The task reports the batch size it pulled."""
assert completed_run.xcom(_TASK_ID).get("num_xcoms") == 10
3 changes: 3 additions & 0 deletions go-sdk/cmd/airflow-go-pack/pack_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ sdk:
supervisor_schema_version: "2026-06-16"
source: "main.go"
dags:
concurrent_xcom_dag:
tasks:
- "pull_xcoms_concurrently"
simple_dag:
tasks:
- "extract"
Expand Down
22 changes: 19 additions & 3 deletions go-sdk/dags/go_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
# specific language governing permissions and limitations
# under the License.
"""
Python stub Dag mirroring the Go SDK example bundle (``go-sdk/example/bundle``).
Python stub Dags mirroring the Go SDK example bundle (``go-sdk/example/bundle``).

The graph sandwiches the Go tasks between two native Python tasks so the run
exercises XCom across the language boundary, the same way
Two Dags, both backed by the same Go bundle: ``simple_dag`` (extract/transform/
load, below) and ``concurrent_xcom_dag`` (one ``pull_xcoms_concurrently`` task
timing sequential vs goroutine XCom pulls).

``simple_dag`` sandwiches the Go tasks between two native Python tasks so the
run exercises XCom across the language boundary, the same way
``java-sdk/dags/java_examples.py`` does for the Java SDK::

python_task_1 >> extract >> transform >> [load, python_task_2]
Expand Down Expand Up @@ -91,3 +95,15 @@ def simple_dag():


simple_dag()


@task.stub(queue="golang")
def pull_xcoms_concurrently(): ...


@dag(dag_id="concurrent_xcom_dag")
def concurrent_xcom_dag():
pull_xcoms_concurrently()


concurrent_xcom_dag()
120 changes: 120 additions & 0 deletions go-sdk/example/bundle/concurrentxcom/concurrentxcom.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

// Package concurrentxcom holds the pull_xcoms_concurrently task in its own
// package, so main.go can register tasks defined across packages with one
// RegisterDags.
package concurrentxcom

import (
"errors"
"fmt"
"log/slog"
"reflect"
"sync"
"time"

"github.com/apache/airflow/go-sdk/pkg/api"
"github.com/apache/airflow/go-sdk/sdk"
)

const (
numXComs = 10
// perItemWork is the per-item work the goroutines overlap.
perItemWork = 150 * time.Millisecond
)

// PullXComsConcurrently pulls a batch of XComs sequentially then concurrently
// (one goroutine per item), exercising concurrent reads of the injected
// sdk.Client, and returns both timings.
func PullXComsConcurrently(ctx sdk.TIRunContext, client sdk.Client, log *slog.Logger) (any, error) {
ti := ctx.TaskInstance()
// PushXCom needs only the ids off the TaskInstance, not the UUID.
apiTI := api.TaskInstance{
DagId: ti.DagID,
RunId: ti.RunID,
TaskId: ti.TaskID,
MapIndex: ti.MapIndex,
}

keys := make([]string, numXComs)
for i := range keys {
keys[i] = fmt.Sprintf("item_%d", i)
if err := client.PushXCom(ctx, apiTI, keys[i], i); err != nil {
return nil, fmt.Errorf("seeding xcom %s: %w", keys[i], err)
}
}

pull := func(key string) (any, error) {
v, err := client.GetXCom(ctx, ti.DagID, ti.RunID, ti.TaskID, nil, key, nil)
if err != nil {
return nil, err
}
time.Sleep(perItemWork)
return v, nil
}

seqResults := make([]any, numXComs)
seqStart := time.Now()
for i, key := range keys {
v, err := pull(key)
if err != nil {
return nil, fmt.Errorf("sequential pull %s: %w", key, err)
}
seqResults[i] = v
}
sequential := time.Since(seqStart)

concResults := make([]any, numXComs)
errs := make([]error, numXComs)
concStart := time.Now()
var wg sync.WaitGroup
for i, key := range keys {
wg.Add(1)
go func() {
defer wg.Done()
concResults[i], errs[i] = pull(key)
}()
}
Comment on lines +86 to +92

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for i, key := range keys {
wg.Add(1)
go func() {
defer wg.Done()
concResults[i], errs[i] = pull(key)
}()
}
for i, key := range keys {
wg.Add(1)
go func(i int, key string) {
defer wg.Done()
concResults[i], errs[i] = pull(key)
}(i, key)
}

suggested by Claude to be explicit. I don’t know what Go people generally prefer (according to Claude, the syntax without arguments only works after 1.22) but passing explicit values to a lambda call is generally better in many languages.

wg.Wait()
concurrent := time.Since(concStart)
if err := errors.Join(errs...); err != nil {
return nil, fmt.Errorf("concurrent pulls failed: %w", err)
}

for i := range concResults {
if !reflect.DeepEqual(concResults[i], seqResults[i]) {
return nil, fmt.Errorf(
"concurrent result %d = %v, want %v",
i,
concResults[i],
seqResults[i],
)
}
}

log.InfoContext(ctx, "pulled xcoms concurrently",
"num_xcoms", numXComs,
"sequential_ms", sequential.Milliseconds(),
"concurrent_ms", concurrent.Milliseconds(),
)
return map[string]any{
"num_xcoms": numXComs,
"sequential_ms": sequential.Milliseconds(),
"concurrent_ms": concurrent.Milliseconds(),
}, nil
}
109 changes: 109 additions & 0 deletions go-sdk/example/bundle/concurrentxcom/concurrentxcom_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package concurrentxcom

import (
"context"
"log/slog"
"sync"
"testing"

"github.com/stretchr/testify/assert"

"github.com/apache/airflow/go-sdk/pkg/api"
"github.com/apache/airflow/go-sdk/sdk"
)

// mockXComClient is a mutex-guarded in-memory sdk.Client so the concurrent
// GetXCom goroutines are race-free under `go test -race`.
type mockXComClient struct {
mu sync.RWMutex
values map[string]any
}

func newMockXComClient() *mockXComClient {
return &mockXComClient{values: make(map[string]any)}
}

func (m *mockXComClient) PushXCom(
ctx context.Context,
ti api.TaskInstance,
key string,
value any,
) error {
m.mu.Lock()
defer m.mu.Unlock()
m.values[key] = value
return nil
}

func (m *mockXComClient) GetXCom(
ctx context.Context,
dagId, runId, taskId string,
mapIndex *int,
key string,
value any,
) (any, error) {
m.mu.RLock()
defer m.mu.RUnlock()
v, ok := m.values[key]
if !ok {
return nil, sdk.XComNotFound
}
return v, nil
}

func (m *mockXComClient) GetVariable(ctx context.Context, key string) (string, error) {
panic("unimplemented")
}

func (m *mockXComClient) UnmarshalJSONVariable(ctx context.Context, key string, pointer any) error {
panic("unimplemented")
}

func (m *mockXComClient) GetConnection(ctx context.Context, connID string) (sdk.Connection, error) {
panic("unimplemented")
}

var _ sdk.Client = (*mockXComClient)(nil)

func Test_PullXComsConcurrently(t *testing.T) {
ctx := sdk.NewTIRunContext(
context.Background(),
sdk.TaskInstance{
DagID: "concurrent_xcom_dag",
RunID: "run",
TaskID: "pull_xcoms_concurrently",
},
sdk.DagRun{},
)

result, err := PullXComsConcurrently(ctx, newMockXComClient(), slog.Default())
assert.NoError(t, err)

m, ok := result.(map[string]any)
assert.True(t, ok)
assert.Equal(t, numXComs, m["num_xcoms"])

sequential := m["sequential_ms"].(int64)
concurrent := m["concurrent_ms"].(int64)
assert.Greater(t, sequential, int64(0))
assert.Greater(t, concurrent, int64(0))
// The per-item work overlaps across goroutines, so concurrent beats sequential.
assert.Less(t, concurrent, sequential)
}
5 changes: 5 additions & 0 deletions go-sdk/example/bundle/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (

v1 "github.com/apache/airflow/go-sdk/bundle/bundlev1"
"github.com/apache/airflow/go-sdk/bundle/bundlev1/bundlev1server"
"github.com/apache/airflow/go-sdk/example/bundle/concurrentxcom"
"github.com/apache/airflow/go-sdk/sdk"
)

Expand All @@ -50,6 +51,10 @@ func (m *myBundle) RegisterDags(dagbag v1.Registry) error {
simpleDag.AddTask(transform)
simpleDag.AddTask(load)

// Tasks defined in other packages register through the same dagbag.
concurrentDag := dagbag.AddDag("concurrent_xcom_dag")
concurrentDag.AddTaskWithName("pull_xcoms_concurrently", concurrentxcom.PullXComsConcurrently)

return nil
}

Expand Down