Skip to content

Commit 3bddf55

Browse files
committed
address comments
Signed-off-by: Jiaping Zeng <[email protected]>
1 parent c7722c0 commit 3bddf55

File tree

6 files changed

+45
-15
lines changed

6 files changed

+45
-15
lines changed

common/src/main/java/org/opensearch/ml/common/agui/AGUIInputConverter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ public static boolean isAGUIInput(String inputJson) {
7272

7373
return true;
7474
} catch (Exception e) {
75-
log.debug("Failed to parse input as JSON for AG-UI detection", e);
75+
log.error("Failed to parse input as JSON for AG-UI detection", e);
7676
return false;
7777
}
7878
}

common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_ENABLED;
99
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED;
10+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AG_UI_ENABLED;
1011
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED;
1112
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED;
1213
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_EXECUTE_TOOL_ENABLED;
@@ -22,7 +23,6 @@
2223
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED;
2324
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED;
2425
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_STREAM_ENABLED;
25-
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AG_UI_ENABLED;
2626

2727
import java.util.ArrayList;
2828
import java.util.List;

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,6 +1066,8 @@ public static Map<String, Tool> wrapFrontendToolsAsToolObjects(List<Map<String,
10661066
}
10671067

10681068
return wrappedTools;
1069+
}
1070+
10691071
public static Map<String, Object> createMemoryParams(
10701072
String question,
10711073
String memoryId,

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAGUIAgentRunner.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
import org.opensearch.core.action.ActionListener;
3838
import org.opensearch.core.xcontent.NamedXContentRegistry;
3939
import org.opensearch.ml.common.agent.MLAgent;
40-
import org.opensearch.ml.common.spi.memory.Memory;
40+
import org.opensearch.ml.common.memory.Memory;
4141
import org.opensearch.ml.common.spi.tools.Tool;
4242
import org.opensearch.ml.engine.encryptor.Encryptor;
4343
import org.opensearch.ml.engine.function_calling.FunctionCalling;
@@ -362,6 +362,7 @@ private void processAGUIMessages(Map<String, String> params, String llmInterface
362362
}
363363
} catch (Exception e) {
364364
log.error("Failed to process AG-UI messages to chat history", e);
365+
throw new IllegalArgumentException("Failed to process AG-UI messages to chat history", e);
365366
}
366367
}
367368

@@ -405,6 +406,7 @@ private void processAGUIContext(Map<String, String> params) {
405406

406407
} catch (Exception e) {
407408
log.error("Failed to process AG-UI context", e);
409+
throw new IllegalArgumentException("Failed to process AG-UI context", e);
408410
}
409411
}
410412
}

plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
import static org.opensearch.core.rest.RestStatus.BAD_REQUEST;
99
import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;
1010
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
11+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AG_UI_DISABLED_MESSAGE;
1112
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_EXECUTE_TOOL_DISABLED_MESSAGE;
1213
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
1314
import static org.opensearch.ml.utils.MLExceptionUtils.AGENT_FRAMEWORK_DISABLED_ERR_MSG;
14-
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AG_UI_DISABLED_MESSAGE;
1515
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_AGENT_ID;
1616
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM;
1717
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TOOL_NAME;
@@ -140,6 +140,13 @@ MLExecuteTaskRequest getRequest(RestRequest request) throws IOException {
140140
);
141141
} else {
142142
input = MLInput.parse(parser, functionName.name());
143+
144+
if (!(input instanceof AgentMLInput)) {
145+
throw new IllegalArgumentException(
146+
String.format("Invalid input type. Expected: AgentMLInput, Received: %s", input.getClass().getSimpleName())
147+
);
148+
}
149+
143150
((AgentMLInput) input).setAgentId(agentId);
144151
((AgentMLInput) input).setTenantId(tenantId);
145152
((AgentMLInput) input).setIsAsync(async);

plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_BACKEND_TOOL_NAMES;
1313
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_RUN_ID;
1414
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_THREAD_ID;
15+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AG_UI_DISABLED_MESSAGE;
1516
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
1617
import static org.opensearch.ml.plugin.MachineLearningPlugin.STREAM_EXECUTE_THREAD_POOL;
1718
import static org.opensearch.ml.utils.MLExceptionUtils.AGENT_FRAMEWORK_DISABLED_ERR_MSG;
1819
import static org.opensearch.ml.utils.MLExceptionUtils.STREAM_DISABLED_ERR_MSG;
19-
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AG_UI_DISABLED_MESSAGE;
2020
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_AGENT_ID;
2121
import static org.opensearch.ml.utils.RestActionUtils.isAsync;
2222
import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID;
@@ -182,22 +182,41 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
182182
BaseEvent runStartedEvent = new RunStartedEvent(threadId, runId);
183183
HttpChunk startChunk = createHttpChunk("data: " + runStartedEvent.toJsonString() + "\n\n", false);
184184
channel.sendChunk(startChunk);
185-
log.debug("RestMLExecuteStreamAction: Sent RUN_STARTED event - threadId={}, runId={}", threadId, runId);
185+
log.debug("AG-UI: RestMLExecuteStreamAction: Sent RUN_STARTED event - threadId={}, runId={}", threadId, runId);
186186
}
187187

188188
// Extract backend tool names from agent configuration and add to request for AG-UI filtering
189189
List<String> backendToolNames = extractBackendToolNamesFromAgent(agent);
190190
if (isAGUI && !backendToolNames.isEmpty()) {
191191
// Add backend tool names to request parameters so they're available during streaming
192-
RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) ((AgentMLInput) mlExecuteTaskRequest
193-
.getInput()).getInputDataset();
194-
inputDataSet.getParameters().put(AGUI_PARAM_BACKEND_TOOL_NAMES, new Gson().toJson(backendToolNames));
195-
log
196-
.info(
197-
"AG-UI: Added {} backend tool names to request for streaming filter: {}",
198-
backendToolNames.size(),
199-
backendToolNames
200-
);
192+
try {
193+
if (!(mlExecuteTaskRequest.getInput() instanceof AgentMLInput)) {
194+
throw new IllegalArgumentException(
195+
"Invalid input type. Expected: AgentMLInput, Received: "
196+
+ mlExecuteTaskRequest.getInput().getClass().getSimpleName()
197+
);
198+
}
199+
AgentMLInput agentInput = (AgentMLInput) mlExecuteTaskRequest.getInput();
200+
201+
if (!(agentInput.getInputDataset() instanceof RemoteInferenceInputDataSet)) {
202+
throw new IllegalArgumentException(
203+
"Invalid dataset type. Expected: RemoteInferenceInputDataSet, Received: "
204+
+ agentInput.getInputDataset().getClass().getSimpleName()
205+
);
206+
}
207+
RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) agentInput.getInputDataset();
208+
209+
inputDataSet.getParameters().put(AGUI_PARAM_BACKEND_TOOL_NAMES, new Gson().toJson(backendToolNames));
210+
log
211+
.info(
212+
"AG-UI: Added {} backend tool names to request for streaming filter: {}",
213+
backendToolNames.size(),
214+
backendToolNames
215+
);
216+
} catch (ClassCastException e) {
217+
log.error("Failed to cast input types for backend tool names extraction", e);
218+
throw new IllegalArgumentException("Invalid input type configuration for AG-UI request", e);
219+
}
201220
}
202221

203222
final CompletableFuture<HttpChunk> future = new CompletableFuture<>();

0 commit comments

Comments
 (0)