Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions go/adk/pkg/a2a/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,33 @@ func (u *userIDInterceptor) Before(ctx context.Context, callCtx *a2asrv.CallCont
return ctx, nil
}

// newAgentMessage builds an agent message stamped with the request's context
// and task ids. A2A allows omitting these (the task is the canonical carrier),
// but stamping them lets consumers that flatten task.history into standalone
// messages key each message to its task without backfilling. Mirrors the Python
// kagent-adk event converter.
func newAgentMessage(reqCtx *a2asrv.RequestContext, parts ...a2atype.Part) *a2atype.Message {
msg := a2atype.NewMessage(a2atype.MessageRoleAgent, parts...)
msg.ContextID = reqCtx.ContextID
msg.TaskID = reqCtx.TaskID
return msg
}

// newAgentStatusEvent builds a working TaskStatusUpdateEvent whose agent message
// carries the given parts, the given metadata, and the request's context/task
// ids (via newAgentMessage). The message and event share the same metadata map,
// matching the executor's emission paths. This is the per-event seam where a
// streamed agent message is turned into an emitted (and persisted) event, so the
// id stamping here is what the send guard relies on when it later flattens
// task.history.
func newAgentStatusEvent(reqCtx *a2asrv.RequestContext, parts a2atype.ContentParts, meta map[string]any) *a2atype.TaskStatusUpdateEvent {
msg := newAgentMessage(reqCtx, parts...)
msg.Metadata = meta
statusEv := a2atype.NewStatusUpdateEvent(reqCtx, a2atype.TaskStateWorking, msg)
statusEv.Metadata = meta
return statusEv
}

// Execute implements a2asrv.AgentExecutor.
// It follows the Python _handle_request pattern: set up session, handle HITL,
// convert inbound message, run the agent loop, and emit A2A events.
Expand Down Expand Up @@ -266,7 +293,7 @@ func (e *KAgentExecutor) Execute(ctx context.Context, reqCtx *a2asrv.RequestCont
// Events with no content carry metadata only; still track invocationID/usage.
// Check for LLM error.
if adkEvent.ErrorCode != "" {
errMsg := a2atype.NewMessage(a2atype.MessageRoleAgent,
errMsg := newAgentMessage(reqCtx,
a2atype.TextPart{Text: fmt.Sprintf("LLM error: %s %s", adkEvent.ErrorCode, adkEvent.ErrorMessage)})
failed := a2atype.NewStatusUpdateEvent(reqCtx, a2atype.TaskStateFailed, errMsg)
failed.Final = true
Expand All @@ -278,7 +305,7 @@ func (e *KAgentExecutor) Execute(ctx context.Context, reqCtx *a2asrv.RequestCont

// Check for LLM error (even with content present).
if adkEvent.ErrorCode != "" {
errMsg := a2atype.NewMessage(a2atype.MessageRoleAgent,
errMsg := newAgentMessage(reqCtx,
a2atype.TextPart{Text: fmt.Sprintf("LLM error: %s %s", adkEvent.ErrorCode, adkEvent.ErrorMessage)})
failed := a2atype.NewStatusUpdateEvent(reqCtx, a2atype.TaskStateFailed, errMsg)
failed.Final = true
Expand Down Expand Up @@ -326,10 +353,7 @@ func (e *KAgentExecutor) Execute(ctx context.Context, reqCtx *a2asrv.RequestCont
if len(textOnly) > 0 {
mirrorMeta := maps.Clone(eventMeta)
mirrorMeta[adka2a.ToA2AMetaKey("partial")] = true
msg := a2atype.NewMessage(a2atype.MessageRoleAgent, textOnly...)
msg.Metadata = mirrorMeta
statusEv := a2atype.NewStatusUpdateEvent(reqCtx, a2atype.TaskStateWorking, msg)
statusEv.Metadata = mirrorMeta
statusEv := newAgentStatusEvent(reqCtx, textOnly, mirrorMeta)
if err := queue.Write(ctx, statusEv); err != nil {
return fmt.Errorf("failed to write partial status event: %w", err)
}
Expand All @@ -338,10 +362,7 @@ func (e *KAgentExecutor) Execute(ctx context.Context, reqCtx *a2asrv.RequestCont
mirrorParts := a2aParts
if len(hitlParts) == 0 {
// Only mirror when not accumulating HITL parts (those go into input_required).
msg := a2atype.NewMessage(a2atype.MessageRoleAgent, mirrorParts...)
msg.Metadata = maps.Clone(eventMeta)
statusEv := a2atype.NewStatusUpdateEvent(reqCtx, a2atype.TaskStateWorking, msg)
statusEv.Metadata = maps.Clone(eventMeta)
statusEv := newAgentStatusEvent(reqCtx, mirrorParts, maps.Clone(eventMeta))
if err := queue.Write(ctx, statusEv); err != nil {
return fmt.Errorf("failed to write mirror status event: %w", err)
}
Expand All @@ -362,7 +383,7 @@ func (e *KAgentExecutor) Execute(ctx context.Context, reqCtx *a2asrv.RequestCont
}

if runErr != nil {
errMsg := a2atype.NewMessage(a2atype.MessageRoleAgent, a2atype.TextPart{Text: runErr.Error()})
errMsg := newAgentMessage(reqCtx, a2atype.TextPart{Text: runErr.Error()})
failed := a2atype.NewStatusUpdateEvent(reqCtx, a2atype.TaskStateFailed, errMsg)
failed.Final = true
failed.Metadata = finalMeta
Expand All @@ -371,7 +392,7 @@ func (e *KAgentExecutor) Execute(ctx context.Context, reqCtx *a2asrv.RequestCont

if len(hitlParts) > 0 {
// input_required: the agent is waiting for HITL decisions.
hitlMsg := a2atype.NewMessage(a2atype.MessageRoleAgent, hitlParts...)
hitlMsg := newAgentMessage(reqCtx, hitlParts...)
inputRequired := a2atype.NewStatusUpdateEvent(reqCtx, a2atype.TaskStateInputRequired, hitlMsg)
inputRequired.Final = true
inputRequired.Metadata = finalMeta
Expand Down
67 changes: 67 additions & 0 deletions go/adk/pkg/a2a/executor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package a2a

import (
"testing"

a2atype "github.com/a2aproject/a2a-go/a2a"
"github.com/a2aproject/a2a-go/a2asrv"
)

// TestNewAgentMessage_StampsContextAndTaskID verifies agent messages carry the
// request's context and task ids. A2A allows omitting them (the task is the
// canonical carrier), but stamping them lets consumers that flatten task.history
// key each message to its task without backfilling.
func TestNewAgentMessage_StampsContextAndTaskID(t *testing.T) {
reqCtx := &a2asrv.RequestContext{
ContextID: "ctx-xyz",
TaskID: a2atype.TaskID("task-xyz"),
}

msg := newAgentMessage(reqCtx, a2atype.TextPart{Text: "hello"})

if msg.ContextID != "ctx-xyz" {
t.Errorf("ContextID = %q, want %q", msg.ContextID, "ctx-xyz")
}
if msg.TaskID != a2atype.TaskID("task-xyz") {
t.Errorf("TaskID = %q, want %q", msg.TaskID, a2atype.TaskID("task-xyz"))
}
if msg.Role != a2atype.MessageRoleAgent {
t.Errorf("Role = %q, want %q", msg.Role, a2atype.MessageRoleAgent)
}
}

// TestNewAgentStatusEvent_MessageCarriesIDs verifies the per-event emission seam:
// the working status event the executor writes (and that is persisted into
// task.history) carries an agent message stamped with the request's context/task
// ids. This is the property the send guard depends on — without it the persisted
// message keys differently from its locally-streamed counterpart and falsely
// blocks the next send. Mirrors the Python converter test.
func TestNewAgentStatusEvent_MessageCarriesIDs(t *testing.T) {
reqCtx := &a2asrv.RequestContext{
ContextID: "ctx-xyz",
TaskID: a2atype.TaskID("task-xyz"),
}
meta := map[string]any{"k": "v"}

ev := newAgentStatusEvent(reqCtx, a2atype.ContentParts{a2atype.TextPart{Text: "hi"}}, meta)

if ev.Status.State != a2atype.TaskStateWorking {
t.Errorf("State = %q, want %q", ev.Status.State, a2atype.TaskStateWorking)
}
if ev.Status.Message == nil {
t.Fatal("status message is nil")
}
if ev.Status.Message.ContextID != "ctx-xyz" {
t.Errorf("message ContextID = %q, want %q", ev.Status.Message.ContextID, "ctx-xyz")
}
if ev.Status.Message.TaskID != a2atype.TaskID("task-xyz") {
t.Errorf("message TaskID = %q, want %q", ev.Status.Message.TaskID, a2atype.TaskID("task-xyz"))
}
// The event itself also carries the ids (from reqCtx), matching the message.
if ev.ContextID != "ctx-xyz" {
t.Errorf("event ContextID = %q, want %q", ev.ContextID, "ctx-xyz")
}
if ev.TaskID != a2atype.TaskID("task-xyz") {
t.Errorf("event TaskID = %q, want %q", ev.TaskID, a2atype.TaskID("task-xyz"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ def convert_event_to_a2a_message(
invocation_context: InvocationContext,
role: Role = Role.agent,
subagent_session_ids: Optional[Dict[str, str]] = None,
task_id: Optional[str] = None,
context_id: Optional[str] = None,
) -> Optional[Message]:
"""Converts an ADK event to an A2A message.

Expand All @@ -171,6 +173,12 @@ def convert_event_to_a2a_message(
subagent_session_ids: Optional mapping of tool name to pre-generated
subagent session ID. When provided, function_call DataParts for
matching tools will have the session ID stamped into their metadata.
task_id: Optional task ID stamped onto the message so it carries the
same identity as the enclosing task. A2A allows these to be omitted
(the task is the canonical carrier), but stamping them lets consumers
that flatten task.history into standalone messages key each message to
its task without backfilling.
context_id: Optional context ID stamped onto the message, as task_id.

Returns:
An A2A Message if the event has content, None otherwise.
Expand Down Expand Up @@ -198,7 +206,14 @@ def convert_event_to_a2a_message(

if a2a_parts:
message_metadata = _get_context_metadata(event, invocation_context)
return Message(message_id=str(uuid.uuid4()), role=role, parts=a2a_parts, metadata=message_metadata)
return Message(
message_id=str(uuid.uuid4()),
role=role,
parts=a2a_parts,
metadata=message_metadata,
task_id=task_id,
context_id=context_id,
)

except Exception as e:
logger.error("Failed to convert event to status message: %s", e)
Expand Down Expand Up @@ -342,7 +357,13 @@ def convert_event_to_a2a_events(
a2a_events.append(error_event)

# Handle regular message content
message = convert_event_to_a2a_message(event, invocation_context, subagent_session_ids=subagent_session_ids)
message = convert_event_to_a2a_message(
event,
invocation_context,
subagent_session_ids=subagent_session_ids,
task_id=task_id,
context_id=context_id,
)
if message:
running_event = _create_status_update_event(message, invocation_context, event, task_id, context_id)
a2a_events.append(running_event)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,21 @@ def test_convert_event_to_a2a_events(self):
error_code_key = get_kagent_metadata_key("error_code")
assert error_code_key in error_event.metadata
assert error_event.metadata[error_code_key] == str(genai_types.FinishReason.MALFORMED_FUNCTION_CALL)

def test_message_carries_task_and_context_ids(self):
"""The converted message stamps task_id/context_id so consumers that
flatten task.history can key it to its task without backfilling."""
invocation_context = _create_mock_invocation_context()
content = genai_types.Content(parts=[genai_types.Part(text="hello world")])
event = _create_mock_event(content=content, invocation_id="test_invocation_ids")

result = convert_event_to_a2a_events(event, invocation_context, task_id="task-xyz", context_id="ctx-xyz")

working_events = [
e for e in result if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.working
]
assert len(working_events) == 1
message = working_events[0].status.message
assert message is not None
assert message.task_id == "task-xyz"
assert message.context_id == "ctx-xyz"
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,79 @@ describe("ChatInterface send guard", () => {
expect(mockSendMessageStream).toHaveBeenCalledTimes(1);
});

it("does not block the next send after a same-tab tool-call turn is persisted", async () => {
// Reproduces the false positive seen with tool calls: a same-tab turn that
// makes a tool call streams a ToolCallRequestEvent locally (keyed by its
// real contextId/taskId), but the backend persists agent messages with
// EMPTY contextId/taskId. extractMessagesFromTasks leaves those empty
// (`"" ?? task.contextId` is still "") and rebuilds the converted tool
// message with a fresh uuidv4 messageId, so it keys on ["message", <random>]
// — a different key on every extraction. The guard therefore can never match
// the persisted tool message against the local one, counts the backend as
// ahead, and falsely blocks the next send.
mockBackendTasks([completedTask("task-initial", initialTurn)]);
mockSendMessageStream
.mockResolvedValueOnce(streamOf(completedToolCallStatusEvent("session-1", "task-streamed", "shared-call")))
.mockResolvedValueOnce(streamOf(completedStatusEvent("next answer", "session-1", "task-next")));

renderExistingSession();

expect(await screen.findByText("initial answer")).toBeInTheDocument();

await sendText("tool question");
await waitFor(() => expect(mockSendMessageStream).toHaveBeenCalledTimes(1));

// Backend now reflects exactly that same-tab tool turn. The persisted agent
// tool message carries empty contextId/taskId, as the real backend stores it.
mockBackendTasks([
completedTask("task-initial", initialTurn),
completedTask("task-streamed", [
textMessage(sentMessage().messageId, "user", "tool question", "session-1", "task-streamed"),
toolCallMessage("backend-tool", "", "", "shared-call"),
]),
]);
await sendText("next question");

await waitFor(() => expect(mockSendMessageStream).toHaveBeenCalledTimes(2));
expect(mockToastInfo).not.toHaveBeenCalledWith(staleToastMessage);
});

it("does not block the next send after a same-tab text turn persisted with empty agent ids", async () => {
// The general case behind "every message needs sending twice": the backend
// persists AGENT messages with empty contextId/taskId (A2A optional fields;
// the task is the canonical carrier). The locally-streamed agent message
// carries the task's real ids, so it keys on ["task", ...] while the
// backend-extracted copy — pushed as-is with empty ids — keys on
// ["message", <persisted id>]. They never match, so every turn (each has an
// agent text response) counts the backend as ahead and falsely blocks.
mockBackendTasks([completedTask("task-initial", initialTurn)]);
mockSendMessageStream
.mockResolvedValueOnce(streamOf(completedStatusEvent("same tab answer", "session-1", "task-streamed")))
.mockResolvedValueOnce(streamOf(completedStatusEvent("next answer", "session-1", "task-next")));

renderExistingSession();

expect(await screen.findByText("initial answer")).toBeInTheDocument();

await sendText("same tab question");
await waitFor(() => expect(mockSendMessageStream).toHaveBeenCalledTimes(1));
expect(await screen.findByText("same tab answer")).toBeInTheDocument();

// Backend reflects the turn, with the agent message persisted with empty
// contextId/taskId (the user message keeps its contextId, as the client stamps it).
mockBackendTasks([
completedTask("task-initial", initialTurn),
completedTask("task-streamed", [
textMessage(sentMessage().messageId, "user", "same tab question", "session-1", ""),
textMessage("same-tab-agent", "agent", "same tab answer", "", ""),
]),
]);
await sendText("next question");

await waitFor(() => expect(mockSendMessageStream).toHaveBeenCalledTimes(2));
expect(mockToastInfo).not.toHaveBeenCalledWith(staleToastMessage);
});

it("still blocks when the backend has a cross-tab message not visible locally", async () => {
mockBackendTasks([completedTask("task-initial", initialTurn)]);

Expand Down
28 changes: 21 additions & 7 deletions ui/src/lib/messageHandlers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,18 @@ export function extractMessagesFromTasks(tasks: Task[]): Message[] {
// Agent messages: convert function_call / function_response DataParts to
// the same ToolCallRequestEvent / ToolCallExecutionEvent format produced
// by the live-stream handlers so the rendering component can display them.
const msgContextId = historyItem.contextId ?? task.contextId;
const msgTaskId = historyItem.taskId ?? task.id;
//
// Backfill contextId/taskId from the task when the history item omits them.
// Persisted agent messages (both tool and text) frequently carry empty
// strings here, and `??` would keep the empty string ("" is not nullish).
// The locally-streamed copies of these messages carry the task's real ids,
// so the send guard keys them on (contextId, taskId); without the backfill
// the backend-extracted copies fall back to messageId and never match the
// local ones, counting the backend as ahead and falsely blocking the next
// send on every turn. Treat "" as absent so both copies get the task's
// stable ids. Applied to every extracted agent message below.
const msgContextId = historyItem.contextId || task.contextId;
const msgTaskId = historyItem.taskId || task.id;
const source = getSourceFromMetadata(historyItem.metadata as ADKMetadata | undefined, "assistant");
const msgStats = getMessageTokenStats(historyItem.metadata as Record<string, unknown>);

Expand Down Expand Up @@ -149,12 +159,16 @@ export function extractMessagesFromTasks(tasks: Task[]): Message[] {
}
}

// Text messages (or any message without data parts): push with tokenStats.
// Text messages (or any message without data parts): push with tokenStats
// and the backfilled contextId/taskId so they key the same way the
// locally-streamed copy does.
if (!hasConvertedParts) {
messages.push(msgStats
? { ...historyItem, metadata: { ...(historyItem.metadata as object || {}), tokenStats: msgStats } }
: historyItem
);
messages.push({
...historyItem,
contextId: msgContextId,
taskId: msgTaskId,
...(msgStats ? { metadata: { ...(historyItem.metadata as object || {}), tokenStats: msgStats } } : {}),
});
}
}
}
Expand Down
Loading