Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ public enum MLAgentType {
FLOW,
CONVERSATIONAL,
CONVERSATIONAL_FLOW,
PLAN_EXECUTE_AND_REFLECT;
PLAN_EXECUTE_AND_REFLECT,
AG_UI;

public static MLAgentType from(String value) {
if (value == null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.agui;

/**
* Constants for AG-UI implementation.
*
* Naming Conventions:
* AGUI_ROLE_* - Message role identifiers
* AGUI_PARAM_* - Internal parameter keys
* AGUI_FIELD_* - External API field names
* AGUI_EVENT_* - Event type identifiers
* AGUI_PREFIX_* - ID prefixes for generated identifiers
*/
public final class AGUIConstants {

// ========== Message Roles ==========

/** Role identifier for assistant messages */
public static final String AGUI_ROLE_ASSISTANT = "assistant";

/** Role identifier for user messages */
public static final String AGUI_ROLE_USER = "user";

/** Role identifier for tool result messages */
public static final String AGUI_ROLE_TOOL = "tool";

// ========== Parameter Keys (Internal) ==========

/** Parameter key for AG-UI thread identifier */
public static final String AGUI_PARAM_THREAD_ID = "agui_thread_id";

/** Parameter key for AG-UI run identifier */
public static final String AGUI_PARAM_RUN_ID = "agui_run_id";

/** Parameter key for AG-UI messages array */
public static final String AGUI_PARAM_MESSAGES = "agui_messages";

/** Parameter key for AG-UI tools array */
public static final String AGUI_PARAM_TOOLS = "agui_tools";

/** Parameter key for AG-UI context array */
public static final String AGUI_PARAM_CONTEXT = "agui_context";

/** Parameter key for AG-UI state object */
public static final String AGUI_PARAM_STATE = "agui_state";

/** Parameter key for AG-UI forwarded properties */
public static final String AGUI_PARAM_FORWARDED_PROPS = "agui_forwarded_props";

/** Parameter key for AG-UI tool call results */
public static final String AGUI_PARAM_TOOL_CALL_RESULTS = "agui_tool_call_results";

/** Parameter key for AG-UI assistant tool call messages */
public static final String AGUI_PARAM_ASSISTANT_TOOL_CALL_MESSAGES = "agui_assistant_tool_call_messages";

/** Parameter key for backend tool names (used for filtering) */
public static final String AGUI_PARAM_BACKEND_TOOL_NAMES = "backend_tool_names";

// ========== Field Names (External API) ==========

/** Field name for thread identifier in AG-UI input */
public static final String AGUI_FIELD_THREAD_ID = "threadId";

/** Field name for run identifier in AG-UI input */
public static final String AGUI_FIELD_RUN_ID = "runId";

/** Field name for messages array in AG-UI input */
public static final String AGUI_FIELD_MESSAGES = "messages";

/** Field name for tools array in AG-UI input */
public static final String AGUI_FIELD_TOOLS = "tools";

/** Field name for context array in AG-UI input */
public static final String AGUI_FIELD_CONTEXT = "context";

/** Field name for state object in AG-UI input */
public static final String AGUI_FIELD_STATE = "state";

/** Field name for forwarded properties in AG-UI input */
public static final String AGUI_FIELD_FORWARDED_PROPS = "forwardedProps";

/** Field name for message role */
public static final String AGUI_FIELD_ROLE = "role";

/** Field name for message content */
public static final String AGUI_FIELD_CONTENT = "content";

/** Field name for tool call identifier */
public static final String AGUI_FIELD_TOOL_CALL_ID = "toolCallId";

/** Field name for tool calls array */
public static final String AGUI_FIELD_TOOL_CALLS = "toolCalls";

/** Field name for message identifier */
public static final String AGUI_FIELD_ID = "id";

/** Field name for tool call type */
public static final String AGUI_FIELD_TYPE = "type";

/** Field name for function object in tool calls */
public static final String AGUI_FIELD_FUNCTION = "function";

/** Field name for function name */
public static final String AGUI_FIELD_NAME = "name";

/** Field name for function arguments */
public static final String AGUI_FIELD_ARGUMENTS = "arguments";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.agui;

import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_CONTENT;
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_CONTEXT;
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_FORWARDED_PROPS;
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_MESSAGES;
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_ROLE;
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_RUN_ID;
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_STATE;
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_THREAD_ID;
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_TOOLS;
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_TOOL_CALL_ID;
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_CONTEXT;
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_FORWARDED_PROPS;
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_MESSAGES;
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_RUN_ID;
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_STATE;
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_THREAD_ID;
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_TOOLS;
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_TOOL_CALL_RESULTS;
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_ROLE_USER;
import static org.opensearch.ml.common.utils.StringUtils.getStringField;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;

import com.google.gson.Gson;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class AGUIInputConverter {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any comprehensive input sanitization. This can be potential for injection attacks through malformed AG-UI JSON input


private static final Gson gson = new Gson();

public static boolean isAGUIInput(String inputJson) {
try {
JsonObject jsonObj = JsonParser.parseString(inputJson).getAsJsonObject();

// Check required fields exist
if (!jsonObj.has(AGUI_FIELD_THREAD_ID)
|| !jsonObj.has(AGUI_FIELD_RUN_ID)
|| !jsonObj.has(AGUI_FIELD_MESSAGES)
|| !jsonObj.has(AGUI_FIELD_TOOLS)) {
return false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a log?

}

// Validate messages is an array
JsonElement messages = jsonObj.get(AGUI_FIELD_MESSAGES);
if (!messages.isJsonArray()) {
return false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, let's add logs?

}

// Validate tools is an array
JsonElement tools = jsonObj.get(AGUI_FIELD_TOOLS);
if (!tools.isJsonArray()) {
return false;
}

return true;
} catch (Exception e) {
log.error("Failed to parse input as JSON for AG-UI detection", e);
return false;
}
}

public static AgentMLInput convertFromAGUIInput(String aguiInputJson, String agentId, String tenantId, boolean isAsync) {
try {
JsonObject aguiInput = JsonParser.parseString(aguiInputJson).getAsJsonObject();

String threadId = getStringField(aguiInput, AGUI_FIELD_THREAD_ID);
String runId = getStringField(aguiInput, AGUI_FIELD_RUN_ID);
JsonElement state = aguiInput.get(AGUI_FIELD_STATE);
JsonElement messages = aguiInput.get(AGUI_FIELD_MESSAGES);
JsonElement tools = aguiInput.get(AGUI_FIELD_TOOLS);
JsonElement context = aguiInput.get(AGUI_FIELD_CONTEXT);
JsonElement forwardedProps = aguiInput.get(AGUI_FIELD_FORWARDED_PROPS);

Map<String, String> parameters = new HashMap<>();
parameters.put(AGUI_PARAM_THREAD_ID, threadId);
parameters.put(AGUI_PARAM_RUN_ID, runId);

if (state != null) {
parameters.put(AGUI_PARAM_STATE, gson.toJson(state));
}

if (messages != null) {
parameters.put(AGUI_PARAM_MESSAGES, gson.toJson(messages));
extractUserQuestion(messages, parameters);
}

if (tools != null) {
parameters.put(AGUI_PARAM_TOOLS, gson.toJson(tools));
}

if (context != null) {
parameters.put(AGUI_PARAM_CONTEXT, gson.toJson(context));
}

if (forwardedProps != null) {
parameters.put(AGUI_PARAM_FORWARDED_PROPS, gson.toJson(forwardedProps));
}
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build();
AgentMLInput agentMLInput = new AgentMLInput(agentId, tenantId, FunctionName.AGENT, inputDataSet, isAsync);

log.debug("Converted AG-UI input to ML-Commons format for agent: {}", agentId);
return agentMLInput;

} catch (Exception e) {
log.error("Failed to convert AG-UI input to ML-Commons format", e);
throw new IllegalArgumentException("Invalid AG-UI input format", e);
}
}

private static void extractUserQuestion(JsonElement messages, Map<String, String> parameters) {
if (messages == null || !messages.isJsonArray()) {
throw new IllegalArgumentException("Invalid AG-UI messages");
}

try {
// Find the last user message to use as the current question
String lastUserMessage = null;
String toolCallResults = null;

for (JsonElement messageElement : messages.getAsJsonArray()) {
if (messageElement.isJsonObject()) {
JsonObject message = messageElement.getAsJsonObject();
JsonElement roleElement = message.get(AGUI_FIELD_ROLE);
JsonElement contentElement = message.get(AGUI_FIELD_CONTENT);
JsonElement toolCallIdElement = message.get(AGUI_FIELD_TOOL_CALL_ID);

if (roleElement != null
&& AGUI_ROLE_USER.equals(roleElement.getAsString())
&& contentElement != null
&& !contentElement.isJsonNull()) {

String content = contentElement.getAsString();

// Check if this is a tool call result (has toolCallId field)
if (toolCallIdElement != null && !toolCallIdElement.isJsonNull()) {
// This is a tool call result from frontend
String toolCallId = toolCallIdElement.getAsString();

// Create tool result structure
JsonObject toolResult = new JsonObject();
toolResult.addProperty("tool_call_id", toolCallId);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we put all these in a static variable?

toolResult.addProperty("content", content);

toolCallResults = gson.toJson(List.of(toolResult));
log.debug("Extracted tool call result from AG-UI messages: toolCallId={}, content={}", toolCallId, content);
} else {
// Regular user message
lastUserMessage = content;
}
}
}
}

// Set appropriate parameters based on what was found
if (toolCallResults != null) {
parameters.put(AGUI_PARAM_TOOL_CALL_RESULTS, toolCallResults);
log.debug("Detected AG-UI tool call results: {}", toolCallResults);
} else if (lastUserMessage != null) {
parameters.put("question", lastUserMessage);
log.debug("Extracted user question from AG-UI messages: {}", lastUserMessage);
} else {
throw new IllegalArgumentException("No user message found in AG-UI messages");
}
Comment on lines +173 to +181
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. nit: check if both are null and then throw error early
  2. is there a case where user message is found and toolCallResult is found? In that case, we only process the result?

} catch (Exception e) {
throw new IllegalArgumentException("Invalid AG-UI message format", e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.agui;

import java.io.IOException;
import java.util.Map;

import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.ToXContentFragment;
import org.opensearch.core.xcontent.XContentBuilder;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

@Data
@NoArgsConstructor
@AllArgsConstructor
public abstract class BaseEvent implements ToXContentFragment, Writeable {

protected String type;
protected Long timestamp;
protected Map<String, Object> rawEvent;

public BaseEvent(StreamInput input) throws IOException {
this.type = input.readString();
this.timestamp = input.readOptionalLong();
if (input.readBoolean()) {
this.rawEvent = input.readMap();
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(type);
out.writeOptionalLong(timestamp);
if (rawEvent != null) {
out.writeBoolean(true);
out.writeMap(rawEvent);
} else {
out.writeBoolean(false);
}
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field("type", type);
if (timestamp != null) {
builder.field("timestamp", timestamp);
}
if (rawEvent != null) {
builder.field("rawEvent", rawEvent);
}
addEventSpecificFields(builder, params);
builder.endObject();
return builder;
}

protected abstract void addEventSpecificFields(XContentBuilder builder, Params params) throws IOException;

public String toJsonString() {
try {
return toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString();
} catch (IOException e) {
throw new RuntimeException("Failed to serialize event to JSON", e);
}
}
}
Loading
Loading