From f8205232f1139aba48c5112fbcb6e8d810807bea Mon Sep 17 00:00:00 2001 From: Jiaping Zeng Date: Tue, 18 Nov 2025 04:27:19 -0800 Subject: [PATCH 1/6] AG-UI events Signed-off-by: Jiaping Zeng --- .../ml/common/agui/AGUIConstants.java | 112 ++++++++++++++++++ .../opensearch/ml/common/agui/BaseEvent.java | 76 ++++++++++++ .../ml/common/agui/MessagesSnapshotEvent.java | 48 ++++++++ .../ml/common/agui/RunErrorEvent.java | 54 +++++++++ .../ml/common/agui/RunFinishedEvent.java | 66 +++++++++++ .../ml/common/agui/RunStartedEvent.java | 52 ++++++++ .../common/agui/TextMessageContentEvent.java | 52 ++++++++ .../ml/common/agui/TextMessageEndEvent.java | 47 ++++++++ .../ml/common/agui/TextMessageStartEvent.java | 52 ++++++++ .../ml/common/agui/ToolCallArgsEvent.java | 52 ++++++++ .../ml/common/agui/ToolCallEndEvent.java | 47 ++++++++ .../ml/common/agui/ToolCallResultEvent.java | 64 ++++++++++ .../ml/common/agui/ToolCallStartEvent.java | 59 +++++++++ 13 files changed, 781 insertions(+) create mode 100644 common/src/main/java/org/opensearch/ml/common/agui/AGUIConstants.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agui/BaseEvent.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agui/MessagesSnapshotEvent.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agui/RunErrorEvent.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agui/RunFinishedEvent.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agui/RunStartedEvent.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agui/TextMessageContentEvent.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agui/TextMessageEndEvent.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agui/TextMessageStartEvent.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agui/ToolCallArgsEvent.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agui/ToolCallEndEvent.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agui/ToolCallResultEvent.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agui/ToolCallStartEvent.java diff --git a/common/src/main/java/org/opensearch/ml/common/agui/AGUIConstants.java b/common/src/main/java/org/opensearch/ml/common/agui/AGUIConstants.java new file mode 100644 index 0000000000..67a5700ed9 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agui/AGUIConstants.java @@ -0,0 +1,112 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agui; + +/** + * Constants for AG-UI implementation. + * + * Naming Conventions: + * AGUI_ROLE_* - Message role identifiers + * AGUI_PARAM_* - Internal parameter keys + * AGUI_FIELD_* - External API field names + * AGUI_EVENT_* - Event type identifiers + * AGUI_PREFIX_* - ID prefixes for generated identifiers + */ +public final class AGUIConstants { + + // ========== Message Roles ========== + + /** Role identifier for assistant messages */ + public static final String AGUI_ROLE_ASSISTANT = "assistant"; + + /** Role identifier for user messages */ + public static final String AGUI_ROLE_USER = "user"; + + /** Role identifier for tool result messages */ + public static final String AGUI_ROLE_TOOL = "tool"; + + // ========== Parameter Keys (Internal) ========== + + /** Parameter key for AG-UI thread identifier */ + public static final String AGUI_PARAM_THREAD_ID = "agui_thread_id"; + + /** Parameter key for AG-UI run identifier */ + public static final String AGUI_PARAM_RUN_ID = "agui_run_id"; + + /** Parameter key for AG-UI messages array */ + public static final String AGUI_PARAM_MESSAGES = "agui_messages"; + + /** Parameter key for AG-UI tools array */ + public static final String AGUI_PARAM_TOOLS = "agui_tools"; + + /** Parameter key for AG-UI context array */ + public static final String AGUI_PARAM_CONTEXT = "agui_context"; + + /** Parameter key for AG-UI state object */ + public static final String AGUI_PARAM_STATE = "agui_state"; + + /** Parameter key for AG-UI forwarded properties */ + public static final String AGUI_PARAM_FORWARDED_PROPS = "agui_forwarded_props"; + + /** Parameter key for AG-UI tool call results */ + public static final String AGUI_PARAM_TOOL_CALL_RESULTS = "agui_tool_call_results"; + + /** Parameter key for AG-UI assistant tool call messages */ + public static final String AGUI_PARAM_ASSISTANT_TOOL_CALL_MESSAGES = "agui_assistant_tool_call_messages"; + + /** Parameter key for backend tool names (used for filtering) */ + public static final String AGUI_PARAM_BACKEND_TOOL_NAMES = "backend_tool_names"; + + // ========== Field Names (External API) ========== + + /** Field name for thread identifier in AG-UI input */ + public static final String AGUI_FIELD_THREAD_ID = "threadId"; + + /** Field name for run identifier in AG-UI input */ + public static final String AGUI_FIELD_RUN_ID = "runId"; + + /** Field name for messages array in AG-UI input */ + public static final String AGUI_FIELD_MESSAGES = "messages"; + + /** Field name for tools array in AG-UI input */ + public static final String AGUI_FIELD_TOOLS = "tools"; + + /** Field name for context array in AG-UI input */ + public static final String AGUI_FIELD_CONTEXT = "context"; + + /** Field name for state object in AG-UI input */ + public static final String AGUI_FIELD_STATE = "state"; + + /** Field name for forwarded properties in AG-UI input */ + public static final String AGUI_FIELD_FORWARDED_PROPS = "forwardedProps"; + + /** Field name for message role */ + public static final String AGUI_FIELD_ROLE = "role"; + + /** Field name for message content */ + public static final String AGUI_FIELD_CONTENT = "content"; + + /** Field name for tool call identifier */ + public static final String AGUI_FIELD_TOOL_CALL_ID = "toolCallId"; + + /** Field name for tool calls array */ + public static final String AGUI_FIELD_TOOL_CALLS = "toolCalls"; + + /** Field name for message identifier */ + public static final String AGUI_FIELD_ID = "id"; + + /** Field name for tool call type */ + public static final String AGUI_FIELD_TYPE = "type"; + + /** Field name for function object in tool calls */ + public static final String AGUI_FIELD_FUNCTION = "function"; + + /** Field name for function name */ + public static final String AGUI_FIELD_NAME = "name"; + + /** Field name for function arguments */ + public static final String AGUI_FIELD_ARGUMENTS = "arguments"; +} diff --git a/common/src/main/java/org/opensearch/ml/common/agui/BaseEvent.java b/common/src/main/java/org/opensearch/ml/common/agui/BaseEvent.java new file mode 100644 index 0000000000..72c4dfec40 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agui/BaseEvent.java @@ -0,0 +1,76 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agui; + +import java.io.IOException; +import java.util.Map; + +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentFragment; +import org.opensearch.core.xcontent.XContentBuilder; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@AllArgsConstructor +public abstract class BaseEvent implements ToXContentFragment, Writeable { + + protected String type; + protected Long timestamp; + protected Map rawEvent; + + public BaseEvent(StreamInput input) throws IOException { + this.type = input.readString(); + this.timestamp = input.readOptionalLong(); + if (input.readBoolean()) { + this.rawEvent = input.readMap(); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(type); + out.writeOptionalLong(timestamp); + if (rawEvent != null) { + out.writeBoolean(true); + out.writeMap(rawEvent); + } else { + out.writeBoolean(false); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("type", type); + if (timestamp != null) { + builder.field("timestamp", timestamp); + } + if (rawEvent != null) { + builder.field("rawEvent", rawEvent); + } + addEventSpecificFields(builder, params); + builder.endObject(); + return builder; + } + + protected abstract void addEventSpecificFields(XContentBuilder builder, Params params) throws IOException; + + public String toJsonString() { + try { + return toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString(); + } catch (IOException e) { + throw new RuntimeException("Failed to serialize event to JSON", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agui/MessagesSnapshotEvent.java b/common/src/main/java/org/opensearch/ml/common/agui/MessagesSnapshotEvent.java new file mode 100644 index 0000000000..4b601ea754 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agui/MessagesSnapshotEvent.java @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agui; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; + +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@EqualsAndHashCode(callSuper = true) +public class MessagesSnapshotEvent extends BaseEvent { + + public static final String TYPE = "MESSAGES_SNAPSHOT"; + + private List messages; + + public MessagesSnapshotEvent(List messages) { + super(TYPE, System.currentTimeMillis(), null); + this.messages = messages; + } + + public MessagesSnapshotEvent(StreamInput input) throws IOException { + super(input); + this.messages = input.readList(StreamInput::readGenericValue); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeCollection(messages, StreamOutput::writeGenericValue); + } + + @Override + protected void addEventSpecificFields(XContentBuilder builder, Params params) throws IOException { + builder.field("messages", messages); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agui/RunErrorEvent.java b/common/src/main/java/org/opensearch/ml/common/agui/RunErrorEvent.java new file mode 100644 index 0000000000..d03e53167e --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agui/RunErrorEvent.java @@ -0,0 +1,54 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agui; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; + +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@EqualsAndHashCode(callSuper = true) +public class RunErrorEvent extends BaseEvent { + + public static final String TYPE = "RUN_ERROR"; + + private String message; + private String code; + + public RunErrorEvent(String message, String code) { + super(TYPE, System.currentTimeMillis(), null); + this.message = message != null ? message : ""; + this.code = code; + } + + public RunErrorEvent(StreamInput input) throws IOException { + super(input); + this.message = input.readString(); + this.code = input.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(message); + out.writeOptionalString(code); + } + + @Override + protected void addEventSpecificFields(XContentBuilder builder, Params params) throws IOException { + builder.field("message", message); + if (code != null) { + builder.field("code", code); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agui/RunFinishedEvent.java b/common/src/main/java/org/opensearch/ml/common/agui/RunFinishedEvent.java new file mode 100644 index 0000000000..2f320df73a --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agui/RunFinishedEvent.java @@ -0,0 +1,66 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agui; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; + +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@EqualsAndHashCode(callSuper = true) +public class RunFinishedEvent extends BaseEvent { + + public static final String TYPE = "RUN_FINISHED"; + + private String threadId; + private String runId; + private Object result; + + public RunFinishedEvent(String threadId, String runId, Object result) { + super(TYPE, System.currentTimeMillis(), null); + this.threadId = threadId; + this.runId = runId; + this.result = result; + } + + public RunFinishedEvent(StreamInput input) throws IOException { + super(input); + this.threadId = input.readString(); + this.runId = input.readString(); + if (input.readBoolean()) { + this.result = input.readGenericValue(); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(threadId); + out.writeString(runId); + if (result != null) { + out.writeBoolean(true); + out.writeGenericValue(result); + } else { + out.writeBoolean(false); + } + } + + @Override + protected void addEventSpecificFields(XContentBuilder builder, Params params) throws IOException { + builder.field("threadId", threadId); + builder.field("runId", runId); + if (result != null) { + builder.field("result", result); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agui/RunStartedEvent.java b/common/src/main/java/org/opensearch/ml/common/agui/RunStartedEvent.java new file mode 100644 index 0000000000..aafe9c8037 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agui/RunStartedEvent.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agui; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; + +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@EqualsAndHashCode(callSuper = true) +public class RunStartedEvent extends BaseEvent { + + public static final String TYPE = "RUN_STARTED"; + + private String threadId; + private String runId; + + public RunStartedEvent(String threadId, String runId) { + super(TYPE, System.currentTimeMillis(), null); + this.threadId = threadId; + this.runId = runId; + } + + public RunStartedEvent(StreamInput input) throws IOException { + super(input); + this.threadId = input.readString(); + this.runId = input.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(threadId); + out.writeString(runId); + } + + @Override + protected void addEventSpecificFields(XContentBuilder builder, Params params) throws IOException { + builder.field("threadId", threadId); + builder.field("runId", runId); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agui/TextMessageContentEvent.java b/common/src/main/java/org/opensearch/ml/common/agui/TextMessageContentEvent.java new file mode 100644 index 0000000000..2657833759 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agui/TextMessageContentEvent.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agui; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; + +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@EqualsAndHashCode(callSuper = true) +public class TextMessageContentEvent extends BaseEvent { + + public static final String TYPE = "TEXT_MESSAGE_CONTENT"; + + private String messageId; + private String delta; + + public TextMessageContentEvent(String messageId, String delta) { + super(TYPE, System.currentTimeMillis(), null); + this.messageId = messageId; + this.delta = delta != null ? delta : ""; + } + + public TextMessageContentEvent(StreamInput input) throws IOException { + super(input); + this.messageId = input.readString(); + this.delta = input.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(messageId); + out.writeString(delta); + } + + @Override + protected void addEventSpecificFields(XContentBuilder builder, Params params) throws IOException { + builder.field("messageId", messageId); + builder.field("delta", delta); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agui/TextMessageEndEvent.java b/common/src/main/java/org/opensearch/ml/common/agui/TextMessageEndEvent.java new file mode 100644 index 0000000000..770507d4e8 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agui/TextMessageEndEvent.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agui; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; + +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@EqualsAndHashCode(callSuper = true) +public class TextMessageEndEvent extends BaseEvent { + + public static final String TYPE = "TEXT_MESSAGE_END"; + + private String messageId; + + public TextMessageEndEvent(String messageId) { + super(TYPE, System.currentTimeMillis(), null); + this.messageId = messageId; + } + + public TextMessageEndEvent(StreamInput input) throws IOException { + super(input); + this.messageId = input.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(messageId); + } + + @Override + protected void addEventSpecificFields(XContentBuilder builder, Params params) throws IOException { + builder.field("messageId", messageId); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agui/TextMessageStartEvent.java b/common/src/main/java/org/opensearch/ml/common/agui/TextMessageStartEvent.java new file mode 100644 index 0000000000..72edc993b5 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agui/TextMessageStartEvent.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agui; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; + +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@EqualsAndHashCode(callSuper = true) +public class TextMessageStartEvent extends BaseEvent { + + public static final String TYPE = "TEXT_MESSAGE_START"; + + private String messageId; + private String role; + + public TextMessageStartEvent(String messageId, String role) { + super(TYPE, System.currentTimeMillis(), null); + this.messageId = messageId; + this.role = role != null ? role : "assistant"; + } + + public TextMessageStartEvent(StreamInput input) throws IOException { + super(input); + this.messageId = input.readString(); + this.role = input.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(messageId); + out.writeString(role); + } + + @Override + protected void addEventSpecificFields(XContentBuilder builder, Params params) throws IOException { + builder.field("messageId", messageId); + builder.field("role", role); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agui/ToolCallArgsEvent.java b/common/src/main/java/org/opensearch/ml/common/agui/ToolCallArgsEvent.java new file mode 100644 index 0000000000..5fbf4c420b --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agui/ToolCallArgsEvent.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agui; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; + +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@EqualsAndHashCode(callSuper = true) +public class ToolCallArgsEvent extends BaseEvent { + + public static final String TYPE = "TOOL_CALL_ARGS"; + + private String toolCallId; + private String delta; + + public ToolCallArgsEvent(String toolCallId, String delta) { + super(TYPE, System.currentTimeMillis(), null); + this.toolCallId = toolCallId; + this.delta = delta; + } + + public ToolCallArgsEvent(StreamInput input) throws IOException { + super(input); + this.toolCallId = input.readString(); + this.delta = input.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(toolCallId); + out.writeString(delta); + } + + @Override + protected void addEventSpecificFields(XContentBuilder builder, Params params) throws IOException { + builder.field("toolCallId", toolCallId); + builder.field("delta", delta); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agui/ToolCallEndEvent.java b/common/src/main/java/org/opensearch/ml/common/agui/ToolCallEndEvent.java new file mode 100644 index 0000000000..1820f90523 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agui/ToolCallEndEvent.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agui; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; + +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@EqualsAndHashCode(callSuper = true) +public class ToolCallEndEvent extends BaseEvent { + + public static final String TYPE = "TOOL_CALL_END"; + + private String toolCallId; + + public ToolCallEndEvent(String toolCallId) { + super(TYPE, System.currentTimeMillis(), null); + this.toolCallId = toolCallId; + } + + public ToolCallEndEvent(StreamInput input) throws IOException { + super(input); + this.toolCallId = input.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(toolCallId); + } + + @Override + protected void addEventSpecificFields(XContentBuilder builder, Params params) throws IOException { + builder.field("toolCallId", toolCallId); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agui/ToolCallResultEvent.java b/common/src/main/java/org/opensearch/ml/common/agui/ToolCallResultEvent.java new file mode 100644 index 0000000000..0e2f9747de --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agui/ToolCallResultEvent.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agui; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; + +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@EqualsAndHashCode(callSuper = true) +public class ToolCallResultEvent extends BaseEvent { + + public static final String TYPE = "TOOL_CALL_RESULT"; + + private String messageId; + private String toolCallId; + private String content; + private String role; + + public ToolCallResultEvent(String messageId, String toolCallId, String content) { + super(TYPE, System.currentTimeMillis(), null); + this.messageId = messageId; + this.toolCallId = toolCallId; + this.content = content; + this.role = "tool"; + } + + public ToolCallResultEvent(StreamInput input) throws IOException { + super(input); + this.messageId = input.readString(); + this.toolCallId = input.readString(); + this.content = input.readString(); + this.role = input.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(messageId); + out.writeString(toolCallId); + out.writeString(content); + out.writeOptionalString(role); + } + + @Override + protected void addEventSpecificFields(XContentBuilder builder, Params params) throws IOException { + builder.field("messageId", messageId); + builder.field("toolCallId", toolCallId); + builder.field("content", content); + if (role != null) { + builder.field("role", role); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agui/ToolCallStartEvent.java b/common/src/main/java/org/opensearch/ml/common/agui/ToolCallStartEvent.java new file mode 100644 index 0000000000..780ceaa1e8 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agui/ToolCallStartEvent.java @@ -0,0 +1,59 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agui; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; + +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@EqualsAndHashCode(callSuper = true) +public class ToolCallStartEvent extends BaseEvent { + + public static final String TYPE = "TOOL_CALL_START"; + + private String toolCallId; + private String toolCallName; + private String parentMessageId; + + public ToolCallStartEvent(String toolCallId, String toolCallName, String parentMessageId) { + super(TYPE, System.currentTimeMillis(), null); + this.toolCallId = toolCallId; + this.toolCallName = toolCallName; + this.parentMessageId = parentMessageId; + } + + public ToolCallStartEvent(StreamInput input) throws IOException { + super(input); + this.toolCallId = input.readString(); + this.toolCallName = input.readString(); + this.parentMessageId = input.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(toolCallId); + out.writeString(toolCallName); + out.writeOptionalString(parentMessageId); + } + + @Override + protected void addEventSpecificFields(XContentBuilder builder, Params params) throws IOException { + builder.field("toolCallId", toolCallId); + builder.field("toolCallName", toolCallName); + if (parentMessageId != null) { + builder.field("parentMessageId", parentMessageId); + } + } +} From 44e353ed3028e7dcf2344f2517c24469b2c3de0b Mon Sep 17 00:00:00 2001 From: Jiaping Zeng Date: Tue, 18 Nov 2025 15:36:38 -0800 Subject: [PATCH 2/6] convert AG-UI input to agent parameters Signed-off-by: Jiaping Zeng --- .../org/opensearch/ml/common/MLAgentType.java | 3 +- .../ml/common/agui/AGUIInputConverter.java | 186 ++++++++++++++++++ .../ml/common/output/MLOutputType.java | 3 +- .../ml/common/utils/StringUtils.java | 5 + .../algorithms/agent/AGUIFrontendTool.java | 83 ++++++++ .../remote/AwsConnectorExecutor.java | 2 +- .../remote/HttpJsonConnectorExecutor.java | 2 +- .../remote/RemoteConnectorExecutor.java | 6 + .../ml/rest/RestMLExecuteAction.java | 20 +- 9 files changed, 302 insertions(+), 8 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/agui/AGUIInputConverter.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AGUIFrontendTool.java diff --git a/common/src/main/java/org/opensearch/ml/common/MLAgentType.java b/common/src/main/java/org/opensearch/ml/common/MLAgentType.java index 2dd2614634..5ebb3948eb 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLAgentType.java +++ b/common/src/main/java/org/opensearch/ml/common/MLAgentType.java @@ -11,7 +11,8 @@ public enum MLAgentType { FLOW, CONVERSATIONAL, CONVERSATIONAL_FLOW, - PLAN_EXECUTE_AND_REFLECT; + PLAN_EXECUTE_AND_REFLECT, + AG_UI; public static MLAgentType from(String value) { if (value == null) { diff --git a/common/src/main/java/org/opensearch/ml/common/agui/AGUIInputConverter.java b/common/src/main/java/org/opensearch/ml/common/agui/AGUIInputConverter.java new file mode 100644 index 0000000000..8ae2cde044 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agui/AGUIInputConverter.java @@ -0,0 +1,186 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agui; + +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_CONTENT; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_CONTEXT; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_FORWARDED_PROPS; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_MESSAGES; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_ROLE; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_RUN_ID; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_STATE; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_THREAD_ID; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_TOOLS; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_TOOL_CALL_ID; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_CONTEXT; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_FORWARDED_PROPS; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_MESSAGES; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_RUN_ID; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_STATE; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_THREAD_ID; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_TOOLS; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_TOOL_CALL_RESULTS; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_ROLE_USER; +import static org.opensearch.ml.common.utils.StringUtils.getStringField; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.execute.agent.AgentMLInput; + +import com.google.gson.Gson; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class AGUIInputConverter { + + private static final Gson gson = new Gson(); + + public static boolean isAGUIInput(String inputJson) { + try { + JsonObject jsonObj = JsonParser.parseString(inputJson).getAsJsonObject(); + + // Check required fields exist + if (!jsonObj.has(AGUI_FIELD_THREAD_ID) + || !jsonObj.has(AGUI_FIELD_RUN_ID) + || !jsonObj.has(AGUI_FIELD_MESSAGES) + || !jsonObj.has(AGUI_FIELD_TOOLS)) { + return false; + } + + // Validate messages is an array + JsonElement messages = jsonObj.get(AGUI_FIELD_MESSAGES); + if (!messages.isJsonArray()) { + return false; + } + + // Validate tools is an array + JsonElement tools = jsonObj.get(AGUI_FIELD_TOOLS); + if (!tools.isJsonArray()) { + return false; + } + + return true; + } catch (Exception e) { + log.debug("Failed to parse input as JSON for AG-UI detection", e); + return false; + } + } + + public static AgentMLInput convertFromAGUIInput(String aguiInputJson, String agentId, String tenantId, boolean isAsync) { + try { + JsonObject aguiInput = JsonParser.parseString(aguiInputJson).getAsJsonObject(); + + String threadId = getStringField(aguiInput, AGUI_FIELD_THREAD_ID); + String runId = getStringField(aguiInput, AGUI_FIELD_RUN_ID); + JsonElement state = aguiInput.get(AGUI_FIELD_STATE); + JsonElement messages = aguiInput.get(AGUI_FIELD_MESSAGES); + JsonElement tools = aguiInput.get(AGUI_FIELD_TOOLS); + JsonElement context = aguiInput.get(AGUI_FIELD_CONTEXT); + JsonElement forwardedProps = aguiInput.get(AGUI_FIELD_FORWARDED_PROPS); + + Map parameters = new HashMap<>(); + parameters.put(AGUI_PARAM_THREAD_ID, threadId); + parameters.put(AGUI_PARAM_RUN_ID, runId); + + if (state != null) { + parameters.put(AGUI_PARAM_STATE, gson.toJson(state)); + } + + if (messages != null) { + parameters.put(AGUI_PARAM_MESSAGES, gson.toJson(messages)); + extractUserQuestion(messages, parameters); + } + + if (tools != null) { + parameters.put(AGUI_PARAM_TOOLS, gson.toJson(tools)); + } + + if (context != null) { + parameters.put(AGUI_PARAM_CONTEXT, gson.toJson(context)); + } + + if (forwardedProps != null) { + parameters.put(AGUI_PARAM_FORWARDED_PROPS, gson.toJson(forwardedProps)); + } + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build(); + AgentMLInput agentMLInput = new AgentMLInput(agentId, tenantId, FunctionName.AGENT, inputDataSet, isAsync); + + log.debug("Converted AG-UI input to ML-Commons format for agent: {}", agentId); + return agentMLInput; + + } catch (Exception e) { + log.error("Failed to convert AG-UI input to ML-Commons format", e); + throw new IllegalArgumentException("Invalid AG-UI input format", e); + } + } + + private static void extractUserQuestion(JsonElement messages, Map parameters) { + if (messages == null || !messages.isJsonArray()) { + throw new IllegalArgumentException("Invalid AG-UI messages"); + } + + try { + // Find the last user message to use as the current question + String lastUserMessage = null; + String toolCallResults = null; + + for (JsonElement messageElement : messages.getAsJsonArray()) { + if (messageElement.isJsonObject()) { + JsonObject message = messageElement.getAsJsonObject(); + JsonElement roleElement = message.get(AGUI_FIELD_ROLE); + JsonElement contentElement = message.get(AGUI_FIELD_CONTENT); + JsonElement toolCallIdElement = message.get(AGUI_FIELD_TOOL_CALL_ID); + + if (roleElement != null + && AGUI_ROLE_USER.equals(roleElement.getAsString()) + && contentElement != null + && !contentElement.isJsonNull()) { + + String content = contentElement.getAsString(); + + // Check if this is a tool call result (has toolCallId field) + if (toolCallIdElement != null && !toolCallIdElement.isJsonNull()) { + // This is a tool call result from frontend + String toolCallId = toolCallIdElement.getAsString(); + + // Create tool result structure + JsonObject toolResult = new JsonObject(); + toolResult.addProperty("tool_call_id", toolCallId); + toolResult.addProperty("content", content); + + toolCallResults = gson.toJson(List.of(toolResult)); + log.debug("Extracted tool call result from AG-UI messages: toolCallId={}, content={}", toolCallId, content); + } else { + // Regular user message + lastUserMessage = content; + } + } + } + } + + // Set appropriate parameters based on what was found + if (toolCallResults != null) { + parameters.put(AGUI_PARAM_TOOL_CALL_RESULTS, toolCallResults); + log.debug("Detected AG-UI tool call results: {}", toolCallResults); + } else if (lastUserMessage != null) { + parameters.put("question", lastUserMessage); + log.debug("Extracted user question from AG-UI messages: {}", lastUserMessage); + } else { + throw new IllegalArgumentException("No user message found in AG-UI messages"); + } + } catch (Exception e) { + throw new IllegalArgumentException("Invalid AG-UI message format", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/output/MLOutputType.java b/common/src/main/java/org/opensearch/ml/common/output/MLOutputType.java index d500183296..ed4fcc83bc 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/MLOutputType.java +++ b/common/src/main/java/org/opensearch/ml/common/output/MLOutputType.java @@ -11,5 +11,6 @@ public enum MLOutputType { SAMPLE_ALGO, MODEL_TENSOR, MCORR_TENSOR, - ML_TASK_OUTPUT + ML_TASK_OUTPUT, + AG_UI_OUTPUT } diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index a3f1a3b416..96ab3a02a4 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -726,4 +726,9 @@ public Float read(JsonReader in) throws IOException { return f; } } + + public static String getStringField(JsonObject obj, String fieldName) { + JsonElement element = obj.get(fieldName); + return element != null && !element.isJsonNull() ? element.getAsString() : null; + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AGUIFrontendTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AGUIFrontendTool.java new file mode 100644 index 0000000000..7e722fb29d --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AGUIFrontendTool.java @@ -0,0 +1,83 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.agent; + +import java.util.Map; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.spi.tools.Tool; + +import lombok.extern.log4j.Log4j2; + +/** + * Placeholder tool for AG-UI frontend tools. + * Frontend tools are not executed on the backend - they are executed in the browser. + * This placeholder allows the LLM to see frontend tools in the unified tool list. + */ +@Log4j2 +public class AGUIFrontendTool implements Tool { + private final String toolName; + private final String toolDescription; + private final Map toolAttributes; + + public AGUIFrontendTool(String toolName, String toolDescription, Map toolAttributes) { + this.toolName = toolName; + this.toolDescription = toolDescription; + this.toolAttributes = toolAttributes; + } + + @Override + public String getName() { + return toolName; + } + + @Override + public void setName(String name) {} + + @Override + public String getDescription() { + return toolDescription; + } + + @Override + public void setDescription(String description) {} + + @Override + public Map getAttributes() { + return toolAttributes; + } + + @Override + public void setAttributes(Map attributes) {} + + @Override + @SuppressWarnings("unchecked") + public void run(Map parameters, ActionListener listener) { + log.debug("AG-UI: Frontend tool {} executed with parameters: {}", toolName, parameters); + String errorResult = String + .format( + "Error: Tool '%s' is a frontend tool and should be called via function calling in the final response, " + + "not during ReAct execution.", + toolName + ); + listener.onResponse((T) errorResult); + } + + @Override + public boolean validate(Map parameters) { + return true; + } + + @Override + public String getType() { + return "AGUIFrontendTool"; + } + + @Override + public String getVersion() { + return "1.0.0"; + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index 3b53935aaf..b1fbc57d03 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -157,7 +157,7 @@ public void invokeRemoteServiceStream( llmInterface = StringEscapeUtils.unescapeJava(llmInterface); validateLLMInterface(llmInterface); - StreamingHandler handler = StreamingHandlerFactory.createHandler(llmInterface, connector, getHttpClient(), null); + StreamingHandler handler = StreamingHandlerFactory.createHandler(llmInterface, connector, getHttpClient(), null, parameters); handler.startStream(action, parameters, payload, actionListener); } catch (Exception e) { log.error("Failed to execute streaming", e); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index 7804770258..9eb433decd 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -157,7 +157,7 @@ public void invokeRemoteServiceStream( validateLLMInterface(llmInterface); StreamingHandler handler = StreamingHandlerFactory - .createHandler(llmInterface, connector, null, super.getConnectorClientConfig()); + .createHandler(llmInterface, connector, null, super.getConnectorClientConfig(), parameters); handler.startStream(action, parameters, payload, actionListener); } catch (Exception e) { log.error("Failed to execute streaming", e); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index 5181e8d087..2ead5ea25d 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -282,6 +282,12 @@ && getUserRateLimiterMap().get(user.getName()) != null String parentInteractionId = parameters.get("parent_interaction_id"); // TODO: find a better way to differentiate agent and predict request boolean isAgentRequest = (memoryId != null || parentInteractionId != null); + getLogger() + .info( + "RemoteConnectorExecutor: Creating StreamPredictActionListener - isAgentRequest={}, agentListener={}", + isAgentRequest, + agentListener != null ? "present" : "null" + ); StreamPredictActionListener streamListener = new StreamPredictActionListener<>( channel, isAgentRequest ? agentListener : null, diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java index 6b293595c6..d6b57d7487 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java @@ -27,6 +27,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.agui.AGUIInputConverter; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.execute.agent.AgentMLInput; @@ -124,10 +125,21 @@ MLExecuteTaskRequest getRequest(RestRequest request) throws IOException { String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); String agentId = request.param(PARAMETER_AGENT_ID); functionName = FunctionName.AGENT; - input = MLInput.parse(parser, functionName.name()); - ((AgentMLInput) input).setAgentId(agentId); - ((AgentMLInput) input).setTenantId(tenantId); - ((AgentMLInput) input).setIsAsync(async); + + String requestBodyJson = request.contentOrSourceParam().v2().utf8ToString(); + if (AGUIInputConverter.isAGUIInput(requestBodyJson)) { + throw new IllegalArgumentException( + "AG-UI agents require streaming execution. " + + "Please use the streaming endpoint: POST /_plugins/_ml/agents/" + + agentId + + "/_execute/stream" + ); + } else { + input = MLInput.parse(parser, functionName.name()); + ((AgentMLInput) input).setAgentId(agentId); + ((AgentMLInput) input).setTenantId(tenantId); + ((AgentMLInput) input).setIsAsync(async); + } } else if (uri.startsWith(ML_BASE_URI + "/tools/")) { if (!mlFeatureEnabledSetting.isToolExecuteEnabled()) { throw new IllegalStateException(ML_COMMONS_EXECUTE_TOOL_DISABLED_MESSAGE); From 0351e4e87b940e1b7acf558d45edbab2b46c6273 Mon Sep 17 00:00:00 2001 From: Jiaping Zeng Date: Tue, 18 Nov 2025 15:42:52 -0800 Subject: [PATCH 3/6] add AG-UI tool use Signed-off-by: Jiaping Zeng --- .../streaming/StreamingHandlerFactory.java | 30 ++++++++++++---- ...rockConverseDeepseekR1FunctionCalling.java | 7 ++++ .../BedrockConverseFunctionCalling.java | 36 +++++++++++++++++++ .../function_calling/FunctionCalling.java | 8 +++++ ...penaiV1ChatCompletionsFunctionCalling.java | 5 +++ 5 files changed, 79 insertions(+), 7 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/StreamingHandlerFactory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/StreamingHandlerFactory.java index fb4ecf9fbf..90d8623b72 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/StreamingHandlerFactory.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/StreamingHandlerFactory.java @@ -9,6 +9,7 @@ import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS; import java.lang.reflect.Constructor; +import java.util.Map; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorClientConfig; @@ -23,24 +24,38 @@ public static StreamingHandler createHandler( Connector connector, SdkAsyncHttpClient httpClient, ConnectorClientConfig connectorClientConfig + ) { + return createHandler(llmInterface, connector, httpClient, connectorClientConfig, null); + } + + public static StreamingHandler createHandler( + String llmInterface, + Connector connector, + SdkAsyncHttpClient httpClient, + ConnectorClientConfig connectorClientConfig, + Map parameters ) { switch (llmInterface.toLowerCase()) { case LLM_INTERFACE_BEDROCK_CONVERSE_CLAUDE: - return createBedrockHandler(httpClient, connector); + return createBedrockHandler(httpClient, connector, parameters); case LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS: - return createHttpHandler(llmInterface, connector, connectorClientConfig); + return createHttpHandler(llmInterface, connector, connectorClientConfig, parameters); default: throw new IllegalArgumentException("Unsupported LLM interface: " + llmInterface); } } - private static StreamingHandler createBedrockHandler(SdkAsyncHttpClient httpClient, Connector connector) { + private static StreamingHandler createBedrockHandler( + SdkAsyncHttpClient httpClient, + Connector connector, + Map parameters + ) { try { // Use reflection to avoid hard dependency Class handlerClass = Class.forName("org.opensearch.ml.engine.algorithms.remote.streaming.BedrockStreamingHandler"); Constructor constructor = handlerClass - .getConstructor(SdkAsyncHttpClient.class, Class.forName("org.opensearch.ml.common.connector.AwsConnector")); - return (StreamingHandler) constructor.newInstance(httpClient, connector); + .getConstructor(SdkAsyncHttpClient.class, Class.forName("org.opensearch.ml.common.connector.AwsConnector"), Map.class); + return (StreamingHandler) constructor.newInstance(httpClient, connector, parameters); } catch (ClassNotFoundException e) { throw new MLException("Bedrock streaming not available - Bedrock SDK not found", e); } catch (Exception e) { @@ -51,8 +66,9 @@ private static StreamingHandler createBedrockHandler(SdkAsyncHttpClient httpClie private static StreamingHandler createHttpHandler( String llmInterface, Connector connector, - ConnectorClientConfig connectorClientConfig + ConnectorClientConfig connectorClientConfig, + Map parameters ) { - return new HttpStreamingHandler(llmInterface, connector, connectorClientConfig); + return new HttpStreamingHandler(llmInterface, connector, connectorClientConfig, parameters); } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/function_calling/BedrockConverseDeepseekR1FunctionCalling.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/function_calling/BedrockConverseDeepseekR1FunctionCalling.java index 7bc0a0c7d6..a5b63450e5 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/function_calling/BedrockConverseDeepseekR1FunctionCalling.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/function_calling/BedrockConverseDeepseekR1FunctionCalling.java @@ -118,4 +118,11 @@ public List supply(List> toolResults) { return List.of(toolMessage); } + + @Override + public String formatAGUIToolCalls(String toolCallsJson) { + throw new UnsupportedOperationException( + "AG-UI is not yet supported with Deepseek R1, please use a different LLM interface such as bedrock/converse/claude or openai/v1/chat/completions." + ); + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/function_calling/BedrockConverseFunctionCalling.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/function_calling/BedrockConverseFunctionCalling.java index 41c625b78d..d5b6660a00 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/function_calling/BedrockConverseFunctionCalling.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/function_calling/BedrockConverseFunctionCalling.java @@ -25,6 +25,7 @@ import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.INTERACTION_TEMPLATE_TOOL_RESPONSE; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -32,6 +33,7 @@ import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.utils.StringUtils; +import com.google.gson.Gson; import com.jayway.jsonpath.JsonPath; import lombok.Data; @@ -121,6 +123,40 @@ public List supply(List> toolResults) { return List.of(toolMessage); } + @Override + public String formatAGUIToolCalls(String toolCallsJson) { + BedrockMessage assistantMessage = new BedrockMessage("assistant"); + Gson gson = new Gson(); + + try { + List toolCalls = gson.fromJson(toolCallsJson, List.class); + for (Object toolCallObj : toolCalls) { + Map toolCall = (Map) toolCallObj; + Map toolUse = new HashMap<>(); + toolUse.put("toolUseId", toolCall.get("id")); + + Map function = (Map) toolCall.get("function"); + if (function != null) { + toolUse.put("name", function.get("name")); + + String argumentsJson = (String) function.get("arguments"); + try { + Object argumentsObj = gson.fromJson(argumentsJson, Object.class); + toolUse.put("input", argumentsObj); + } catch (Exception e) { + toolUse.put("input", Map.of()); + } + } + + assistantMessage.getContent().add(Map.of("toolUse", toolUse)); + } + } catch (Exception e) { + return "{\"role\":\"assistant\",\"content\":[]}"; + } + + return assistantMessage.getResponse(); + } + @Data public static class ToolResult { private String toolUseId; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/function_calling/FunctionCalling.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/function_calling/FunctionCalling.java index 5482c0346f..2499dab685 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/function_calling/FunctionCalling.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/function_calling/FunctionCalling.java @@ -35,4 +35,12 @@ public interface FunctionCalling { * @return a LLMMessage containing tool results. */ List supply(List> toolResults); + + /** + * Format AG-UI tool calls into an assistant message in LLM-specific format. + * + * @param toolCallsJson JSON string containing array of tool calls from AG-UI. + * @return JSON string representing the assistant message with tool calls in LLM-specific format + */ + String formatAGUIToolCalls(String toolCallsJson); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/function_calling/OpenaiV1ChatCompletionsFunctionCalling.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/function_calling/OpenaiV1ChatCompletionsFunctionCalling.java index 3f37f54749..ed88ef98f3 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/function_calling/OpenaiV1ChatCompletionsFunctionCalling.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/function_calling/OpenaiV1ChatCompletionsFunctionCalling.java @@ -117,4 +117,9 @@ public List supply(List> toolResults) { return messages; } + + @Override + public String formatAGUIToolCalls(String toolCallsJson) { + return "{\"role\":\"assistant\",\"tool_calls\":" + toolCallsJson + "}"; + } } From 1caab620b68a110d7aa67bbe8b9491e4f5f2b912 Mon Sep 17 00:00:00 2001 From: Jiaping Zeng Date: Tue, 18 Nov 2025 20:37:25 -0800 Subject: [PATCH 4/6] add AG-UI processing Signed-off-by: Jiaping Zeng --- .../engine/algorithms/agent/AgentUtils.java | 44 ++ .../algorithms/agent/MLAGUIAgentRunner.java | 410 ++++++++++++++++++ .../algorithms/agent/MLAgentExecutor.java | 11 + .../algorithms/agent/MLChatAgentRunner.java | 286 ++++++++++-- .../algorithms/agent/StreamingWrapper.java | 31 +- .../streaming/BaseStreamingHandler.java | 21 + .../streaming/BedrockStreamingHandler.java | 139 +++++- .../streaming/HttpStreamingHandler.java | 100 ++++- .../ml/rest/RestMLExecuteStreamAction.java | 207 +++++++-- 9 files changed, 1145 insertions(+), 104 deletions(-) create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAGUIAgentRunner.java diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index 7de547127a..1fb6a14959 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -1014,4 +1014,48 @@ public static Tool createTool(Map toolFactories, Map> parseFrontendTools(String aguiTools) { + List> frontendTools = new ArrayList<>(); + if (aguiTools != null && !aguiTools.isEmpty() && !aguiTools.trim().equals("[]")) { + try { + Type listType = new TypeToken>>() { + }.getType(); + List> parsed = gson.fromJson(aguiTools, listType); + if (parsed != null) { + frontendTools.addAll(parsed); + } + } catch (Exception e) { + log.error("Failed to parse frontend tools: {}", e.getMessage()); + } + } + return frontendTools; + } + + public static Map wrapFrontendToolsAsToolObjects(List> frontendTools) { + Map wrappedTools = new HashMap<>(); + + for (Map frontendTool : frontendTools) { + String toolName = (String) frontendTool.get("name"); + String toolDescription = (String) frontendTool.get("description"); + + // Create frontend tool object with source marker + Map toolAttributes = new HashMap<>(); + toolAttributes.put("source", "frontend"); + toolAttributes.put("tool_definition", frontendTool); + + Object parameters = frontendTool.get("parameters"); + if (parameters != null) { + toolAttributes.put("input_schema", gson.toJson(parameters)); + } else { + Map emptySchema = Map.of("type", "object", "properties", Map.of()); + toolAttributes.put("input_schema", gson.toJson(emptySchema)); + } + + Tool frontendToolObj = new AGUIFrontendTool(toolName, toolDescription, toolAttributes); + wrappedTools.put(toolName, frontendToolObj); + } + + return wrappedTools; + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAGUIAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAGUIAgentRunner.java new file mode 100644 index 0000000000..14efd48de3 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAGUIAgentRunner.java @@ -0,0 +1,410 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.agent; + +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_CONTENT; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_ROLE; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_TOOL_CALLS; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_TOOL_CALL_ID; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_ASSISTANT_TOOL_CALL_MESSAGES; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_CONTEXT; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_MESSAGES; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_TOOL_CALL_RESULTS; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_ROLE_ASSISTANT; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_ROLE_TOOL; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_ROLE_USER; +import static org.opensearch.ml.common.utils.StringUtils.getStringField; +import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.common.utils.StringUtils.processTextDoc; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CHAT_HISTORY_MESSAGE_PREFIX; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CHAT_HISTORY_QUESTION_TEMPLATE; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CHAT_HISTORY_RESPONSE_TEMPLATE; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CONTEXT; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.LLM_INTERFACE; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.NEW_CHAT_HISTORY; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.commons.text.StringSubstitutor; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.spi.memory.Memory; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.engine.encryptor.Encryptor; +import org.opensearch.ml.engine.function_calling.FunctionCalling; +import org.opensearch.ml.engine.function_calling.FunctionCallingFactory; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.transport.TransportChannel; +import org.opensearch.transport.client.Client; + +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class MLAGUIAgentRunner implements MLAgentRunner { + + private final Client client; + private final Settings settings; + private final ClusterService clusterService; + private final NamedXContentRegistry xContentRegistry; + private final Map toolFactories; + private final Map memoryFactoryMap; + private final SdkClient sdkClient; + private final Encryptor encryptor; + + public MLAGUIAgentRunner( + Client client, + Settings settings, + ClusterService clusterService, + NamedXContentRegistry xContentRegistry, + Map toolFactories, + Map memoryFactoryMap, + SdkClient sdkClient, + Encryptor encryptor + ) { + this.client = client; + this.settings = settings; + this.clusterService = clusterService; + this.xContentRegistry = xContentRegistry; + this.toolFactories = toolFactories; + this.memoryFactoryMap = memoryFactoryMap; + this.sdkClient = sdkClient; + this.encryptor = encryptor; + } + + @Override + public void run(MLAgent mlAgent, Map params, ActionListener listener, TransportChannel channel) { + try { + String llmInterface = params.get(LLM_INTERFACE); + if (llmInterface == null && mlAgent.getParameters() != null) { + llmInterface = mlAgent.getParameters().get(LLM_INTERFACE); + } + + FunctionCalling functionCalling = FunctionCallingFactory.create(llmInterface); + if (functionCalling != null) { + functionCalling.configure(params); + } + + processAGUIMessages(params, llmInterface); + processAGUIContext(params); + + params.put("agent_type", "ag_ui"); + + MLAgentRunner conversationalRunner = new MLChatAgentRunner( + client, + settings, + clusterService, + xContentRegistry, + toolFactories, + memoryFactoryMap, + sdkClient, + encryptor + ); + + // Execute with streaming - events are generated in RestMLExecuteStreamAction + conversationalRunner.run(mlAgent, params, listener, channel); + + } catch (Exception e) { + log.error("Error starting AG-UI agent execution", e); + listener.onFailure(e); + } + } + + private void processAGUIMessages(Map params, String llmInterface) { + String aguiMessagesJson = params.get(AGUI_PARAM_MESSAGES); + if (aguiMessagesJson == null || aguiMessagesJson.isEmpty()) { + return; + } + + try { + JsonElement messagesElement = gson.fromJson(aguiMessagesJson, JsonElement.class); + + if (!messagesElement.isJsonArray()) { + log.warn("AG-UI messages is not a JSON array"); + return; + } + + JsonArray messageArray = messagesElement.getAsJsonArray(); + + for (int i = 0; i < messageArray.size(); i++) { + JsonElement msgElement = messageArray.get(i); + if (msgElement.isJsonObject()) { + JsonObject msg = msgElement.getAsJsonObject(); + String role = getStringField(msg, "role"); + String content = getStringField(msg, "content"); + boolean hasToolCalls = msg.has("toolCalls"); + boolean hasToolCallId = msg.has("toolCallId"); + log + .debug( + "AG-UI: Message[{}] - role: {}, hasToolCalls: {}, hasToolCallId: {}, content preview: {}", + i, + role, + hasToolCalls, + hasToolCallId, + content != null && content.length() > 50 ? content.substring(0, 50) + "..." : content + ); + } + } + + if (messageArray.size() <= 1) { + return; + } + + // Check for tool result messages and extract them + // Also track assistant messages with tool calls + List> toolResults = new ArrayList<>(); + List toolCallMessageIndices = new ArrayList<>(); + List toolResultMessageIndices = new ArrayList<>(); + List assistantToolCallMessages = new ArrayList<>(); + int lastToolResultIndex = -1; + + for (int i = 0; i < messageArray.size(); i++) { + JsonElement messageElement = messageArray.get(i); + if (messageElement.isJsonObject()) { + JsonObject message = messageElement.getAsJsonObject(); + String role = getStringField(message, AGUI_FIELD_ROLE); + + // Track and extract assistant messages with tool calls + if (AGUI_ROLE_ASSISTANT.equals(role) && message.has(AGUI_FIELD_TOOL_CALLS)) { + toolCallMessageIndices.add(i); + + // Extract tool calls from AG-UI message (AG-UI uses OpenAI-compatible format) + JsonElement toolCallsElement = message.get(AGUI_FIELD_TOOL_CALLS); + if (toolCallsElement != null && toolCallsElement.isJsonArray()) { + // Pass the JSON array directly to FunctionCalling for format conversion + String toolCallsJson = gson.toJson(toolCallsElement); + + FunctionCalling functionCalling = FunctionCallingFactory.create(llmInterface); + String assistantMessage = ""; + + if (functionCalling != null) { + // Use FunctionCalling to format the message in the correct LLM format + assistantMessage = functionCalling.formatAGUIToolCalls(toolCallsJson); + log.debug("AG-UI: Formatted assistant message using {}", functionCalling.getClass().getSimpleName()); + } else { + log.error("AG-UI: Invalid function calling configuration: {}", llmInterface); + } + + assistantToolCallMessages.add(assistantMessage); + log.debug("AG-UI: Extracted assistant message at index {}", i); + log.debug("AG-UI: Assistant message JSON: {}", assistantMessage); + } + } + + if (AGUI_ROLE_TOOL.equals(role)) { + String content = getStringField(message, AGUI_FIELD_CONTENT); + String toolCallId = getStringField(message, AGUI_FIELD_TOOL_CALL_ID); + + if (content != null && toolCallId != null) { + Map toolResult = new HashMap<>(); + toolResult.put("tool_call_id", toolCallId); + toolResult.put("content", content); + toolResults.add(toolResult); + toolResultMessageIndices.add(i); + lastToolResultIndex = i; + } + } + } + } + + // Only process the MOST RECENT tool execution + // Check if there are any assistant messages after the last tool result + boolean hasAssistantAfterToolResult = false; + if (lastToolResultIndex >= 0) { + for (int i = lastToolResultIndex + 1; i < messageArray.size(); i++) { + JsonElement messageElement = messageArray.get(i); + if (messageElement.isJsonObject()) { + JsonObject message = messageElement.getAsJsonObject(); + String role = getStringField(message, AGUI_FIELD_ROLE); + if (AGUI_ROLE_ASSISTANT.equals(role)) { + hasAssistantAfterToolResult = true; + break; + } + } + } + } + + boolean toolResultsAreRecent = !toolResults.isEmpty() && !hasAssistantAfterToolResult; + + if (!toolResults.isEmpty() && toolResultsAreRecent) { + // Only include the MOST RECENT tool execution (last tool call + result pair) + // Find the assistant message that corresponds to the last tool result + String lastToolCallMessage = null; + + // The last tool result should correspond to the last assistant message with tool_calls + if (!assistantToolCallMessages.isEmpty() && !toolCallMessageIndices.isEmpty()) { + lastToolCallMessage = assistantToolCallMessages.getLast(); + } + + // Only include the last tool result + Map lastToolResult = toolResults.getLast(); + List> recentToolResults = List.of(lastToolResult); + + String toolResultsJson = gson.toJson(recentToolResults); + params.put(AGUI_PARAM_TOOL_CALL_RESULTS, toolResultsJson); + + // Only pass the most recent assistant message with tool_calls + if (lastToolCallMessage != null) { + params.put(AGUI_PARAM_ASSISTANT_TOOL_CALL_MESSAGES, gson.toJson(List.of(lastToolCallMessage))); + } + } else if (!toolResults.isEmpty()) { + log + .info( + "AG-UI: Found {} tool results but they are not recent (last at index {}, total messages: {}), " + + "skipping from interactions", + toolResults.size(), + lastToolResultIndex, + messageArray.size() + ); + } + + String chatHistoryQuestionTemplate = params.get(CHAT_HISTORY_QUESTION_TEMPLATE); + String chatHistoryResponseTemplate = params.get(CHAT_HISTORY_RESPONSE_TEMPLATE); + + if (chatHistoryQuestionTemplate == null || chatHistoryResponseTemplate == null) { + + StringBuilder chatHistoryBuilder = new StringBuilder(); + + for (int i = 0; i < messageArray.size() - 1; i++) { + JsonElement messageElement = messageArray.get(i); + if (messageElement.isJsonObject()) { + JsonObject message = messageElement.getAsJsonObject(); + String role = getStringField(message, AGUI_FIELD_ROLE); + String content = getStringField(message, AGUI_FIELD_CONTENT); + + // Skip tool messages - they're not part of chat history + if (AGUI_ROLE_TOOL.equals(role)) { + continue; + } + + // Skip assistant messages with tool_calls - they're not part of chat history + if (AGUI_ROLE_ASSISTANT.equals(role) && message.has(AGUI_FIELD_TOOL_CALLS)) { + continue; + } + + // Include user messages and assistant messages with content (final answers) + if ((AGUI_ROLE_USER.equals(role) || AGUI_ROLE_ASSISTANT.equals(role)) && content != null && !content.isEmpty()) { + if (chatHistoryBuilder.length() > 0) { + chatHistoryBuilder.append("\n"); + } + chatHistoryBuilder.append(role.equals(AGUI_ROLE_USER) ? "Human: " : "Assistant: ").append(content); + } + } + } + + if (chatHistoryBuilder.length() > 0) { + params.put(NEW_CHAT_HISTORY, chatHistoryBuilder.toString()); + } + } else { + List chatHistory = new ArrayList<>(); + + for (int i = 0; i < messageArray.size() - 1; i++) { + JsonElement messageElement = messageArray.get(i); + if (messageElement.isJsonObject()) { + JsonObject message = messageElement.getAsJsonObject(); + String role = getStringField(message, AGUI_FIELD_ROLE); + String content = getStringField(message, AGUI_FIELD_CONTENT); + + // Skip tool messages - they're never part of chat history + if (AGUI_ROLE_TOOL.equals(role)) { + continue; + } + + if (AGUI_ROLE_USER.equals(role) && content != null && !content.isEmpty()) { + // When we have recent tool results, skip the user message that triggered the tool call + // This is the user message right before the assistant message with tool calls + if (toolResultsAreRecent && !toolCallMessageIndices.isEmpty()) { + int firstToolCallIndex = toolCallMessageIndices.get(0); + // Skip user messages that are at or after the first tool call + // (they're part of the current tool execution cycle, not historical chat) + if (i >= firstToolCallIndex - 1) { + continue; + } + } + + Map messageParams = new HashMap<>(); + messageParams.put("question", processTextDoc(content)); + StringSubstitutor substitutor = new StringSubstitutor(messageParams, CHAT_HISTORY_MESSAGE_PREFIX, "}"); + String chatMessage = substitutor.replace(chatHistoryQuestionTemplate); + chatHistory.add(chatMessage); + } else if (AGUI_ROLE_ASSISTANT.equals(role)) { + // Skip ALL assistant messages with tool_calls - they're never part of chat history + // (matching backend behavior where only final answers are in chat history) + if (message.has(AGUI_FIELD_TOOL_CALLS)) { + // Skip - not part of chat history + } else if (content != null && !content.isEmpty()) { + // Regular assistant message with content (final answer) + Map messageParams = new HashMap<>(); + messageParams.put("response", processTextDoc(content)); + StringSubstitutor substitutor = new StringSubstitutor(messageParams, CHAT_HISTORY_MESSAGE_PREFIX, "}"); + String chatMessage = substitutor.replace(chatHistoryResponseTemplate); + chatHistory.add(chatMessage); + } + } + } + } + + if (!chatHistory.isEmpty()) { + params.put(NEW_CHAT_HISTORY, String.join(", ", chatHistory) + ", "); + } + } + } catch (Exception e) { + log.error("Failed to process AG-UI messages to chat history", e); + } + } + + private void processAGUIContext(Map params) { + String aguiContextJson = params.get(AGUI_PARAM_CONTEXT); + + if (aguiContextJson == null || aguiContextJson.isEmpty()) { + return; + } + + try { + JsonElement contextElement = gson.fromJson(aguiContextJson, JsonElement.class); + + if (!contextElement.isJsonArray()) { + log.warn("AG-UI context is not a JSON array"); + return; + } + + JsonArray contextArray = contextElement.getAsJsonArray(); + if (contextArray.size() == 0) { + return; + } + + StringBuilder contextBuilder = new StringBuilder(); + + for (JsonElement contextItemElement : contextArray) { + if (contextItemElement.isJsonObject()) { + JsonObject contextItem = contextItemElement.getAsJsonObject(); + String description = getStringField(contextItem, "description"); + String value = getStringField(contextItem, "value"); + + if (description != null && value != null) { + contextBuilder.append("- ").append(description).append(": ").append(value).append("\n"); + } + } + } + + if (contextBuilder.length() > 0) { + params.put(CONTEXT, contextBuilder.toString()); + } + + } catch (Exception e) { + log.error("Failed to process AG-UI context", e); + } + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index 1594506cf4..a9a3fb55d3 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -653,6 +653,17 @@ protected MLAgentRunner getAgentRunner(MLAgent mlAgent) { sdkClient, encryptor ); + case AG_UI: + return new MLAGUIAgentRunner( + client, + settings, + clusterService, + xContentRegistry, + toolFactories, + memoryFactoryMap, + sdkClient, + encryptor + ); default: throw new IllegalArgumentException("Unsupported agent type: " + mlAgent.getType()); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 103f3f89b3..83af7a1a74 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -5,6 +5,10 @@ package org.opensearch.ml.engine.algorithms.agent; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_ASSISTANT_TOOL_CALL_MESSAGES; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_BACKEND_TOOL_NAMES; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_TOOLS; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_TOOL_CALL_RESULTS; import static org.opensearch.ml.common.conversation.ActionConstants.ADDITIONAL_INFO_FIELD; import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; @@ -31,11 +35,14 @@ import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getToolNames; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.outputToOutputString; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseFrontendTools; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseLLMOutput; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.substitute; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.wrapFrontendToolsAsToolObjects; import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.CHAT_HISTORY_PREFIX; import static org.opensearch.ml.engine.tools.ReadFromScratchPadTool.SCRATCHPAD_NOTES_KEY; +import java.lang.reflect.Type; import java.security.PrivilegedActionException; import java.util.ArrayList; import java.util.Collections; @@ -47,7 +54,6 @@ import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; import org.apache.commons.text.StringSubstitutor; import org.opensearch.action.ActionRequest; @@ -83,6 +89,7 @@ import org.opensearch.transport.client.Client; import com.google.common.annotations.VisibleForTesting; +import com.google.gson.reflect.TypeToken; import lombok.Data; import lombok.NoArgsConstructor; @@ -260,24 +267,22 @@ private void runAgent( String sessionId, FunctionCalling functionCalling ) { - List toolSpecs = getMlToolSpecs(mlAgent, params); - - // Create a common method to handle both success and failure cases - Consumer> processTools = (allToolSpecs) -> { - Map tools = new HashMap<>(); - Map toolSpecMap = new HashMap<>(); - createTools(toolFactories, params, allToolSpecs, tools, toolSpecMap, mlAgent); - runReAct(mlAgent.getLlm(), tools, toolSpecMap, params, memory, sessionId, mlAgent.getTenantId(), listener, functionCalling); - }; + List> frontendTools = new ArrayList<>(); + + if (isAGUIAgent(params)) { + // Check if this is an AG-UI request with tool call results + String aguiToolCallResults = params.get(AGUI_PARAM_TOOL_CALL_RESULTS); + if (aguiToolCallResults != null && !aguiToolCallResults.isEmpty()) { + // Process tool call results from frontend + processAGUIToolResults(mlAgent, params, listener, memory, sessionId, functionCalling, aguiToolCallResults); + return; + } - // Fetch MCP tools and handle both success and failure cases - getMcpToolSpecs(mlAgent, client, sdkClient, encryptor, ActionListener.wrap(mcpTools -> { - toolSpecs.addAll(mcpTools); - processTools.accept(toolSpecs); - }, e -> { - log.error("Failed to get MCP tools, continuing with base tools only", e); - processTools.accept(toolSpecs); - })); + // Parse frontend tools if present + String aguiTools = params.get(AGUI_PARAM_TOOLS); + frontendTools = parseFrontendTools(aguiTools); + } + processUnifiedTools(mlAgent, params, listener, memory, sessionId, functionCalling, frontendTools); } private void runReAct( @@ -289,7 +294,8 @@ private void runReAct( String sessionId, String tenantId, ActionListener listener, - FunctionCalling functionCalling + FunctionCalling functionCalling, + Map backendTools ) { Map tmpParameters = constructLLMParams(llm, parameters); String prompt = constructLLMPrompt(tools, tmpParameters); @@ -312,6 +318,7 @@ private void runReAct( AtomicReference lastAction = new AtomicReference<>(); AtomicReference lastActionInput = new AtomicReference<>(); AtomicReference lastToolSelectionResponse = new AtomicReference<>(); + AtomicReference lastToolCallId = new AtomicReference<>(); Map additionalInfo = new ConcurrentHashMap<>(); Map lastToolParams = new ConcurrentHashMap<>(); @@ -378,6 +385,7 @@ private void runReAct( lastAction.set(action); lastActionInput.set(actionInput); lastToolSelectionResponse.set(thoughtResponse); + lastToolCallId.set(toolCallId); traceTensors .add( @@ -419,28 +427,42 @@ private void runReAct( } if (tools.containsKey(action)) { - Map toolParams = constructToolParams( - tools, - toolSpecMap, - question, - lastActionInput, - action, - actionInput - ); - lastToolParams.clear(); - lastToolParams.putAll(toolParams); - runTool( - tools, - toolSpecMap, - tmpParameters, - (ActionListener) nextStepListener, - action, - actionInput, - toolParams, - interactions, - toolCallId, - functionCalling - ); + // Check if this is a backend tool - if it is, execute it normally in the ReAct loop + // If it's NOT a backend tool, it must be a frontend tool, so break out of the loop + boolean isBackendTool = backendTools != null && backendTools.containsKey(action); + + log.info("AG-UI: Tool execution request - action: {}, isBackendTool: {}", action, isBackendTool); + + if (!isBackendTool) { + // For frontend tool use, we close the response stream and wait for frontend tool result + if (streamingWrapper != null) { + streamingWrapper.sendRunFinishedAndCloseStream(sessionId, parentInteractionId); + } + } else { + // Handle backend tool normally + Map toolParams = constructToolParams( + tools, + toolSpecMap, + question, + lastActionInput, + action, + actionInput + ); + lastToolParams.clear(); + lastToolParams.putAll(toolParams); + runTool( + tools, + toolSpecMap, + tmpParameters, + (ActionListener) nextStepListener, + action, + actionInput, + toolParams, + interactions, + toolCallId, + functionCalling + ); + } } else { String res = String.format(Locale.ROOT, "Failed to run the tool %s which is unsupported.", action); @@ -486,7 +508,17 @@ private void runReAct( } sessionMsgAnswerBuilder.append(outputToOutputString(filteredOutput)); - streamingWrapper.sendToolResponse(outputToOutputString(filteredOutput), sessionId, parentInteractionId); + + String toolOutputString = outputToOutputString(filteredOutput); + + if (streamingWrapper != null) { + if (isAGUIAgent(parameters)) { + streamingWrapper.sendBackendToolResult(lastToolCallId.get(), toolOutputString, sessionId, parentInteractionId); + } else { + streamingWrapper.sendToolResponse(toolOutputString, sessionId, parentInteractionId); + } + } + traceTensors .add( ModelTensors @@ -518,6 +550,7 @@ private void runReAct( ); return; } + ActionRequest request = streamingWrapper.createPredictionRequest(llm, tmpParameters, tenantId); streamingWrapper.executeRequest(request, (ActionListener) nextStepListener); } @@ -933,4 +966,173 @@ private void saveMessage( memory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), "LLM", listener); } } + + private boolean isAGUIAgent(Map parameters) { + return parameters != null && parameters.containsKey("agent_type") && parameters.get("agent_type").equals("ag_ui"); + } + + /** + * Process unified tools - combines frontend and backend tools for LLM visibility + */ + private void processUnifiedTools( + MLAgent mlAgent, + Map params, + ActionListener listener, + Memory memory, + String sessionId, + FunctionCalling functionCalling, + List> frontendTools + ) { + // Always get backend tools + List backendToolSpecs = getMlToolSpecs(mlAgent, params); + + // Handle backend tool loading with MCP tools + getMcpToolSpecs(mlAgent, client, sdkClient, encryptor, ActionListener.wrap(mcpTools -> { + // Add MCP tools to backend tools + backendToolSpecs.addAll(mcpTools); + + // Create backend tools map + Map backendToolsMap = new HashMap<>(); + Map toolSpecMap = new HashMap<>(); + createTools(toolFactories, params, backendToolSpecs, backendToolsMap, toolSpecMap, mlAgent); + + // Create unified tool list for function calling (frontend + backend) + processUnifiedToolsWithBackend( + mlAgent, + params, + listener, + memory, + sessionId, + functionCalling, + frontendTools, + backendToolsMap, + toolSpecMap + ); + }, e -> { + // Even if MCP tools fail, continue with base backend tools + + Map backendToolsMap = new HashMap<>(); + Map toolSpecMap = new HashMap<>(); + createTools(toolFactories, params, backendToolSpecs, backendToolsMap, toolSpecMap, mlAgent); + + processUnifiedToolsWithBackend( + mlAgent, + params, + listener, + memory, + sessionId, + functionCalling, + frontendTools, + backendToolsMap, + toolSpecMap + ); + })); + } + + /** + * Process unified tools with both frontend and backend tools ready + */ + private void processUnifiedToolsWithBackend( + MLAgent mlAgent, + Map params, + ActionListener listener, + Memory memory, + String sessionId, + FunctionCalling functionCalling, + List> frontendTools, + Map backendToolsMap, + Map toolSpecMap + ) { + // Create unified tool map by adding frontend tools as Tool objects + Map unifiedToolsMap = new HashMap<>(backendToolsMap); + unifiedToolsMap.putAll(wrapFrontendToolsAsToolObjects(frontendTools)); + + // Store backend tool names so streaming handler can filter them out from AG-UI events + if (!backendToolsMap.isEmpty()) { + List backendToolNames = new ArrayList<>(backendToolsMap.keySet()); + params.put(AGUI_PARAM_BACKEND_TOOL_NAMES, gson.toJson(backendToolNames)); + } + + // Call runReAct with unified tools - both frontend and backend tools will be visible to LLM + // Pass backendToolsMap so runReAct can distinguish between frontend and backend tools + runReAct( + mlAgent.getLlm(), + unifiedToolsMap, + toolSpecMap, + params, + memory, + sessionId, + mlAgent.getTenantId(), + listener, + functionCalling, + backendToolsMap + ); + } + + /** + * Process AG-UI tool call results from frontend execution + */ + private void processAGUIToolResults( + MLAgent mlAgent, + Map params, + ActionListener listener, + Memory memory, + String sessionId, + FunctionCalling functionCalling, + String aguiToolCallResults + ) { + try { + + Type listType = new TypeToken>>() { + }.getType(); + List> toolResults = gson.fromJson(aguiToolCallResults, listType); + + if (functionCalling != null && !toolResults.isEmpty()) { + List> formattedResults = new ArrayList<>(); + for (Map result : toolResults) { + Map formattedResult = new HashMap<>(); + formattedResult.put(TOOL_CALL_ID, result.get("tool_call_id")); + formattedResult.put(TOOL_RESULT, Map.of("text", result.get("content"))); + formattedResults.add(formattedResult); + } + + List llmMessages = functionCalling.supply(formattedResults); + + if (!llmMessages.isEmpty()) { + // Build interactions list: assistant message with tool_calls FIRST, then tool results + List interactions = new ArrayList<>(); + + String assistantToolCallMessagesJson = params.get(AGUI_PARAM_ASSISTANT_TOOL_CALL_MESSAGES); + if (assistantToolCallMessagesJson != null && !assistantToolCallMessagesJson.isEmpty()) { + Type listType2 = new TypeToken>() { + }.getType(); + List assistantMessages = gson.fromJson(assistantToolCallMessagesJson, listType2); + interactions.addAll(assistantMessages); + } + + for (LLMMessage llmMessage : llmMessages) { + interactions.add(llmMessage.getResponse()); + } + + Map updatedParams = new HashMap<>(params); + if (!interactions.isEmpty()) { + String interactionsValue = ", " + String.join(", ", interactions); + updatedParams.put(INTERACTIONS, interactionsValue); + } + + String aguiTools = params.get(AGUI_PARAM_TOOLS); + List> frontendTools = parseFrontendTools(aguiTools); + + processUnifiedTools(mlAgent, updatedParams, listener, memory, sessionId, functionCalling, frontendTools); + } else { + listener.onFailure(new RuntimeException("No LLM messages generated from tool results")); + } + } else { + listener.onFailure(new RuntimeException("No function calling interface or empty tool results")); + } + } catch (Exception e) { + listener.onFailure(e); + } + } + } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/StreamingWrapper.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/StreamingWrapper.java index beb5e60f53..4f8b4219aa 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/StreamingWrapper.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/StreamingWrapper.java @@ -8,6 +8,7 @@ import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.returnFinalResponse; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -16,6 +17,9 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.agent.LLMSpec; +import org.opensearch.ml.common.agui.BaseEvent; +import org.opensearch.ml.common.agui.RunFinishedEvent; +import org.opensearch.ml.common.agui.ToolCallResultEvent; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput; import org.opensearch.ml.common.output.model.ModelTensor; @@ -34,7 +38,7 @@ @Log4j2 public class StreamingWrapper { private final TransportChannel channel; - private final boolean isStreaming; + private boolean isStreaming; private Client client; public StreamingWrapper(TransportChannel channel, org.opensearch.transport.client.Client client) { @@ -122,6 +126,31 @@ public void sendToolResponse(String toolOutput, String sessionId, String parentI } } + public void sendBackendToolResult(String toolCallId, String toolResult, String sessionId, String parentInteractionId) { + try { + BaseEvent toolCallResultEvent = new ToolCallResultEvent("msg_" + System.currentTimeMillis(), toolCallId, toolResult); + MLTaskResponse toolChunk = createStreamChunk(toolCallResultEvent.toJsonString(), sessionId, parentInteractionId, false); + channel.sendResponseBatch(toolChunk); + } catch (Exception e) { + log.error("Failed to send backend tool AGUI events for toolCallId '{}': {}", toolCallId, e.getMessage()); + sendToolResponse(toolResult, sessionId, parentInteractionId); + } + } + + public void sendRunFinishedAndCloseStream(String sessionId, String parentInteractionId) { + + BaseEvent runFinishedEvent = new RunFinishedEvent(sessionId, parentInteractionId, null); + List modelTensors = new ArrayList<>(); + Map dataMap = Map.of("content", runFinishedEvent.toJsonString(), "is_last", true); + + modelTensors.add(ModelTensor.builder().name("response").dataAsMap(dataMap).build()); + ModelTensorOutput output = ModelTensorOutput + .builder() + .mlModelOutputs(List.of(ModelTensors.builder().mlModelTensors(modelTensors).build())) + .build(); + channel.sendResponseBatch(new MLTaskResponse(output)); + } + private MLTaskResponse createStreamChunk(String toolOutput, String sessionId, String parentInteractionId, boolean isLast) { List tensors = Arrays .asList( diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/BaseStreamingHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/BaseStreamingHandler.java index 73a4f740ed..1b37c73e19 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/BaseStreamingHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/BaseStreamingHandler.java @@ -10,11 +10,15 @@ import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; +import org.opensearch.ml.common.agui.BaseEvent; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.transport.MLTaskResponse; +import lombok.extern.log4j.Log4j2; + +@Log4j2 public abstract class BaseStreamingHandler implements StreamingHandler { protected void sendContentResponse(String content, boolean isLast, StreamPredictActionListener actionListener) { @@ -35,4 +39,21 @@ protected void sendCompletionResponse(AtomicBoolean isStreamClosed, StreamPredic sendContentResponse("", true, actionListener); } } + + protected void sendAGUIEvent(BaseEvent event, boolean isLast, StreamPredictActionListener actionListener) { + if (event == null) { + return; + } + + List modelTensors = new ArrayList<>(); + Map dataMap = Map.of("content", event.toJsonString(), "is_last", isLast); + + modelTensors.add(ModelTensor.builder().name("response").dataAsMap(dataMap).build()); + ModelTensorOutput output = ModelTensorOutput + .builder() + .mlModelOutputs(List.of(ModelTensors.builder().mlModelTensors(modelTensors).build())) + .build(); + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); + actionListener.onStreamResponse(response, isLast); + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/BedrockStreamingHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/BedrockStreamingHandler.java index 0ec9cce537..8f041213d9 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/BedrockStreamingHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/BedrockStreamingHandler.java @@ -21,6 +21,7 @@ import org.opensearch.OpenSearchStatusException; import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.agui.*; import org.opensearch.ml.common.connector.AwsConnector; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.output.model.ModelTensor; @@ -66,8 +67,13 @@ public class BedrockStreamingHandler extends BaseStreamingHandler { private final SdkAsyncHttpClient httpClient; private final AwsConnector connector; + private final boolean isAGUIAgent; private static final String STOP_REASON_TOOL_USE = "StopReason=tool_use"; + // AG-UI message state for this LLM response (per-request scope) + private String messageId; + private boolean textMessageStarted = false; + private enum StreamState { STREAMING_CONTENT, TOOL_CALL_DETECTED, @@ -77,8 +83,18 @@ private enum StreamState { } public BedrockStreamingHandler(SdkAsyncHttpClient httpClient, AwsConnector connector) { + this(httpClient, connector, null); + } + + public BedrockStreamingHandler(SdkAsyncHttpClient httpClient, AwsConnector connector, Map parameters) { this.httpClient = httpClient; this.connector = connector; + + this.isAGUIAgent = parameters != null && (parameters.containsKey("agent_type") && parameters.get("agent_type").equals("ag_ui")); + + if (isAGUIAgent) { + log.debug("BedrockStreamingHandler: Detected AG-UI agent"); + } } @Override @@ -96,6 +112,13 @@ public void startStream( StringBuilder toolInputAccumulator = new StringBuilder(); AtomicReference currentState = new AtomicReference<>(StreamState.STREAMING_CONTENT); + // Initialize AG-UI message state for this LLM response + if (isAGUIAgent) { + messageId = "msg_" + System.currentTimeMillis() + "_" + System.nanoTime(); + textMessageStarted = false; + log.debug("AG-UI: Initialized messageId for LLM response: {}", messageId); + } + // Build Bedrock client BedrockRuntimeAsyncClient bedrockClient = buildBedrockRuntimeAsyncClient(); @@ -157,8 +180,35 @@ private boolean isClientError(Throwable error) { private ConverseStreamRequest buildConverseStreamRequest(String payload, Map parameters) { try { + log.debug("AG-UI: Building Bedrock request from payload: {}", payload); ObjectMapper mapper = new ObjectMapper(); JsonNode payloadJson = mapper.readTree(payload); + + // Log the messages array for debugging + if (payloadJson.has("messages")) { + JsonNode messagesArray = payloadJson.get("messages"); + log.debug("AG-UI: Messages array in payload: {}", messagesArray); + + // Check for consecutive messages with the same role (Bedrock doesn't allow this) + String previousRole = null; + for (int i = 0; i < messagesArray.size(); i++) { + JsonNode msg = messagesArray.get(i); + String currentRole = msg.has("role") ? msg.get("role").asText() : "unknown"; + if (previousRole != null && previousRole.equals(currentRole)) { + log + .warn( + "AG-UI: Found consecutive messages with same role '{}' at index {} and {}. Bedrock requires alternating roles!", + currentRole, + i - 1, + i + ); + } + previousRole = currentRole; + } + } else { + log.warn("AG-UI: No messages array found in payload!"); + } + return ConverseStreamRequest .builder() .modelId(parameters.get("model")) @@ -190,9 +240,44 @@ private void handleStreamEvent( if (isToolUseDetected(event)) { currentState.set(StreamState.TOOL_CALL_DETECTED); extractToolInfo(event, toolName, toolUseId); + + if (isAGUIAgent) { + // end current text message before sending tool events + BaseEvent textMessageEndEvent = new TextMessageEndEvent(messageId); + sendAGUIEvent(textMessageEndEvent, false, listener); + + BaseEvent toolCallStartEvent = new ToolCallStartEvent(toolUseId.get(), toolName.get(), messageId); + sendAGUIEvent(toolCallStartEvent, false, listener); + log.debug("AG-UI: Sent TOOL_CALL_START for messageId: {} and toolUseId: {}", messageId, toolUseId); + } } else if (isContentDelta(event)) { - sendContentResponse(getTextContent(event), false, listener); + String content = getTextContent(event); + + if (isAGUIAgent) { + if (!textMessageStarted) { + textMessageStarted = true; + BaseEvent textMessageStartEvent = new TextMessageStartEvent(toolName.get(), messageId); + sendAGUIEvent(textMessageStartEvent, false, listener); + log.debug("AG-UI: Sent TEXT_MESSAGE_START for messageId: {}", messageId); + } + + BaseEvent textMessageContentEvent = new TextMessageContentEvent(content, messageId); + sendAGUIEvent(textMessageContentEvent, false, listener); + log.debug("AG-UI: Sent TEXT_MESSAGE_CONTENT for messageId: {}", messageId); + } else { + sendContentResponse(content, false, listener); + } } else if (isStreamComplete(event)) { + if (isAGUIAgent && textMessageStarted) { + BaseEvent textMessageEndEvent = new TextMessageEndEvent(messageId); + sendAGUIEvent(textMessageEndEvent, false, listener); + log.debug("AG-UI: Sent TEXT_MESSAGE_END for messageId: {}", messageId); + + BaseEvent runFinishedEvent = new RunFinishedEvent("", "", null); + sendAGUIEvent(runFinishedEvent, true, listener); + log.debug("RestMLExecuteStreamAction: Added RUN_FINISHED event - ReAct loop completed"); + } + currentState.set(StreamState.COMPLETED); sendCompletionResponse(isStreamClosed, listener); } @@ -201,31 +286,52 @@ private void handleStreamEvent( case TOOL_CALL_DETECTED: if (isToolInputDelta(event)) { currentState.set(StreamState.ACCUMULATING_TOOL_INPUT); - accumulateToolInput(getToolInputFragment(event), toolInput, toolInputAccumulator); + String inputFragment = getToolInputFragment(event); + accumulateToolInput(inputFragment, toolInput, toolInputAccumulator); + + if (isAGUIAgent) { + BaseEvent toolCallArgsEvent = new ToolCallArgsEvent(toolUseId.get(), inputFragment); + sendAGUIEvent(toolCallArgsEvent, false, listener); + log.debug("AG-UI: Sent TOOL_CALL_ARGS for messageId: {}", messageId); + } else { + sendContentResponse(inputFragment, false, listener); + } } break; case ACCUMULATING_TOOL_INPUT: if (isToolInputDelta(event)) { - accumulateToolInput(getToolInputFragment(event), toolInput, toolInputAccumulator); + String inputFragment = getToolInputFragment(event); + accumulateToolInput(inputFragment, toolInput, toolInputAccumulator); + + if (isAGUIAgent) { + BaseEvent toolCallArgsEvent = new ToolCallArgsEvent(toolUseId.get(), inputFragment); + sendAGUIEvent(toolCallArgsEvent, false, listener); + log.debug("AG-UI: Sent TOOL_CALL_ARGS for messageId: {}", messageId); + } else { + sendContentResponse(inputFragment, false, listener); + } } else if (isToolInputComplete(event)) { + if (isAGUIAgent) { + BaseEvent toolCallEndEvent = new ToolCallEndEvent(toolUseId.get()); + sendAGUIEvent(toolCallEndEvent, false, listener); + log.debug("AG-UI: Sent TOOL_CALL_END event for tool '{}' after args completed", toolName.get()); + } + currentState.set(StreamState.WAITING_FOR_TOOL_RESULT); listener.onResponse(createToolUseResponse(toolName, toolInput, toolUseId)); } break; case WAITING_FOR_TOOL_RESULT: - // Don't close stream - wait for tool execution log.debug("Waiting for tool result - keeping stream open"); break; case COMPLETED: - // Stream already completed break; } } - // TODO: refactor the event type checker methods private void extractToolInfo(ConverseStreamOutput event, AtomicReference toolName, AtomicReference toolUseId) { ContentBlockStartEvent startEvent = (ContentBlockStartEvent) event; if (startEvent.start() != null && startEvent.start().toolUse() != null) { @@ -395,10 +501,31 @@ private List parseMessages(JsonNode messagesArray) { private Message buildMessage(JsonNode messageItem) { String role = messageItem.has("role") && messageItem.get("role") != null ? messageItem.get("role").asText() : "assistant"; + // Handle AG-UI tool result messages + if (isAGUIAgent && "tool".equals(role)) { + return buildToolResultMessage(messageItem); + } + List contentBlocks = buildContentBlocks(messageItem.get("content")); return Message.builder().role(role).content(contentBlocks).build(); } + private Message buildToolResultMessage(JsonNode toolMessage) { + String toolCallId = toolMessage.has("toolCallId") ? toolMessage.get("toolCallId").asText() : ""; + String content = toolMessage.has("content") ? toolMessage.get("content").asText() : ""; + + ContentBlock toolResultBlock = ContentBlock + .builder() + .toolResult( + ToolResultBlock.builder().toolUseId(toolCallId).content(ToolResultContentBlock.builder().text(content).build()).build() + ) + .build(); + + log.debug("AG-UI: Converted tool message to Bedrock format - toolUseId: {}, content length: {}", toolCallId, content.length()); + + return Message.builder().role("user").content(List.of(toolResultBlock)).build(); + } + private List buildContentBlocks(JsonNode contentArray) { List blocks = new ArrayList<>(); if (contentArray != null && contentArray.isArray()) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/HttpStreamingHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/HttpStreamingHandler.java index 15078dfccb..9105dc6f8f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/HttpStreamingHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/HttpStreamingHandler.java @@ -5,6 +5,8 @@ package org.opensearch.ml.engine.algorithms.remote.streaming; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_RUN_ID; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_THREAD_ID; import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS; @@ -16,6 +18,10 @@ import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; +import org.opensearch.ml.common.agui.BaseEvent; +import org.opensearch.ml.common.agui.ToolCallArgsEvent; +import org.opensearch.ml.common.agui.ToolCallEndEvent; +import org.opensearch.ml.common.agui.ToolCallStartEvent; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorClientConfig; import org.opensearch.ml.common.exception.MLException; @@ -43,16 +49,25 @@ public class HttpStreamingHandler extends BaseStreamingHandler { private final Connector connector; private OkHttpClient okHttpClient; private String llmInterface; + private Map parameters; public HttpStreamingHandler(String llmInterface, Connector connector, ConnectorClientConfig connectorClientConfig) { + this(llmInterface, connector, connectorClientConfig, null); + } + + public HttpStreamingHandler( + String llmInterface, + Connector connector, + ConnectorClientConfig connectorClientConfig, + Map parameters + ) { this.connector = connector; this.llmInterface = llmInterface; + this.parameters = parameters; - // Get connector client configuration Duration connectionTimeout = Duration.ofSeconds(connectorClientConfig.getConnectionTimeout()); Duration readTimeout = Duration.ofSeconds(connectorClientConfig.getReadTimeout()); - // Initialize OkHttp client for SSE try { AccessController.doPrivileged((PrivilegedExceptionAction) () -> { this.okHttpClient = new OkHttpClient.Builder() @@ -76,7 +91,7 @@ public void startStream( ) { try { log.info("Creating SSE connection for streaming request"); - EventSourceListener listener = new HTTPEventSourceListener(actionListener, llmInterface); + EventSourceListener listener = new HTTPEventSourceListener(actionListener, llmInterface, parameters); Request request = ConnectorUtils.buildOKHttpStreamingRequest(action, connector, parameters, payload); AccessController.doPrivileged((PrivilegedExceptionAction) () -> { @@ -99,6 +114,7 @@ public void handleError(Throwable error, StreamPredictActionListener streamActionListener; private final String llmInterface; + private final boolean isAGUIAgent; private AtomicBoolean isStreamClosed; private boolean functionCallInProgress = false; private boolean agentExecutionInProgress = false; @@ -106,10 +122,21 @@ public final class HTTPEventSourceListener extends EventSourceListener { private String accumulatedToolName = null; private String accumulatedArguments = ""; - public HTTPEventSourceListener(StreamPredictActionListener streamActionListener, String llmInterface) { + public HTTPEventSourceListener( + StreamPredictActionListener streamActionListener, + String llmInterface, + Map parameters + ) { this.streamActionListener = streamActionListener; this.llmInterface = llmInterface; this.isStreamClosed = new AtomicBoolean(false); + + this.isAGUIAgent = parameters != null + && (parameters.containsKey(AGUI_PARAM_THREAD_ID) || parameters.containsKey(AGUI_PARAM_RUN_ID)); + + if (isAGUIAgent) { + log.debug("HttpStreamingHandler: Detected AG-UI agent"); + } } /*** @@ -203,7 +230,6 @@ private void handleDoneEvent() { } private void processStreamChunk(Map dataMap) { - // Handle stop finish reason String finishReason = extractPath(dataMap, "$.choices[0].finish_reason"); if ("stop".equals(finishReason)) { agentExecutionInProgress = false; @@ -211,20 +237,21 @@ private void processStreamChunk(Map dataMap) { return; } - // Process content String content = extractPath(dataMap, "$.choices[0].delta.content"); if (content != null && !content.isEmpty()) { sendContentResponse(content, false, streamActionListener); } - // Process tool call List toolCalls = extractPath(dataMap, "$.choices[0].delta.tool_calls"); if (toolCalls != null) { - accumulateFunctionCall(toolCalls); - sendContentResponse(StringUtils.toJson(toolCalls), false, streamActionListener); + if (isAGUIAgent) { + processAGUIToolCalls(toolCalls); + } else { + accumulateFunctionCall(toolCalls); + sendContentResponse(StringUtils.toJson(toolCalls), false, streamActionListener); + } } - // Handle tool_calls finish reason if ("tool_calls".equals(finishReason) && functionCallInProgress) { completeToolCall(); } @@ -240,15 +267,18 @@ private T extractPath(Map dataMap, String path) { private void completeToolCall() { agentExecutionInProgress = true; - String completeFunctionCall = buildCompleteFunctionCallResponse(); - // Send to client and agent - sendContentResponse(completeFunctionCall, false, streamActionListener); - Map response = gson.fromJson(completeFunctionCall, Map.class); - ModelTensorOutput output = createModelTensorOutput(response); - streamActionListener.onResponse(new MLTaskResponse(output)); + if (isAGUIAgent) { + BaseEvent toolCallEndEvent = new ToolCallEndEvent(accumulatedToolCallId); + sendAGUIEvent(toolCallEndEvent, true, streamActionListener); + } else { + String completeFunctionCall = buildCompleteFunctionCallResponse(); + sendContentResponse(completeFunctionCall, false, streamActionListener); + Map response = gson.fromJson(completeFunctionCall, Map.class); + ModelTensorOutput output = createModelTensorOutput(response); + streamActionListener.onResponse(new MLTaskResponse(output)); + } - // Reset state functionCallInProgress = false; } @@ -268,12 +298,46 @@ private ModelTensorOutput createModelTensorOutput(Map responseDa return ModelTensorOutput.builder().mlModelOutputs(List.of(tensors)).build(); } + private void processAGUIToolCalls(List toolCalls) { + functionCallInProgress = true; + + for (Object toolCall : toolCalls) { + Map tcMap = (Map) toolCall; + + if (tcMap.containsKey("id")) { + String toolCallId = (String) tcMap.get("id"); + if (accumulatedToolCallId == null) { + accumulatedToolCallId = toolCallId; + } + } + + if (tcMap.containsKey("function")) { + Map func = (Map) tcMap.get("function"); + + if (func.containsKey("name")) { + String toolName = (String) func.get("name"); + if (accumulatedToolName == null) { + accumulatedToolName = toolName; + + BaseEvent startEvent = new ToolCallStartEvent(accumulatedToolCallId, toolName, null); + sendAGUIEvent(startEvent, false, streamActionListener); + } + } + + if (func.containsKey("arguments")) { + String argsDelta = (String) func.get("arguments"); + BaseEvent argsEvent = new ToolCallArgsEvent(accumulatedToolCallId, argsDelta); + sendAGUIEvent(argsEvent, false, streamActionListener); + } + } + } + } + private void accumulateFunctionCall(List toolCalls) { functionCallInProgress = true; for (Object toolCall : toolCalls) { Map tcMap = (Map) toolCall; - // Extract ID and name from first chunk if (tcMap.containsKey("id")) { accumulatedToolCallId = (String) tcMap.get("id"); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java index 4a07b79cab..e7fb80fac2 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java @@ -9,6 +9,9 @@ import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_BACKEND_TOOL_NAMES; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_RUN_ID; +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_THREAD_ID; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; import static org.opensearch.ml.plugin.MachineLearningPlugin.STREAM_EXECUTE_THREAD_POOL; import static org.opensearch.ml.utils.MLExceptionUtils.AGENT_FRAMEWORK_DISABLED_ERR_MSG; @@ -20,6 +23,7 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.Locale; @@ -47,6 +51,8 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.agui.*; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.MLInput; @@ -73,6 +79,10 @@ import org.opensearch.transport.client.node.NodeClient; import org.opensearch.transport.stream.StreamTransportResponse; +import com.google.gson.Gson; +import com.google.gson.JsonElement; +import com.google.gson.JsonParser; + import lombok.extern.log4j.Log4j2; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -161,6 +171,33 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client try { BytesReference completeContent = combineChunks(chunks); MLExecuteTaskRequest mlExecuteTaskRequest = getRequest(agentId, request, completeContent); + boolean isAGUI = isAGUIAgent(mlExecuteTaskRequest); + + // Send RUN_STARTED event immediately for AG-UI agents (ReAct cycle begins) + if (isAGUI) { + String threadId = extractThreadId(mlExecuteTaskRequest); + String runId = extractRunId(mlExecuteTaskRequest); + + BaseEvent runStartedEvent = new RunStartedEvent(threadId, runId); + HttpChunk startChunk = createHttpChunk("data: " + runStartedEvent.toJsonString() + "\n\n", false); + channel.sendChunk(startChunk); + log.debug("RestMLExecuteStreamAction: Sent RUN_STARTED event - threadId={}, runId={}", threadId, runId); + } + + // Extract backend tool names from agent configuration and add to request for AG-UI filtering + List backendToolNames = extractBackendToolNamesFromAgent(agent); + if (isAGUI && !backendToolNames.isEmpty()) { + // Add backend tool names to request parameters so they're available during streaming + RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) ((AgentMLInput) mlExecuteTaskRequest + .getInput()).getInputDataset(); + inputDataSet.getParameters().put(AGUI_PARAM_BACKEND_TOOL_NAMES, new Gson().toJson(backendToolNames)); + log + .info( + "AG-UI: Added {} backend tool names to request for streaming filter: {}", + backendToolNames.size(), + backendToolNames + ); + } final CompletableFuture future = new CompletableFuture<>(); StreamTransportResponseHandler handler = new StreamTransportResponseHandler() { @@ -170,7 +207,7 @@ public void handleStreamResponse(StreamTransportResponse streamR MLTaskResponse response = streamResponse.nextResponse(); if (response != null) { - HttpChunk responseChunk = convertToHttpChunk(response); + HttpChunk responseChunk = convertToHttpChunk(response, isAGUI); channel.sendChunk(responseChunk); // Recursively handle the next response @@ -319,61 +356,126 @@ MLExecuteTaskRequest getRequest(String agentId, RestRequest request, BytesRefere } String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); FunctionName functionName = FunctionName.AGENT; - Input input = MLInput.parse(parser, functionName.name()); - AgentMLInput agentInput = (AgentMLInput) input; - agentInput.setAgentId(agentId); - agentInput.setTenantId(tenantId); - agentInput.setIsAsync(async); - RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) agentInput.getInputDataset(); + + // Check if this is AG-UI input format + String requestBodyJson = content.utf8ToString(); + Input input; + if (AGUIInputConverter.isAGUIInput(requestBodyJson)) { + log.debug("AG-UI: Detected AG-UI input format for streaming agent: {}", agentId); + input = AGUIInputConverter.convertFromAGUIInput(requestBodyJson, agentId, tenantId, async); + } else { + input = MLInput.parse(parser, functionName.name()); + AgentMLInput agentInput = (AgentMLInput) input; + agentInput.setAgentId(agentId); + agentInput.setTenantId(tenantId); + agentInput.setIsAsync(async); + } + + RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) ((AgentMLInput) input).getInputDataset(); inputDataSet.getParameters().put("stream", String.valueOf(true)); return new MLExecuteTaskRequest(functionName, input); } - private HttpChunk convertToHttpChunk(MLTaskResponse response) throws IOException { - String sseData; + private boolean isAGUIAgent(MLExecuteTaskRequest request) { + if (request.getInput() instanceof AgentMLInput agentInput) { + RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) agentInput.getInputDataset(); + + // Check if this request came from AG-UI by looking for AG-UI specific parameters + return inputDataSet.getParameters().containsKey(AGUI_PARAM_THREAD_ID) + || inputDataSet.getParameters().containsKey(AGUI_PARAM_RUN_ID); + } + return false; + } + + private List extractBackendToolNamesFromAgent(MLAgent agent) { + List backendToolNames = new ArrayList<>(); + if (agent != null && agent.getTools() != null) { + for (MLToolSpec toolSpec : agent.getTools()) { + if (toolSpec.getName() != null) { + backendToolNames.add(toolSpec.getName()); + } + } + } + return backendToolNames; + } + + private String extractThreadId(MLExecuteTaskRequest request) { + if (request.getInput() instanceof AgentMLInput) { + AgentMLInput agentInput = (AgentMLInput) request.getInput(); + RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) agentInput.getInputDataset(); + String threadId = inputDataSet.getParameters().get(AGUI_PARAM_THREAD_ID); + return threadId != null ? threadId : "thread_" + System.currentTimeMillis(); + } + return "thread_" + System.currentTimeMillis(); + } + + private String extractRunId(MLExecuteTaskRequest request) { + if (request.getInput() instanceof AgentMLInput) { + AgentMLInput agentInput = (AgentMLInput) request.getInput(); + RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) agentInput.getInputDataset(); + String runId = inputDataSet.getParameters().get(AGUI_PARAM_RUN_ID); + return runId != null ? runId : "run_" + System.currentTimeMillis(); + } + return "run_" + System.currentTimeMillis(); + } + + private HttpChunk convertToHttpChunk(MLTaskResponse response, boolean isAGUIAgent) throws IOException { + String memoryId = ""; + String parentInteractionId = ""; + String content = ""; boolean isLast = false; try { Map dataMap = extractDataMap(response); if (dataMap.containsKey("error")) { - // Error response - String errorMessage = (String) dataMap.get("error"); - sseData = String.format("data: {\"error\": \"%s\"}\n\n", errorMessage.replace("\"", "\\\"").replace("\n", "\\n")); + // Error response - handle errors + content = (String) dataMap.get("error"); isLast = true; } else { // TODO: refactor to handle other types of agents // Regular response - extract values and build proper structure - String memoryId = extractTensorResult(response, "memory_id"); - String parentInteractionId = extractTensorResult(response, "parent_interaction_id"); - String content = dataMap.containsKey("content") ? (String) dataMap.get("content") : ""; - isLast = dataMap.containsKey("is_last") ? Boolean.TRUE.equals(dataMap.get("is_last")) : false; - boolean finalIsLast = isLast; - - List orderedTensors = List - .of( - ModelTensor.builder().name("memory_id").result(memoryId).build(), - ModelTensor.builder().name("parent_interaction_id").result(parentInteractionId).build(), - ModelTensor.builder().name("response").dataAsMap(new LinkedHashMap() { - { - put("content", content); - put("is_last", finalIsLast); - } - }).build() - ); - - ModelTensors tensors = ModelTensors.builder().mlModelTensors(orderedTensors).build(); - ModelTensorOutput tensorOutput = ModelTensorOutput.builder().mlModelOutputs(List.of(tensors)).build(); - - XContentBuilder builder = XContentFactory.jsonBuilder(); - tensorOutput.toXContent(builder, ToXContent.EMPTY_PARAMS); - sseData = "data: " + builder.toString() + "\n\n"; + memoryId = extractTensorResult(response, "memory_id"); + parentInteractionId = extractTensorResult(response, "parent_interaction_id"); + content = dataMap.containsKey("content") ? (String) dataMap.get("content") : ""; + isLast = dataMap.containsKey("is_last") && Boolean.TRUE.equals(dataMap.get("is_last")); } } catch (Exception e) { log.error("Failed to process response", e); - sseData = "data: {\"error\": \"Processing failed\"}\n\n"; + content = "Processing failed"; isLast = true; } + + String finalContent = content; + boolean finalIsLast = isLast; + + // If this is an AG-UI agent, convert to AG-UI event format + if (isAGUIAgent) { + return convertToAGUIEvent(content, isLast); + } + + // Create ordered tensors + List orderedTensors = List + .of( + ModelTensor.builder().name("memory_id").result(memoryId).build(), + ModelTensor.builder().name("parent_interaction_id").result(parentInteractionId).build(), + ModelTensor.builder().name("response").dataAsMap(new LinkedHashMap() { + { + put("content", finalContent); + put("is_last", finalIsLast); + } + }).build() + ); + + ModelTensors tensors = ModelTensors.builder().mlModelTensors(orderedTensors).build(); + + ModelTensorOutput tensorOutput = ModelTensorOutput.builder().mlModelOutputs(List.of(tensors)).build(); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + tensorOutput.toXContent(builder, ToXContent.EMPTY_PARAMS); + String jsonData = builder.toString(); + + String sseData = "data: " + jsonData + "\n\n"; return createHttpChunk(sseData, isLast); } @@ -407,6 +509,37 @@ private String extractTensorResult(MLTaskResponse response, String tensorName) { return Map.of(); } + private HttpChunk convertToAGUIEvent(String content, boolean isLast) { + log + .debug( + "RestMLExecuteStreamAction: convertToAGUIEvent() called - contentLength={}, isLast={}", + content != null ? content.length() : "null", + isLast + ); + + StringBuilder sseResponse = new StringBuilder(); + + if (content != null && !content.isEmpty()) { + log.debug("RestMLExecuteStreamAction: Processing content: '{}'", content); + + try { + JsonElement element = JsonParser.parseString(content); + sseResponse.append("data: ").append(element).append("\n\n"); + log.debug("RestMLExecuteStreamAction: Processing json element: '{}'", element); + } catch (Exception e) { + log.error("Failed to process AG-UI events chunk content {}", content, e); + BaseEvent runErrorEvent = new RunErrorEvent(e.getMessage(), null); + sseResponse.append("data: ").append(runErrorEvent.toJsonString()).append("\n\n"); + } + } else { + log.warn("Received null or empty AG-UI content chunk"); + } + + String finalSse = sseResponse.toString(); + log.debug("RestMLExecuteStreamAction: Returning chunk - length={}", finalSse.length()); + return createHttpChunk(finalSse, isLast); + } + @VisibleForTesting BytesReference combineChunks(List chunks) { try { From 01cfec361a90f6ba33e31800df96f474d274b1c6 Mon Sep 17 00:00:00 2001 From: Jiaping Zeng Date: Tue, 18 Nov 2025 21:08:39 -0800 Subject: [PATCH 5/6] add feature flag Signed-off-by: Jiaping Zeng --- .../ml/common/settings/MLCommonsSettings.java | 6 ++++++ .../ml/common/settings/MLFeatureEnabledSetting.java | 13 +++++++++++++ .../opensearch/ml/plugin/MachineLearningPlugin.java | 3 ++- .../org/opensearch/ml/rest/RestMLExecuteAction.java | 4 ++++ .../ml/rest/RestMLExecuteStreamAction.java | 4 ++++ 5 files changed, 29 insertions(+), 1 deletion(-) diff --git a/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java b/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java index c139ea4b68..669a138d17 100644 --- a/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java +++ b/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java @@ -483,4 +483,10 @@ private MLCommonsSettings() {} // Feature flag for streaming feature public static final Setting ML_COMMONS_STREAM_ENABLED = Setting .boolSetting(ML_PLUGIN_SETTING_PREFIX + "stream_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); + + // Feature flag for AG-UI agent support + public static final Setting ML_COMMONS_AG_UI_ENABLED = Setting + .boolSetting(ML_PLUGIN_SETTING_PREFIX + "ag_ui_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final String ML_COMMONS_AG_UI_DISABLED_MESSAGE = + "The AG-UI agent feature is not enabled. To enable, please update the setting " + ML_COMMONS_AG_UI_ENABLED.getKey(); } diff --git a/common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java b/common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java index 2642a12db6..cad12c1c54 100644 --- a/common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java +++ b/common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java @@ -22,6 +22,7 @@ import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_STREAM_ENABLED; +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AG_UI_ENABLED; import java.util.ArrayList; import java.util.List; @@ -63,6 +64,8 @@ public class MLFeatureEnabledSetting { private volatile Boolean isStreamEnabled; + private volatile Boolean isAGUIEnabled; + private final List listeners = new ArrayList<>(); public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) { @@ -83,6 +86,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) isAgenticMemoryEnabled = ML_COMMONS_AGENTIC_MEMORY_ENABLED.get(settings); isIndexInsightEnabled = ML_COMMONS_INDEX_INSIGHT_FEATURE_ENABLED.get(settings); isStreamEnabled = ML_COMMONS_STREAM_ENABLED.get(settings); + isAGUIEnabled = ML_COMMONS_AG_UI_ENABLED.get(settings); clusterService .getClusterSettings() @@ -112,6 +116,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) clusterService .getClusterSettings() .addSettingsUpdateConsumer(ML_COMMONS_INDEX_INSIGHT_FEATURE_ENABLED, it -> isIndexInsightEnabled = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_AG_UI_ENABLED, it -> isAGUIEnabled = it); clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED, it -> { isStaticMetricCollectionEnabled = it; for (SettingsChangeListener listener : listeners) { @@ -245,4 +250,12 @@ public boolean isIndexInsightEnabled() { public boolean isStreamEnabled() { return isStreamEnabled; } + + /** + * Whether the AG-UI agent feature is enabled. If disabled, AG-UI agents will be blocked. + * @return whether the AG-UI agent feature is enabled. + */ + public boolean isAGUIEnabled() { + return isAGUIEnabled; + } } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 62de34961e..9e22cac5ec 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -1364,7 +1364,8 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_INDEX_INSIGHT_FEATURE_ENABLED, MLCommonsSettings.REMOTE_METADATA_GLOBAL_TENANT_ID, MLCommonsSettings.REMOTE_METADATA_GLOBAL_RESOURCE_CACHE_TTL, - MLCommonsSettings.ML_COMMONS_STREAM_ENABLED + MLCommonsSettings.ML_COMMONS_STREAM_ENABLED, + MLCommonsSettings.ML_COMMONS_AG_UI_ENABLED ); return settings; } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java index d6b57d7487..c21e0f9170 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java @@ -11,6 +11,7 @@ import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_EXECUTE_TOOL_DISABLED_MESSAGE; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; import static org.opensearch.ml.utils.MLExceptionUtils.AGENT_FRAMEWORK_DISABLED_ERR_MSG; +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AG_UI_DISABLED_MESSAGE; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_AGENT_ID; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TOOL_NAME; @@ -128,6 +129,9 @@ MLExecuteTaskRequest getRequest(RestRequest request) throws IOException { String requestBodyJson = request.contentOrSourceParam().v2().utf8ToString(); if (AGUIInputConverter.isAGUIInput(requestBodyJson)) { + if (!mlFeatureEnabledSetting.isAGUIEnabled()) { + throw new IllegalStateException(ML_COMMONS_AG_UI_DISABLED_MESSAGE); + } throw new IllegalArgumentException( "AG-UI agents require streaming execution. " + "Please use the streaming endpoint: POST /_plugins/_ml/agents/" diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java index e7fb80fac2..3230b48369 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java @@ -16,6 +16,7 @@ import static org.opensearch.ml.plugin.MachineLearningPlugin.STREAM_EXECUTE_THREAD_POOL; import static org.opensearch.ml.utils.MLExceptionUtils.AGENT_FRAMEWORK_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.MLExceptionUtils.STREAM_DISABLED_ERR_MSG; +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AG_UI_DISABLED_MESSAGE; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_AGENT_ID; import static org.opensearch.ml.utils.RestActionUtils.isAsync; import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID; @@ -361,6 +362,9 @@ MLExecuteTaskRequest getRequest(String agentId, RestRequest request, BytesRefere String requestBodyJson = content.utf8ToString(); Input input; if (AGUIInputConverter.isAGUIInput(requestBodyJson)) { + if (!mlFeatureEnabledSetting.isAGUIEnabled()) { + throw new IllegalStateException(ML_COMMONS_AG_UI_DISABLED_MESSAGE); + } log.debug("AG-UI: Detected AG-UI input format for streaming agent: {}", agentId); input = AGUIInputConverter.convertFromAGUIInput(requestBodyJson, agentId, tenantId, async); } else { From 3bddf55ce6cb835577a671f8132c73daf54250fd Mon Sep 17 00:00:00 2001 From: Jiaping Zeng Date: Wed, 19 Nov 2025 01:22:18 -0800 Subject: [PATCH 6/6] address comments Signed-off-by: Jiaping Zeng --- .../ml/common/agui/AGUIInputConverter.java | 2 +- .../settings/MLFeatureEnabledSetting.java | 2 +- .../engine/algorithms/agent/AgentUtils.java | 2 + .../algorithms/agent/MLAGUIAgentRunner.java | 4 +- .../ml/rest/RestMLExecuteAction.java | 9 +++- .../ml/rest/RestMLExecuteStreamAction.java | 41 ++++++++++++++----- 6 files changed, 45 insertions(+), 15 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/agui/AGUIInputConverter.java b/common/src/main/java/org/opensearch/ml/common/agui/AGUIInputConverter.java index 8ae2cde044..5bd447a47e 100644 --- a/common/src/main/java/org/opensearch/ml/common/agui/AGUIInputConverter.java +++ b/common/src/main/java/org/opensearch/ml/common/agui/AGUIInputConverter.java @@ -72,7 +72,7 @@ public static boolean isAGUIInput(String inputJson) { return true; } catch (Exception e) { - log.debug("Failed to parse input as JSON for AG-UI detection", e); + log.error("Failed to parse input as JSON for AG-UI detection", e); return false; } } diff --git a/common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java b/common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java index cad12c1c54..d045875b9a 100644 --- a/common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java +++ b/common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java @@ -7,6 +7,7 @@ import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED; +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AG_UI_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_EXECUTE_TOOL_ENABLED; @@ -22,7 +23,6 @@ import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_STREAM_ENABLED; -import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AG_UI_ENABLED; import java.util.ArrayList; import java.util.List; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index 371123d2e1..c794ffdf88 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -1066,6 +1066,8 @@ public static Map wrapFrontendToolsAsToolObjects(List createMemoryParams( String question, String memoryId, diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAGUIAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAGUIAgentRunner.java index 14efd48de3..80f8ea9a18 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAGUIAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAGUIAgentRunner.java @@ -37,7 +37,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.agent.MLAgent; -import org.opensearch.ml.common.spi.memory.Memory; +import org.opensearch.ml.common.memory.Memory; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.function_calling.FunctionCalling; @@ -362,6 +362,7 @@ private void processAGUIMessages(Map params, String llmInterface } } catch (Exception e) { log.error("Failed to process AG-UI messages to chat history", e); + throw new IllegalArgumentException("Failed to process AG-UI messages to chat history", e); } } @@ -405,6 +406,7 @@ private void processAGUIContext(Map params) { } catch (Exception e) { log.error("Failed to process AG-UI context", e); + throw new IllegalArgumentException("Failed to process AG-UI context", e); } } } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java index c21e0f9170..1c0cbedc3c 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java @@ -8,10 +8,10 @@ import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AG_UI_DISABLED_MESSAGE; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_EXECUTE_TOOL_DISABLED_MESSAGE; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; import static org.opensearch.ml.utils.MLExceptionUtils.AGENT_FRAMEWORK_DISABLED_ERR_MSG; -import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AG_UI_DISABLED_MESSAGE; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_AGENT_ID; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TOOL_NAME; @@ -140,6 +140,13 @@ MLExecuteTaskRequest getRequest(RestRequest request) throws IOException { ); } else { input = MLInput.parse(parser, functionName.name()); + + if (!(input instanceof AgentMLInput)) { + throw new IllegalArgumentException( + String.format("Invalid input type. Expected: AgentMLInput, Received: %s", input.getClass().getSimpleName()) + ); + } + ((AgentMLInput) input).setAgentId(agentId); ((AgentMLInput) input).setTenantId(tenantId); ((AgentMLInput) input).setIsAsync(async); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java index 3230b48369..0749ae51da 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java @@ -12,11 +12,11 @@ import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_BACKEND_TOOL_NAMES; import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_RUN_ID; import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_THREAD_ID; +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AG_UI_DISABLED_MESSAGE; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; import static org.opensearch.ml.plugin.MachineLearningPlugin.STREAM_EXECUTE_THREAD_POOL; import static org.opensearch.ml.utils.MLExceptionUtils.AGENT_FRAMEWORK_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.MLExceptionUtils.STREAM_DISABLED_ERR_MSG; -import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AG_UI_DISABLED_MESSAGE; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_AGENT_ID; import static org.opensearch.ml.utils.RestActionUtils.isAsync; import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID; @@ -182,22 +182,41 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client BaseEvent runStartedEvent = new RunStartedEvent(threadId, runId); HttpChunk startChunk = createHttpChunk("data: " + runStartedEvent.toJsonString() + "\n\n", false); channel.sendChunk(startChunk); - log.debug("RestMLExecuteStreamAction: Sent RUN_STARTED event - threadId={}, runId={}", threadId, runId); + log.debug("AG-UI: RestMLExecuteStreamAction: Sent RUN_STARTED event - threadId={}, runId={}", threadId, runId); } // Extract backend tool names from agent configuration and add to request for AG-UI filtering List backendToolNames = extractBackendToolNamesFromAgent(agent); if (isAGUI && !backendToolNames.isEmpty()) { // Add backend tool names to request parameters so they're available during streaming - RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) ((AgentMLInput) mlExecuteTaskRequest - .getInput()).getInputDataset(); - inputDataSet.getParameters().put(AGUI_PARAM_BACKEND_TOOL_NAMES, new Gson().toJson(backendToolNames)); - log - .info( - "AG-UI: Added {} backend tool names to request for streaming filter: {}", - backendToolNames.size(), - backendToolNames - ); + try { + if (!(mlExecuteTaskRequest.getInput() instanceof AgentMLInput)) { + throw new IllegalArgumentException( + "Invalid input type. Expected: AgentMLInput, Received: " + + mlExecuteTaskRequest.getInput().getClass().getSimpleName() + ); + } + AgentMLInput agentInput = (AgentMLInput) mlExecuteTaskRequest.getInput(); + + if (!(agentInput.getInputDataset() instanceof RemoteInferenceInputDataSet)) { + throw new IllegalArgumentException( + "Invalid dataset type. Expected: RemoteInferenceInputDataSet, Received: " + + agentInput.getInputDataset().getClass().getSimpleName() + ); + } + RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) agentInput.getInputDataset(); + + inputDataSet.getParameters().put(AGUI_PARAM_BACKEND_TOOL_NAMES, new Gson().toJson(backendToolNames)); + log + .info( + "AG-UI: Added {} backend tool names to request for streaming filter: {}", + backendToolNames.size(), + backendToolNames + ); + } catch (ClassCastException e) { + log.error("Failed to cast input types for backend tool names extraction", e); + throw new IllegalArgumentException("Invalid input type configuration for AG-UI request", e); + } } final CompletableFuture future = new CompletableFuture<>();