Skip to content

Commit d698896

Browse files
authored
AG-UI support in Agent Framework (#4347)
* AG-UI events Signed-off-by: Jiaping Zeng <[email protected]> * convert AG-UI input to agent parameters Signed-off-by: Jiaping Zeng <[email protected]> * add AG-UI tool use Signed-off-by: Jiaping Zeng <[email protected]> * add AG-UI processing Signed-off-by: Jiaping Zeng <[email protected]> * add feature flag Signed-off-by: Jiaping Zeng <[email protected]> * address comments Signed-off-by: Jiaping Zeng <[email protected]> --------- Signed-off-by: Jiaping Zeng <[email protected]>
1 parent f0c6e15 commit d698896

39 files changed

+2362
-120
lines changed

common/src/main/java/org/opensearch/ml/common/MLAgentType.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ public enum MLAgentType {
1111
FLOW,
1212
CONVERSATIONAL,
1313
CONVERSATIONAL_FLOW,
14-
PLAN_EXECUTE_AND_REFLECT;
14+
PLAN_EXECUTE_AND_REFLECT,
15+
AG_UI;
1516

1617
public static MLAgentType from(String value) {
1718
if (value == null) {
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.agui;
7+
8+
/**
9+
* Constants for AG-UI implementation.
10+
*
11+
* Naming Conventions:
12+
* AGUI_ROLE_* - Message role identifiers
13+
* AGUI_PARAM_* - Internal parameter keys
14+
* AGUI_FIELD_* - External API field names
15+
* AGUI_EVENT_* - Event type identifiers
16+
* AGUI_PREFIX_* - ID prefixes for generated identifiers
17+
*/
18+
public final class AGUIConstants {
19+
20+
// ========== Message Roles ==========
21+
22+
/** Role identifier for assistant messages */
23+
public static final String AGUI_ROLE_ASSISTANT = "assistant";
24+
25+
/** Role identifier for user messages */
26+
public static final String AGUI_ROLE_USER = "user";
27+
28+
/** Role identifier for tool result messages */
29+
public static final String AGUI_ROLE_TOOL = "tool";
30+
31+
// ========== Parameter Keys (Internal) ==========
32+
33+
/** Parameter key for AG-UI thread identifier */
34+
public static final String AGUI_PARAM_THREAD_ID = "agui_thread_id";
35+
36+
/** Parameter key for AG-UI run identifier */
37+
public static final String AGUI_PARAM_RUN_ID = "agui_run_id";
38+
39+
/** Parameter key for AG-UI messages array */
40+
public static final String AGUI_PARAM_MESSAGES = "agui_messages";
41+
42+
/** Parameter key for AG-UI tools array */
43+
public static final String AGUI_PARAM_TOOLS = "agui_tools";
44+
45+
/** Parameter key for AG-UI context array */
46+
public static final String AGUI_PARAM_CONTEXT = "agui_context";
47+
48+
/** Parameter key for AG-UI state object */
49+
public static final String AGUI_PARAM_STATE = "agui_state";
50+
51+
/** Parameter key for AG-UI forwarded properties */
52+
public static final String AGUI_PARAM_FORWARDED_PROPS = "agui_forwarded_props";
53+
54+
/** Parameter key for AG-UI tool call results */
55+
public static final String AGUI_PARAM_TOOL_CALL_RESULTS = "agui_tool_call_results";
56+
57+
/** Parameter key for AG-UI assistant tool call messages */
58+
public static final String AGUI_PARAM_ASSISTANT_TOOL_CALL_MESSAGES = "agui_assistant_tool_call_messages";
59+
60+
/** Parameter key for backend tool names (used for filtering) */
61+
public static final String AGUI_PARAM_BACKEND_TOOL_NAMES = "backend_tool_names";
62+
63+
// ========== Field Names (External API) ==========
64+
65+
/** Field name for thread identifier in AG-UI input */
66+
public static final String AGUI_FIELD_THREAD_ID = "threadId";
67+
68+
/** Field name for run identifier in AG-UI input */
69+
public static final String AGUI_FIELD_RUN_ID = "runId";
70+
71+
/** Field name for messages array in AG-UI input */
72+
public static final String AGUI_FIELD_MESSAGES = "messages";
73+
74+
/** Field name for tools array in AG-UI input */
75+
public static final String AGUI_FIELD_TOOLS = "tools";
76+
77+
/** Field name for context array in AG-UI input */
78+
public static final String AGUI_FIELD_CONTEXT = "context";
79+
80+
/** Field name for state object in AG-UI input */
81+
public static final String AGUI_FIELD_STATE = "state";
82+
83+
/** Field name for forwarded properties in AG-UI input */
84+
public static final String AGUI_FIELD_FORWARDED_PROPS = "forwardedProps";
85+
86+
/** Field name for message role */
87+
public static final String AGUI_FIELD_ROLE = "role";
88+
89+
/** Field name for message content */
90+
public static final String AGUI_FIELD_CONTENT = "content";
91+
92+
/** Field name for tool call identifier */
93+
public static final String AGUI_FIELD_TOOL_CALL_ID = "toolCallId";
94+
95+
/** Field name for tool calls array */
96+
public static final String AGUI_FIELD_TOOL_CALLS = "toolCalls";
97+
98+
/** Field name for message identifier */
99+
public static final String AGUI_FIELD_ID = "id";
100+
101+
/** Field name for tool call type */
102+
public static final String AGUI_FIELD_TYPE = "type";
103+
104+
/** Field name for function object in tool calls */
105+
public static final String AGUI_FIELD_FUNCTION = "function";
106+
107+
/** Field name for function name */
108+
public static final String AGUI_FIELD_NAME = "name";
109+
110+
/** Field name for function arguments */
111+
public static final String AGUI_FIELD_ARGUMENTS = "arguments";
112+
}
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.agui;
7+
8+
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_CONTENT;
9+
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_CONTEXT;
10+
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_FORWARDED_PROPS;
11+
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_MESSAGES;
12+
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_ROLE;
13+
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_RUN_ID;
14+
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_STATE;
15+
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_THREAD_ID;
16+
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_TOOLS;
17+
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_TOOL_CALL_ID;
18+
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_CONTEXT;
19+
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_FORWARDED_PROPS;
20+
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_MESSAGES;
21+
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_RUN_ID;
22+
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_STATE;
23+
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_THREAD_ID;
24+
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_TOOLS;
25+
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_TOOL_CALL_RESULTS;
26+
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_ROLE_USER;
27+
import static org.opensearch.ml.common.utils.StringUtils.getStringField;
28+
29+
import java.util.HashMap;
30+
import java.util.List;
31+
import java.util.Map;
32+
33+
import org.opensearch.ml.common.FunctionName;
34+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
35+
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
36+
37+
import com.google.gson.Gson;
38+
import com.google.gson.JsonElement;
39+
import com.google.gson.JsonObject;
40+
import com.google.gson.JsonParser;
41+
42+
import lombok.extern.log4j.Log4j2;
43+
44+
@Log4j2
45+
public class AGUIInputConverter {
46+
47+
private static final Gson gson = new Gson();
48+
49+
public static boolean isAGUIInput(String inputJson) {
50+
try {
51+
JsonObject jsonObj = JsonParser.parseString(inputJson).getAsJsonObject();
52+
53+
// Check required fields exist
54+
if (!jsonObj.has(AGUI_FIELD_THREAD_ID)
55+
|| !jsonObj.has(AGUI_FIELD_RUN_ID)
56+
|| !jsonObj.has(AGUI_FIELD_MESSAGES)
57+
|| !jsonObj.has(AGUI_FIELD_TOOLS)) {
58+
return false;
59+
}
60+
61+
// Validate messages is an array
62+
JsonElement messages = jsonObj.get(AGUI_FIELD_MESSAGES);
63+
if (!messages.isJsonArray()) {
64+
return false;
65+
}
66+
67+
// Validate tools is an array
68+
JsonElement tools = jsonObj.get(AGUI_FIELD_TOOLS);
69+
if (!tools.isJsonArray()) {
70+
return false;
71+
}
72+
73+
return true;
74+
} catch (Exception e) {
75+
log.error("Failed to parse input as JSON for AG-UI detection", e);
76+
return false;
77+
}
78+
}
79+
80+
public static AgentMLInput convertFromAGUIInput(String aguiInputJson, String agentId, String tenantId, boolean isAsync) {
81+
try {
82+
JsonObject aguiInput = JsonParser.parseString(aguiInputJson).getAsJsonObject();
83+
84+
String threadId = getStringField(aguiInput, AGUI_FIELD_THREAD_ID);
85+
String runId = getStringField(aguiInput, AGUI_FIELD_RUN_ID);
86+
JsonElement state = aguiInput.get(AGUI_FIELD_STATE);
87+
JsonElement messages = aguiInput.get(AGUI_FIELD_MESSAGES);
88+
JsonElement tools = aguiInput.get(AGUI_FIELD_TOOLS);
89+
JsonElement context = aguiInput.get(AGUI_FIELD_CONTEXT);
90+
JsonElement forwardedProps = aguiInput.get(AGUI_FIELD_FORWARDED_PROPS);
91+
92+
Map<String, String> parameters = new HashMap<>();
93+
parameters.put(AGUI_PARAM_THREAD_ID, threadId);
94+
parameters.put(AGUI_PARAM_RUN_ID, runId);
95+
96+
if (state != null) {
97+
parameters.put(AGUI_PARAM_STATE, gson.toJson(state));
98+
}
99+
100+
if (messages != null) {
101+
parameters.put(AGUI_PARAM_MESSAGES, gson.toJson(messages));
102+
extractUserQuestion(messages, parameters);
103+
}
104+
105+
if (tools != null) {
106+
parameters.put(AGUI_PARAM_TOOLS, gson.toJson(tools));
107+
}
108+
109+
if (context != null) {
110+
parameters.put(AGUI_PARAM_CONTEXT, gson.toJson(context));
111+
}
112+
113+
if (forwardedProps != null) {
114+
parameters.put(AGUI_PARAM_FORWARDED_PROPS, gson.toJson(forwardedProps));
115+
}
116+
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build();
117+
AgentMLInput agentMLInput = new AgentMLInput(agentId, tenantId, FunctionName.AGENT, inputDataSet, isAsync);
118+
119+
log.debug("Converted AG-UI input to ML-Commons format for agent: {}", agentId);
120+
return agentMLInput;
121+
122+
} catch (Exception e) {
123+
log.error("Failed to convert AG-UI input to ML-Commons format", e);
124+
throw new IllegalArgumentException("Invalid AG-UI input format", e);
125+
}
126+
}
127+
128+
private static void extractUserQuestion(JsonElement messages, Map<String, String> parameters) {
129+
if (messages == null || !messages.isJsonArray()) {
130+
throw new IllegalArgumentException("Invalid AG-UI messages");
131+
}
132+
133+
try {
134+
// Find the last user message to use as the current question
135+
String lastUserMessage = null;
136+
String toolCallResults = null;
137+
138+
for (JsonElement messageElement : messages.getAsJsonArray()) {
139+
if (messageElement.isJsonObject()) {
140+
JsonObject message = messageElement.getAsJsonObject();
141+
JsonElement roleElement = message.get(AGUI_FIELD_ROLE);
142+
JsonElement contentElement = message.get(AGUI_FIELD_CONTENT);
143+
JsonElement toolCallIdElement = message.get(AGUI_FIELD_TOOL_CALL_ID);
144+
145+
if (roleElement != null
146+
&& AGUI_ROLE_USER.equals(roleElement.getAsString())
147+
&& contentElement != null
148+
&& !contentElement.isJsonNull()) {
149+
150+
String content = contentElement.getAsString();
151+
152+
// Check if this is a tool call result (has toolCallId field)
153+
if (toolCallIdElement != null && !toolCallIdElement.isJsonNull()) {
154+
// This is a tool call result from frontend
155+
String toolCallId = toolCallIdElement.getAsString();
156+
157+
// Create tool result structure
158+
JsonObject toolResult = new JsonObject();
159+
toolResult.addProperty("tool_call_id", toolCallId);
160+
toolResult.addProperty("content", content);
161+
162+
toolCallResults = gson.toJson(List.of(toolResult));
163+
log.debug("Extracted tool call result from AG-UI messages: toolCallId={}, content={}", toolCallId, content);
164+
} else {
165+
// Regular user message
166+
lastUserMessage = content;
167+
}
168+
}
169+
}
170+
}
171+
172+
// Set appropriate parameters based on what was found
173+
if (toolCallResults != null) {
174+
parameters.put(AGUI_PARAM_TOOL_CALL_RESULTS, toolCallResults);
175+
log.debug("Detected AG-UI tool call results: {}", toolCallResults);
176+
} else if (lastUserMessage != null) {
177+
parameters.put("question", lastUserMessage);
178+
log.debug("Extracted user question from AG-UI messages: {}", lastUserMessage);
179+
} else {
180+
throw new IllegalArgumentException("No user message found in AG-UI messages");
181+
}
182+
} catch (Exception e) {
183+
throw new IllegalArgumentException("Invalid AG-UI message format", e);
184+
}
185+
}
186+
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.agui;
7+
8+
import java.io.IOException;
9+
import java.util.Map;
10+
11+
import org.opensearch.common.xcontent.XContentFactory;
12+
import org.opensearch.core.common.io.stream.StreamInput;
13+
import org.opensearch.core.common.io.stream.StreamOutput;
14+
import org.opensearch.core.common.io.stream.Writeable;
15+
import org.opensearch.core.xcontent.ToXContent;
16+
import org.opensearch.core.xcontent.ToXContentFragment;
17+
import org.opensearch.core.xcontent.XContentBuilder;
18+
19+
import lombok.AllArgsConstructor;
20+
import lombok.Data;
21+
import lombok.NoArgsConstructor;
22+
23+
@Data
24+
@NoArgsConstructor
25+
@AllArgsConstructor
26+
public abstract class BaseEvent implements ToXContentFragment, Writeable {
27+
28+
protected String type;
29+
protected Long timestamp;
30+
protected Map<String, Object> rawEvent;
31+
32+
public BaseEvent(StreamInput input) throws IOException {
33+
this.type = input.readString();
34+
this.timestamp = input.readOptionalLong();
35+
if (input.readBoolean()) {
36+
this.rawEvent = input.readMap();
37+
}
38+
}
39+
40+
@Override
41+
public void writeTo(StreamOutput out) throws IOException {
42+
out.writeString(type);
43+
out.writeOptionalLong(timestamp);
44+
if (rawEvent != null) {
45+
out.writeBoolean(true);
46+
out.writeMap(rawEvent);
47+
} else {
48+
out.writeBoolean(false);
49+
}
50+
}
51+
52+
@Override
53+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
54+
builder.startObject();
55+
builder.field("type", type);
56+
if (timestamp != null) {
57+
builder.field("timestamp", timestamp);
58+
}
59+
if (rawEvent != null) {
60+
builder.field("rawEvent", rawEvent);
61+
}
62+
addEventSpecificFields(builder, params);
63+
builder.endObject();
64+
return builder;
65+
}
66+
67+
protected abstract void addEventSpecificFields(XContentBuilder builder, Params params) throws IOException;
68+
69+
public String toJsonString() {
70+
try {
71+
return toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString();
72+
} catch (IOException e) {
73+
throw new RuntimeException("Failed to serialize event to JSON", e);
74+
}
75+
}
76+
}

0 commit comments

Comments
 (0)