-
Notifications
You must be signed in to change notification settings - Fork 186
AG-UI support in Agent Framework #4347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f820523
44e353e
0351e4e
1caab62
01cfec3
c7722c0
3bddf55
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| /* | ||
| * Copyright OpenSearch Contributors | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| package org.opensearch.ml.common.agui; | ||
|
|
||
| /** | ||
| * Constants for AG-UI implementation. | ||
| * | ||
| * Naming Conventions: | ||
| * AGUI_ROLE_* - Message role identifiers | ||
| * AGUI_PARAM_* - Internal parameter keys | ||
| * AGUI_FIELD_* - External API field names | ||
| * AGUI_EVENT_* - Event type identifiers | ||
| * AGUI_PREFIX_* - ID prefixes for generated identifiers | ||
| */ | ||
| public final class AGUIConstants { | ||
|
|
||
| // ========== Message Roles ========== | ||
|
|
||
| /** Role identifier for assistant messages */ | ||
| public static final String AGUI_ROLE_ASSISTANT = "assistant"; | ||
|
|
||
| /** Role identifier for user messages */ | ||
| public static final String AGUI_ROLE_USER = "user"; | ||
|
|
||
| /** Role identifier for tool result messages */ | ||
| public static final String AGUI_ROLE_TOOL = "tool"; | ||
|
|
||
| // ========== Parameter Keys (Internal) ========== | ||
|
|
||
| /** Parameter key for AG-UI thread identifier */ | ||
| public static final String AGUI_PARAM_THREAD_ID = "agui_thread_id"; | ||
|
|
||
| /** Parameter key for AG-UI run identifier */ | ||
| public static final String AGUI_PARAM_RUN_ID = "agui_run_id"; | ||
|
|
||
| /** Parameter key for AG-UI messages array */ | ||
| public static final String AGUI_PARAM_MESSAGES = "agui_messages"; | ||
|
|
||
| /** Parameter key for AG-UI tools array */ | ||
| public static final String AGUI_PARAM_TOOLS = "agui_tools"; | ||
|
|
||
| /** Parameter key for AG-UI context array */ | ||
| public static final String AGUI_PARAM_CONTEXT = "agui_context"; | ||
|
|
||
| /** Parameter key for AG-UI state object */ | ||
| public static final String AGUI_PARAM_STATE = "agui_state"; | ||
|
|
||
| /** Parameter key for AG-UI forwarded properties */ | ||
| public static final String AGUI_PARAM_FORWARDED_PROPS = "agui_forwarded_props"; | ||
|
|
||
| /** Parameter key for AG-UI tool call results */ | ||
| public static final String AGUI_PARAM_TOOL_CALL_RESULTS = "agui_tool_call_results"; | ||
|
|
||
| /** Parameter key for AG-UI assistant tool call messages */ | ||
| public static final String AGUI_PARAM_ASSISTANT_TOOL_CALL_MESSAGES = "agui_assistant_tool_call_messages"; | ||
|
|
||
| /** Parameter key for backend tool names (used for filtering) */ | ||
| public static final String AGUI_PARAM_BACKEND_TOOL_NAMES = "backend_tool_names"; | ||
|
|
||
| // ========== Field Names (External API) ========== | ||
|
|
||
| /** Field name for thread identifier in AG-UI input */ | ||
| public static final String AGUI_FIELD_THREAD_ID = "threadId"; | ||
|
|
||
| /** Field name for run identifier in AG-UI input */ | ||
| public static final String AGUI_FIELD_RUN_ID = "runId"; | ||
|
|
||
| /** Field name for messages array in AG-UI input */ | ||
| public static final String AGUI_FIELD_MESSAGES = "messages"; | ||
|
|
||
| /** Field name for tools array in AG-UI input */ | ||
| public static final String AGUI_FIELD_TOOLS = "tools"; | ||
|
|
||
| /** Field name for context array in AG-UI input */ | ||
| public static final String AGUI_FIELD_CONTEXT = "context"; | ||
|
|
||
| /** Field name for state object in AG-UI input */ | ||
| public static final String AGUI_FIELD_STATE = "state"; | ||
|
|
||
| /** Field name for forwarded properties in AG-UI input */ | ||
| public static final String AGUI_FIELD_FORWARDED_PROPS = "forwardedProps"; | ||
|
|
||
| /** Field name for message role */ | ||
| public static final String AGUI_FIELD_ROLE = "role"; | ||
|
|
||
| /** Field name for message content */ | ||
| public static final String AGUI_FIELD_CONTENT = "content"; | ||
|
|
||
| /** Field name for tool call identifier */ | ||
| public static final String AGUI_FIELD_TOOL_CALL_ID = "toolCallId"; | ||
|
|
||
| /** Field name for tool calls array */ | ||
| public static final String AGUI_FIELD_TOOL_CALLS = "toolCalls"; | ||
|
|
||
| /** Field name for message identifier */ | ||
| public static final String AGUI_FIELD_ID = "id"; | ||
|
|
||
| /** Field name for tool call type */ | ||
| public static final String AGUI_FIELD_TYPE = "type"; | ||
|
|
||
| /** Field name for function object in tool calls */ | ||
| public static final String AGUI_FIELD_FUNCTION = "function"; | ||
|
|
||
| /** Field name for function name */ | ||
| public static final String AGUI_FIELD_NAME = "name"; | ||
|
|
||
| /** Field name for function arguments */ | ||
| public static final String AGUI_FIELD_ARGUMENTS = "arguments"; | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,186 @@ | ||
| /* | ||
| * Copyright OpenSearch Contributors | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| package org.opensearch.ml.common.agui; | ||
|
|
||
| import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_CONTENT; | ||
| import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_CONTEXT; | ||
| import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_FORWARDED_PROPS; | ||
| import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_MESSAGES; | ||
| import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_ROLE; | ||
| import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_RUN_ID; | ||
| import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_STATE; | ||
| import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_THREAD_ID; | ||
| import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_TOOLS; | ||
| import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_TOOL_CALL_ID; | ||
| import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_CONTEXT; | ||
| import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_FORWARDED_PROPS; | ||
| import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_MESSAGES; | ||
| import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_RUN_ID; | ||
| import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_STATE; | ||
| import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_THREAD_ID; | ||
| import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_TOOLS; | ||
| import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_TOOL_CALL_RESULTS; | ||
| import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_ROLE_USER; | ||
| import static org.opensearch.ml.common.utils.StringUtils.getStringField; | ||
|
|
||
| import java.util.HashMap; | ||
| import java.util.List; | ||
| import java.util.Map; | ||
|
|
||
| import org.opensearch.ml.common.FunctionName; | ||
| import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; | ||
| import org.opensearch.ml.common.input.execute.agent.AgentMLInput; | ||
|
|
||
| import com.google.gson.Gson; | ||
| import com.google.gson.JsonElement; | ||
| import com.google.gson.JsonObject; | ||
| import com.google.gson.JsonParser; | ||
|
|
||
| import lombok.extern.log4j.Log4j2; | ||
|
|
||
| @Log4j2 | ||
| public class AGUIInputConverter { | ||
|
|
||
| private static final Gson gson = new Gson(); | ||
|
|
||
| public static boolean isAGUIInput(String inputJson) { | ||
| try { | ||
| JsonObject jsonObj = JsonParser.parseString(inputJson).getAsJsonObject(); | ||
|
|
||
| // Check required fields exist | ||
| if (!jsonObj.has(AGUI_FIELD_THREAD_ID) | ||
| || !jsonObj.has(AGUI_FIELD_RUN_ID) | ||
| || !jsonObj.has(AGUI_FIELD_MESSAGES) | ||
| || !jsonObj.has(AGUI_FIELD_TOOLS)) { | ||
| return false; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we add a log? |
||
| } | ||
|
|
||
| // Validate messages is an array | ||
| JsonElement messages = jsonObj.get(AGUI_FIELD_MESSAGES); | ||
| if (!messages.isJsonArray()) { | ||
| return false; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, let's add logs? |
||
| } | ||
|
|
||
| // Validate tools is an array | ||
| JsonElement tools = jsonObj.get(AGUI_FIELD_TOOLS); | ||
| if (!tools.isJsonArray()) { | ||
| return false; | ||
| } | ||
|
|
||
| return true; | ||
| } catch (Exception e) { | ||
| log.error("Failed to parse input as JSON for AG-UI detection", e); | ||
| return false; | ||
| } | ||
| } | ||
|
|
||
| public static AgentMLInput convertFromAGUIInput(String aguiInputJson, String agentId, String tenantId, boolean isAsync) { | ||
| try { | ||
| JsonObject aguiInput = JsonParser.parseString(aguiInputJson).getAsJsonObject(); | ||
|
|
||
| String threadId = getStringField(aguiInput, AGUI_FIELD_THREAD_ID); | ||
| String runId = getStringField(aguiInput, AGUI_FIELD_RUN_ID); | ||
| JsonElement state = aguiInput.get(AGUI_FIELD_STATE); | ||
| JsonElement messages = aguiInput.get(AGUI_FIELD_MESSAGES); | ||
| JsonElement tools = aguiInput.get(AGUI_FIELD_TOOLS); | ||
| JsonElement context = aguiInput.get(AGUI_FIELD_CONTEXT); | ||
| JsonElement forwardedProps = aguiInput.get(AGUI_FIELD_FORWARDED_PROPS); | ||
|
|
||
| Map<String, String> parameters = new HashMap<>(); | ||
| parameters.put(AGUI_PARAM_THREAD_ID, threadId); | ||
| parameters.put(AGUI_PARAM_RUN_ID, runId); | ||
|
|
||
| if (state != null) { | ||
| parameters.put(AGUI_PARAM_STATE, gson.toJson(state)); | ||
| } | ||
|
|
||
| if (messages != null) { | ||
| parameters.put(AGUI_PARAM_MESSAGES, gson.toJson(messages)); | ||
| extractUserQuestion(messages, parameters); | ||
| } | ||
|
|
||
| if (tools != null) { | ||
| parameters.put(AGUI_PARAM_TOOLS, gson.toJson(tools)); | ||
| } | ||
|
|
||
| if (context != null) { | ||
| parameters.put(AGUI_PARAM_CONTEXT, gson.toJson(context)); | ||
| } | ||
|
|
||
| if (forwardedProps != null) { | ||
| parameters.put(AGUI_PARAM_FORWARDED_PROPS, gson.toJson(forwardedProps)); | ||
| } | ||
| RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build(); | ||
| AgentMLInput agentMLInput = new AgentMLInput(agentId, tenantId, FunctionName.AGENT, inputDataSet, isAsync); | ||
|
|
||
| log.debug("Converted AG-UI input to ML-Commons format for agent: {}", agentId); | ||
| return agentMLInput; | ||
|
|
||
| } catch (Exception e) { | ||
| log.error("Failed to convert AG-UI input to ML-Commons format", e); | ||
| throw new IllegalArgumentException("Invalid AG-UI input format", e); | ||
| } | ||
| } | ||
|
|
||
| private static void extractUserQuestion(JsonElement messages, Map<String, String> parameters) { | ||
| if (messages == null || !messages.isJsonArray()) { | ||
| throw new IllegalArgumentException("Invalid AG-UI messages"); | ||
| } | ||
|
|
||
| try { | ||
| // Find the last user message to use as the current question | ||
| String lastUserMessage = null; | ||
| String toolCallResults = null; | ||
|
|
||
| for (JsonElement messageElement : messages.getAsJsonArray()) { | ||
| if (messageElement.isJsonObject()) { | ||
| JsonObject message = messageElement.getAsJsonObject(); | ||
| JsonElement roleElement = message.get(AGUI_FIELD_ROLE); | ||
| JsonElement contentElement = message.get(AGUI_FIELD_CONTENT); | ||
| JsonElement toolCallIdElement = message.get(AGUI_FIELD_TOOL_CALL_ID); | ||
|
|
||
| if (roleElement != null | ||
| && AGUI_ROLE_USER.equals(roleElement.getAsString()) | ||
| && contentElement != null | ||
| && !contentElement.isJsonNull()) { | ||
|
|
||
| String content = contentElement.getAsString(); | ||
|
|
||
| // Check if this is a tool call result (has toolCallId field) | ||
| if (toolCallIdElement != null && !toolCallIdElement.isJsonNull()) { | ||
| // This is a tool call result from frontend | ||
| String toolCallId = toolCallIdElement.getAsString(); | ||
|
|
||
| // Create tool result structure | ||
| JsonObject toolResult = new JsonObject(); | ||
| toolResult.addProperty("tool_call_id", toolCallId); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we put all these in a static variable? |
||
| toolResult.addProperty("content", content); | ||
|
|
||
| toolCallResults = gson.toJson(List.of(toolResult)); | ||
| log.debug("Extracted tool call result from AG-UI messages: toolCallId={}, content={}", toolCallId, content); | ||
| } else { | ||
| // Regular user message | ||
| lastUserMessage = content; | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Set appropriate parameters based on what was found | ||
| if (toolCallResults != null) { | ||
| parameters.put(AGUI_PARAM_TOOL_CALL_RESULTS, toolCallResults); | ||
| log.debug("Detected AG-UI tool call results: {}", toolCallResults); | ||
| } else if (lastUserMessage != null) { | ||
| parameters.put("question", lastUserMessage); | ||
| log.debug("Extracted user question from AG-UI messages: {}", lastUserMessage); | ||
| } else { | ||
| throw new IllegalArgumentException("No user message found in AG-UI messages"); | ||
| } | ||
|
Comment on lines
+173
to
+181
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| } catch (Exception e) { | ||
| throw new IllegalArgumentException("Invalid AG-UI message format", e); | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| /* | ||
| * Copyright OpenSearch Contributors | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| package org.opensearch.ml.common.agui; | ||
|
|
||
| import java.io.IOException; | ||
| import java.util.Map; | ||
|
|
||
| import org.opensearch.common.xcontent.XContentFactory; | ||
| import org.opensearch.core.common.io.stream.StreamInput; | ||
| import org.opensearch.core.common.io.stream.StreamOutput; | ||
| import org.opensearch.core.common.io.stream.Writeable; | ||
| import org.opensearch.core.xcontent.ToXContent; | ||
| import org.opensearch.core.xcontent.ToXContentFragment; | ||
| import org.opensearch.core.xcontent.XContentBuilder; | ||
|
|
||
| import lombok.AllArgsConstructor; | ||
| import lombok.Data; | ||
| import lombok.NoArgsConstructor; | ||
|
|
||
| @Data | ||
| @NoArgsConstructor | ||
| @AllArgsConstructor | ||
| public abstract class BaseEvent implements ToXContentFragment, Writeable { | ||
|
|
||
| protected String type; | ||
| protected Long timestamp; | ||
| protected Map<String, Object> rawEvent; | ||
|
|
||
| public BaseEvent(StreamInput input) throws IOException { | ||
| this.type = input.readString(); | ||
| this.timestamp = input.readOptionalLong(); | ||
| if (input.readBoolean()) { | ||
| this.rawEvent = input.readMap(); | ||
| } | ||
| } | ||
|
|
||
| @Override | ||
| public void writeTo(StreamOutput out) throws IOException { | ||
| out.writeString(type); | ||
| out.writeOptionalLong(timestamp); | ||
| if (rawEvent != null) { | ||
| out.writeBoolean(true); | ||
| out.writeMap(rawEvent); | ||
| } else { | ||
| out.writeBoolean(false); | ||
| } | ||
| } | ||
|
|
||
| @Override | ||
| public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
| builder.startObject(); | ||
| builder.field("type", type); | ||
| if (timestamp != null) { | ||
| builder.field("timestamp", timestamp); | ||
| } | ||
| if (rawEvent != null) { | ||
| builder.field("rawEvent", rawEvent); | ||
| } | ||
| addEventSpecificFields(builder, params); | ||
| builder.endObject(); | ||
| return builder; | ||
| } | ||
|
|
||
| protected abstract void addEventSpecificFields(XContentBuilder builder, Params params) throws IOException; | ||
|
|
||
| public String toJsonString() { | ||
| try { | ||
| return toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString(); | ||
| } catch (IOException e) { | ||
| throw new RuntimeException("Failed to serialize event to JSON", e); | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see any comprehensive input sanitization. This can be potential for injection attacks through malformed AG-UI JSON input