Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,15 @@ async def get_state(self, task_id: str) -> WorkflowState:
workflow_id=task_id,
)

async def send_event(self, agent: Agent, task: Task, event: Event) -> None:
async def send_event(self, agent: Agent, task: Task, event: Event, request: dict | None = None) -> None:
return await self._temporal_client.send_signal(
workflow_id=task.id,
signal=SignalName.RECEIVE_EVENT.value,
payload=SendEventParams(
agent=agent,
task=task,
event=event,
request=request,
).model_dump(),
)

Expand Down
1 change: 1 addition & 0 deletions src/agentex/lib/sdk/fastacp/impl/temporal_acp.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ async def handle_event_send(params: SendEventParams) -> None:
agent=params.agent,
task=params.task,
event=params.event,
request=params.request,
)

except Exception as e:
Expand Down
229 changes: 228 additions & 1 deletion tests/test_header_forwarding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from typing import Any, override
import sys
import types
from datetime import datetime, timezone
from unittest.mock import AsyncMock, Mock

import pytest
from fastapi.testclient import TestClient
Expand Down Expand Up @@ -43,8 +45,14 @@ class _StubTracer(_StubAsyncTracer):

from agentex.lib.core.services.adk.acp.acp import ACPService
from agentex.lib.sdk.fastacp.base.base_acp_server import BaseACPServer
from agentex.lib.types.acp import RPCMethod, SendMessageParams
from agentex.lib.types.acp import RPCMethod, SendMessageParams, SendEventParams
from agentex.types.task_message_content import TextContent
from agentex.lib.sdk.fastacp.impl.temporal_acp import TemporalACP
from agentex.lib.core.temporal.services.temporal_task_service import TemporalTaskService
from agentex.lib.environment_variables import EnvironmentVariables
from agentex.types.agent import Agent
from agentex.types.task import Task
from agentex.types.event import Event


class DummySpan:
Expand Down Expand Up @@ -312,3 +320,222 @@ def test_filter_headers_all_types() -> None:
assert result == expected



# ============================================================================
# Temporal Header Forwarding Tests
# ============================================================================

@pytest.fixture
def mock_temporal_client():
"""Create a mock TemporalClient"""
client = AsyncMock()
client.send_signal = AsyncMock(return_value=None)
return client


@pytest.fixture
def mock_env_vars():
"""Create mock environment variables"""
env_vars = Mock(spec=EnvironmentVariables)
env_vars.WORKFLOW_NAME = "test-workflow"
env_vars.WORKFLOW_TASK_QUEUE = "test-queue"
return env_vars


@pytest.fixture
def temporal_task_service(mock_temporal_client, mock_env_vars):
"""Create TemporalTaskService with mocked client"""
return TemporalTaskService(
temporal_client=mock_temporal_client,
env_vars=mock_env_vars,
)


@pytest.fixture
def sample_agent():
"""Create a sample agent"""
return Agent(
id="agent-123",
name="test-agent",
description="Test agent",
acp_type="agentic",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)


@pytest.fixture
def sample_task():
"""Create a sample task"""
return Task(id="task-456")


@pytest.fixture
def sample_event():
"""Create a sample event"""
return Event(
id="event-789",
agent_id="agent-123",
task_id="task-456",
sequence_id=1,
content=TextContent(author="user", content="Test message")
)


@pytest.mark.asyncio
async def test_temporal_task_service_send_event_with_headers(
temporal_task_service,
mock_temporal_client,
sample_agent,
sample_task,
sample_event
):
"""Test that TemporalTaskService forwards request headers in signal payload"""
# Given
request_headers = {
"x-user-oauth-credentials": "test-oauth-token",
"x-custom-header": "custom-value"
}
request = {"headers": request_headers}

# When
await temporal_task_service.send_event(
agent=sample_agent,
task=sample_task,
event=sample_event,
request=request
)

# Then
mock_temporal_client.send_signal.assert_called_once()
call_args = mock_temporal_client.send_signal.call_args

# Verify the signal was sent to the correct workflow
assert call_args.kwargs["workflow_id"] == sample_task.id
assert call_args.kwargs["signal"] == "receive_event"

# Verify the payload includes the request with headers
payload = call_args.kwargs["payload"]
assert "request" in payload
assert payload["request"] == request
assert payload["request"]["headers"] == request_headers


@pytest.mark.asyncio
async def test_temporal_task_service_send_event_without_headers(
temporal_task_service,
mock_temporal_client,
sample_agent,
sample_task,
sample_event
):
"""Test that TemporalTaskService handles missing request gracefully"""
# When - Send event without request parameter
await temporal_task_service.send_event(
agent=sample_agent,
task=sample_task,
event=sample_event,
request=None
)

# Then
mock_temporal_client.send_signal.assert_called_once()
call_args = mock_temporal_client.send_signal.call_args

# Verify the payload has request as None
payload = call_args.kwargs["payload"]
assert payload["request"] is None


@pytest.mark.asyncio
async def test_temporal_acp_integration_with_request_headers(
mock_temporal_client,
mock_env_vars,
sample_agent,
sample_task,
sample_event
):
"""Test end-to-end integration: TemporalACP -> TemporalTaskService -> TemporalClient signal"""
# Given - Create real TemporalTaskService with mocked client
task_service = TemporalTaskService(
temporal_client=mock_temporal_client,
env_vars=mock_env_vars,
)

# Create TemporalACP with real task service
temporal_acp = TemporalACP(
temporal_address="localhost:7233",
temporal_task_service=task_service,
)
temporal_acp._setup_handlers()

request_headers = {
"x-user-id": "user-123",
"authorization": "Bearer token",
"x-tenant-id": "tenant-456"
}
request = {"headers": request_headers}

# Create SendEventParams as TemporalACP would receive it
params = SendEventParams(
agent=sample_agent,
task=sample_task,
event=sample_event,
request=request
)

# When - Trigger the event handler via the decorated function
# The handler is registered via @temporal_acp.on_task_event_send
# We'll directly call the task service method as the handler does
await task_service.send_event(
agent=params.agent,
task=params.task,
event=params.event,
request=params.request
)

# Then - Verify the temporal client received the signal with request headers
mock_temporal_client.send_signal.assert_called_once()
call_args = mock_temporal_client.send_signal.call_args

# Verify signal payload includes request with headers
payload = call_args.kwargs["payload"]
assert payload["request"] == request
assert payload["request"]["headers"] == request_headers


@pytest.mark.asyncio
async def test_temporal_task_service_preserves_all_header_types(
temporal_task_service,
mock_temporal_client,
sample_agent,
sample_task,
sample_event
):
"""Test that various header types are preserved correctly"""
# Given - Headers with different patterns
request_headers = {
"x-user-oauth-credentials": "oauth-token-12345",
"authorization": "Bearer jwt-token",
"x-tenant-id": "tenant-999",
"x-request-id": "req-abc-123",
"x-custom-app-header": "custom-value"
}
request = {"headers": request_headers}

# When
await temporal_task_service.send_event(
agent=sample_agent,
task=sample_task,
event=sample_event,
request=request
)

# Then - Verify all headers are preserved in the signal payload
call_args = mock_temporal_client.send_signal.call_args
payload = call_args.kwargs["payload"]

assert payload["request"]["headers"] == request_headers
# Verify each header individually
for header_name, header_value in request_headers.items():
assert payload["request"]["headers"][header_name] == header_value