Skip to content

Commit a00b7de

Browse files
authored
[PER Agent] Performance: Process additional info as traces during reflection if provided (#4369)
1 parent 5ac2984 commit a00b7de

File tree

3 files changed

+142
-9
lines changed

3 files changed

+142
-9
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
99
import static org.opensearch.ml.common.MLTask.TASK_ID_FIELD;
10+
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD;
1011
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD;
1112
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD;
1213
import static org.opensearch.ml.common.utils.MLTaskUtils.updateMLTaskDirectly;
@@ -460,9 +461,9 @@ private void executePlanningLoop(
460461
results.put(PARENT_INTERACTION_ID_FIELD, tensor.getResult());
461462
break;
462463
default:
463-
Map<String, ?> dataMap = tensor.getDataAsMap();
464-
if (dataMap != null && dataMap.containsKey(RESPONSE_FIELD)) {
465-
results.put(STEP_RESULT_FIELD, (String) dataMap.get(RESPONSE_FIELD));
464+
String stepResult = parseTensorDataMap(tensor);
465+
if (stepResult != null) {
466+
results.put(STEP_RESULT_FIELD, stepResult);
466467
}
467468
}
468469
});
@@ -502,8 +503,17 @@ private void executePlanningLoop(
502503
}, e -> log.error("Failed to update task {} with executor memory ID", taskId, e)));
503504
}
504505

505-
completedSteps.add(String.format("\nStep %d: %s\n", stepsExecuted + 1, stepToExecute));
506-
completedSteps.add(String.format("\nStep %d Result: %s\n", stepsExecuted + 1, results.get(STEP_RESULT_FIELD)));
506+
completedSteps.add(String.format("\n<step-%d>\n%s\n</step-%d>\n", stepsExecuted + 1, stepToExecute, stepsExecuted + 1));
507+
completedSteps
508+
.add(
509+
String
510+
.format(
511+
"\n<step-%d-result>\n%s\n</step-%d-result>\n",
512+
stepsExecuted + 1,
513+
results.get(STEP_RESULT_FIELD),
514+
stepsExecuted + 1
515+
)
516+
);
507517

508518
saveTraceData(
509519
(ConversationIndexMemory) memory,
@@ -544,6 +554,39 @@ private void executePlanningLoop(
544554
client.execute(MLPredictionTaskAction.INSTANCE, request, planListener);
545555
}
546556

557+
@VisibleForTesting
558+
String parseTensorDataMap(ModelTensor tensor) {
559+
Map<String, ?> dataMap = tensor.getDataAsMap();
560+
if (dataMap == null) {
561+
return null;
562+
}
563+
564+
StringBuilder stepResult = new StringBuilder();
565+
if (dataMap.containsKey(RESPONSE_FIELD)) {
566+
stepResult.append((String) dataMap.get(RESPONSE_FIELD));
567+
}
568+
569+
if (dataMap.containsKey(INTERACTIONS_ADDITIONAL_INFO_FIELD)) {
570+
stepResult.append("\n<step-traces>\n");
571+
((Map<String, Object>) dataMap.get(INTERACTIONS_ADDITIONAL_INFO_FIELD))
572+
.forEach(
573+
(key, value) -> stepResult
574+
.append("<")
575+
.append(key)
576+
.append(">")
577+
.append("\n")
578+
.append(value)
579+
.append("\n")
580+
.append("</")
581+
.append(key)
582+
.append(">")
583+
);
584+
stepResult.append("\n</step-traces>\n");
585+
}
586+
587+
return stepResult.toString();
588+
}
589+
547590
@VisibleForTesting
548591
Map<String, Object> parseLLMOutput(Map<String, String> allParams, ModelTensorOutput modelTensorOutput) {
549592
Map<String, Object> modelOutput = new HashMap<>();

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ public class PromptTemplate {
2828
+ "${parameters."
2929
+ PLANNER_PROMPT_FIELD
3030
+ "} \n"
31-
+ "Objective: ${parameters."
31+
+ "Objective: ```${parameters."
3232
+ USER_PROMPT_FIELD
33-
+ "} \n\nRemember: Respond only in JSON format following the required schema.";
33+
+ "}``` \n\nRemember: Respond only in JSON format following the required schema.";
3434

3535
public static final String DEFAULT_REFLECT_PROMPT_TEMPLATE = "${parameters."
3636
+ DEFAULT_PROMPT_TOOLS_FIELD
@@ -41,10 +41,10 @@ public class PromptTemplate {
4141
+ "Objective: ```${parameters."
4242
+ USER_PROMPT_FIELD
4343
+ "}```\n\n"
44-
+ "Original plan:\n[${parameters."
44+
+ "Previous plan:\n[${parameters."
4545
+ STEPS_FIELD
4646
+ "}] \n\n"
47-
+ "You have currently executed the following steps from the original plan: \n[${parameters."
47+
+ "You have currently executed the following steps: \n[${parameters."
4848
+ COMPLETED_STEPS_FIELD
4949
+ "}] \n\n"
5050
+ "${parameters."

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import static org.junit.Assert.assertEquals;
99
import static org.junit.Assert.assertFalse;
1010
import static org.junit.Assert.assertNotNull;
11+
import static org.junit.Assert.assertNull;
1112
import static org.junit.Assert.assertThrows;
1213
import static org.junit.Assert.assertTrue;
1314
import static org.mockito.ArgumentMatchers.any;
@@ -677,6 +678,42 @@ public void testSaveAndReturnFinalResult() {
677678
assertEquals(finalResult, secondModelTensorList.get(0).getDataAsMap().get("response"));
678679
}
679680

681+
@Test
682+
public void testParseTensorDataMap() {
683+
// Test with response only
684+
Map<String, Object> dataMap = new HashMap<>();
685+
dataMap.put("response", "test response");
686+
ModelTensor tensor = ModelTensor.builder().dataAsMap(dataMap).build();
687+
688+
String result = mlPlanExecuteAndReflectAgentRunner.parseTensorDataMap(tensor);
689+
assertEquals("test response", result);
690+
691+
// Test with additional info
692+
Map<String, Object> additionalInfo = new HashMap<>();
693+
additionalInfo.put("trace1", "content1");
694+
additionalInfo.put("trace2", "content2");
695+
dataMap.put("additional_info", additionalInfo);
696+
697+
result = mlPlanExecuteAndReflectAgentRunner.parseTensorDataMap(tensor);
698+
assertTrue(result.contains("test response"));
699+
assertTrue(result.contains("<step-traces>"));
700+
assertTrue(result.contains("<trace1>\ncontent1\n</trace1>"));
701+
assertTrue(result.contains("<trace2>\ncontent2\n</trace2>"));
702+
assertTrue(result.contains("</step-traces>"));
703+
704+
// Test with null dataMap
705+
ModelTensor nullTensor = ModelTensor.builder().build();
706+
assertNull(mlPlanExecuteAndReflectAgentRunner.parseTensorDataMap(nullTensor));
707+
708+
// No response field
709+
Map<String, Object> noResponseMap = new HashMap<>();
710+
noResponseMap.put("additional_info", additionalInfo);
711+
ModelTensor noResponseTensor = ModelTensor.builder().dataAsMap(noResponseMap).build();
712+
result = mlPlanExecuteAndReflectAgentRunner.parseTensorDataMap(noResponseTensor);
713+
assertTrue(result.contains("<step-traces>"));
714+
assertFalse(result.contains("test response"));
715+
}
716+
680717
@Test
681718
public void testUpdateTaskWithExecutorAgentInfo() {
682719
MLAgent mlAgent = createMLAgentWithTools();
@@ -765,4 +802,57 @@ public void testUpdateTaskWithExecutorAgentInfo() {
765802
mlTaskUtilsMockedStatic.verify(() -> MLTaskUtils.updateMLTaskDirectly(eq(taskId), eq(taskUpdates), eq(client), any()));
766803
}
767804
}
805+
806+
@Test
807+
public void testExecutionWithNullStepResult() {
808+
MLAgent mlAgent = createMLAgentWithTools();
809+
810+
// Setup LLM response for planning phase - returns steps to execute
811+
doAnswer(invocation -> {
812+
ActionListener<Object> listener = invocation.getArgument(2);
813+
ModelTensor modelTensor = ModelTensor
814+
.builder()
815+
.dataAsMap(ImmutableMap.of("response", "{\"steps\":[\"step1\"], \"result\":\"\"}"))
816+
.build();
817+
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
818+
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
819+
when(mlTaskResponse.getOutput()).thenReturn(mlModelTensorOutput);
820+
listener.onResponse(mlTaskResponse);
821+
return null;
822+
}).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any());
823+
824+
// Setup executor response with tensor that has null dataMap - this will hit line 465
825+
doAnswer(invocation -> {
826+
ActionListener<Object> listener = invocation.getArgument(2);
827+
ModelTensor memoryIdTensor = ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result("test_memory_id").build();
828+
ModelTensor parentIdTensor = ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result("test_parent_id").build();
829+
// This tensor will return null from parseTensorDataMap, hitting the stepResult != null check
830+
ModelTensor nullDataTensor = ModelTensor.builder().name("other").build();
831+
ModelTensors modelTensors = ModelTensors
832+
.builder()
833+
.mlModelTensors(Arrays.asList(memoryIdTensor, parentIdTensor, nullDataTensor))
834+
.build();
835+
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
836+
when(mlExecuteTaskResponse.getOutput()).thenReturn(mlModelTensorOutput);
837+
listener.onResponse(mlExecuteTaskResponse);
838+
return null;
839+
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(MLExecuteTaskRequest.class), any());
840+
841+
Map<String, String> params = new HashMap<>();
842+
params.put("question", "test question");
843+
params.put("parent_interaction_id", "test_parent_interaction_id");
844+
845+
// Capture the exception in the listener
846+
doAnswer(invocation -> {
847+
Exception e = invocation.getArgument(0);
848+
assertTrue(e instanceof IllegalStateException);
849+
assertEquals("No valid response found in ReAct agent output", e.getMessage());
850+
return null;
851+
}).when(agentActionListener).onFailure(any());
852+
853+
mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener);
854+
855+
// Verify that onFailure was called with the expected exception
856+
verify(agentActionListener).onFailure(any(IllegalStateException.class));
857+
}
768858
}

0 commit comments

Comments
 (0)