|
| 1 | +/* |
| 2 | + * Copyright OpenSearch Contributors |
| 3 | + * SPDX-License-Identifier: Apache-2.0 |
| 4 | + */ |
| 5 | + |
| 6 | +package org.opensearch.ml.common.agui; |
| 7 | + |
| 8 | +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_CONTENT; |
| 9 | +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_CONTEXT; |
| 10 | +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_FORWARDED_PROPS; |
| 11 | +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_MESSAGES; |
| 12 | +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_ROLE; |
| 13 | +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_RUN_ID; |
| 14 | +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_STATE; |
| 15 | +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_THREAD_ID; |
| 16 | +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_TOOLS; |
| 17 | +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_TOOL_CALL_ID; |
| 18 | +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_CONTEXT; |
| 19 | +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_FORWARDED_PROPS; |
| 20 | +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_MESSAGES; |
| 21 | +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_RUN_ID; |
| 22 | +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_STATE; |
| 23 | +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_THREAD_ID; |
| 24 | +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_TOOLS; |
| 25 | +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_TOOL_CALL_RESULTS; |
| 26 | +import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_ROLE_USER; |
| 27 | +import static org.opensearch.ml.common.utils.StringUtils.getStringField; |
| 28 | + |
| 29 | +import java.util.HashMap; |
| 30 | +import java.util.List; |
| 31 | +import java.util.Map; |
| 32 | + |
| 33 | +import org.opensearch.ml.common.FunctionName; |
| 34 | +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; |
| 35 | +import org.opensearch.ml.common.input.execute.agent.AgentMLInput; |
| 36 | + |
| 37 | +import com.google.gson.Gson; |
| 38 | +import com.google.gson.JsonElement; |
| 39 | +import com.google.gson.JsonObject; |
| 40 | +import com.google.gson.JsonParser; |
| 41 | + |
| 42 | +import lombok.extern.log4j.Log4j2; |
| 43 | + |
| 44 | +@Log4j2 |
| 45 | +public class AGUIInputConverter { |
| 46 | + |
| 47 | + private static final Gson gson = new Gson(); |
| 48 | + |
| 49 | + public static boolean isAGUIInput(String inputJson) { |
| 50 | + try { |
| 51 | + JsonObject jsonObj = JsonParser.parseString(inputJson).getAsJsonObject(); |
| 52 | + |
| 53 | + // Check required fields exist |
| 54 | + if (!jsonObj.has(AGUI_FIELD_THREAD_ID) |
| 55 | + || !jsonObj.has(AGUI_FIELD_RUN_ID) |
| 56 | + || !jsonObj.has(AGUI_FIELD_MESSAGES) |
| 57 | + || !jsonObj.has(AGUI_FIELD_TOOLS)) { |
| 58 | + return false; |
| 59 | + } |
| 60 | + |
| 61 | + // Validate messages is an array |
| 62 | + JsonElement messages = jsonObj.get(AGUI_FIELD_MESSAGES); |
| 63 | + if (!messages.isJsonArray()) { |
| 64 | + return false; |
| 65 | + } |
| 66 | + |
| 67 | + // Validate tools is an array |
| 68 | + JsonElement tools = jsonObj.get(AGUI_FIELD_TOOLS); |
| 69 | + if (!tools.isJsonArray()) { |
| 70 | + return false; |
| 71 | + } |
| 72 | + |
| 73 | + return true; |
| 74 | + } catch (Exception e) { |
| 75 | + log.error("Failed to parse input as JSON for AG-UI detection", e); |
| 76 | + return false; |
| 77 | + } |
| 78 | + } |
| 79 | + |
| 80 | + public static AgentMLInput convertFromAGUIInput(String aguiInputJson, String agentId, String tenantId, boolean isAsync) { |
| 81 | + try { |
| 82 | + JsonObject aguiInput = JsonParser.parseString(aguiInputJson).getAsJsonObject(); |
| 83 | + |
| 84 | + String threadId = getStringField(aguiInput, AGUI_FIELD_THREAD_ID); |
| 85 | + String runId = getStringField(aguiInput, AGUI_FIELD_RUN_ID); |
| 86 | + JsonElement state = aguiInput.get(AGUI_FIELD_STATE); |
| 87 | + JsonElement messages = aguiInput.get(AGUI_FIELD_MESSAGES); |
| 88 | + JsonElement tools = aguiInput.get(AGUI_FIELD_TOOLS); |
| 89 | + JsonElement context = aguiInput.get(AGUI_FIELD_CONTEXT); |
| 90 | + JsonElement forwardedProps = aguiInput.get(AGUI_FIELD_FORWARDED_PROPS); |
| 91 | + |
| 92 | + Map<String, String> parameters = new HashMap<>(); |
| 93 | + parameters.put(AGUI_PARAM_THREAD_ID, threadId); |
| 94 | + parameters.put(AGUI_PARAM_RUN_ID, runId); |
| 95 | + |
| 96 | + if (state != null) { |
| 97 | + parameters.put(AGUI_PARAM_STATE, gson.toJson(state)); |
| 98 | + } |
| 99 | + |
| 100 | + if (messages != null) { |
| 101 | + parameters.put(AGUI_PARAM_MESSAGES, gson.toJson(messages)); |
| 102 | + extractUserQuestion(messages, parameters); |
| 103 | + } |
| 104 | + |
| 105 | + if (tools != null) { |
| 106 | + parameters.put(AGUI_PARAM_TOOLS, gson.toJson(tools)); |
| 107 | + } |
| 108 | + |
| 109 | + if (context != null) { |
| 110 | + parameters.put(AGUI_PARAM_CONTEXT, gson.toJson(context)); |
| 111 | + } |
| 112 | + |
| 113 | + if (forwardedProps != null) { |
| 114 | + parameters.put(AGUI_PARAM_FORWARDED_PROPS, gson.toJson(forwardedProps)); |
| 115 | + } |
| 116 | + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build(); |
| 117 | + AgentMLInput agentMLInput = new AgentMLInput(agentId, tenantId, FunctionName.AGENT, inputDataSet, isAsync); |
| 118 | + |
| 119 | + log.debug("Converted AG-UI input to ML-Commons format for agent: {}", agentId); |
| 120 | + return agentMLInput; |
| 121 | + |
| 122 | + } catch (Exception e) { |
| 123 | + log.error("Failed to convert AG-UI input to ML-Commons format", e); |
| 124 | + throw new IllegalArgumentException("Invalid AG-UI input format", e); |
| 125 | + } |
| 126 | + } |
| 127 | + |
| 128 | + private static void extractUserQuestion(JsonElement messages, Map<String, String> parameters) { |
| 129 | + if (messages == null || !messages.isJsonArray()) { |
| 130 | + throw new IllegalArgumentException("Invalid AG-UI messages"); |
| 131 | + } |
| 132 | + |
| 133 | + try { |
| 134 | + // Find the last user message to use as the current question |
| 135 | + String lastUserMessage = null; |
| 136 | + String toolCallResults = null; |
| 137 | + |
| 138 | + for (JsonElement messageElement : messages.getAsJsonArray()) { |
| 139 | + if (messageElement.isJsonObject()) { |
| 140 | + JsonObject message = messageElement.getAsJsonObject(); |
| 141 | + JsonElement roleElement = message.get(AGUI_FIELD_ROLE); |
| 142 | + JsonElement contentElement = message.get(AGUI_FIELD_CONTENT); |
| 143 | + JsonElement toolCallIdElement = message.get(AGUI_FIELD_TOOL_CALL_ID); |
| 144 | + |
| 145 | + if (roleElement != null |
| 146 | + && AGUI_ROLE_USER.equals(roleElement.getAsString()) |
| 147 | + && contentElement != null |
| 148 | + && !contentElement.isJsonNull()) { |
| 149 | + |
| 150 | + String content = contentElement.getAsString(); |
| 151 | + |
| 152 | + // Check if this is a tool call result (has toolCallId field) |
| 153 | + if (toolCallIdElement != null && !toolCallIdElement.isJsonNull()) { |
| 154 | + // This is a tool call result from frontend |
| 155 | + String toolCallId = toolCallIdElement.getAsString(); |
| 156 | + |
| 157 | + // Create tool result structure |
| 158 | + JsonObject toolResult = new JsonObject(); |
| 159 | + toolResult.addProperty("tool_call_id", toolCallId); |
| 160 | + toolResult.addProperty("content", content); |
| 161 | + |
| 162 | + toolCallResults = gson.toJson(List.of(toolResult)); |
| 163 | + log.debug("Extracted tool call result from AG-UI messages: toolCallId={}, content={}", toolCallId, content); |
| 164 | + } else { |
| 165 | + // Regular user message |
| 166 | + lastUserMessage = content; |
| 167 | + } |
| 168 | + } |
| 169 | + } |
| 170 | + } |
| 171 | + |
| 172 | + // Set appropriate parameters based on what was found |
| 173 | + if (toolCallResults != null) { |
| 174 | + parameters.put(AGUI_PARAM_TOOL_CALL_RESULTS, toolCallResults); |
| 175 | + log.debug("Detected AG-UI tool call results: {}", toolCallResults); |
| 176 | + } else if (lastUserMessage != null) { |
| 177 | + parameters.put("question", lastUserMessage); |
| 178 | + log.debug("Extracted user question from AG-UI messages: {}", lastUserMessage); |
| 179 | + } else { |
| 180 | + throw new IllegalArgumentException("No user message found in AG-UI messages"); |
| 181 | + } |
| 182 | + } catch (Exception e) { |
| 183 | + throw new IllegalArgumentException("Invalid AG-UI message format", e); |
| 184 | + } |
| 185 | + } |
| 186 | +} |
0 commit comments