|
8 | 8 | import static org.junit.Assert.assertEquals; |
9 | 9 | import static org.junit.Assert.assertFalse; |
10 | 10 | import static org.junit.Assert.assertNotNull; |
| 11 | +import static org.junit.Assert.assertNull; |
11 | 12 | import static org.junit.Assert.assertThrows; |
12 | 13 | import static org.junit.Assert.assertTrue; |
13 | 14 | import static org.mockito.ArgumentMatchers.any; |
@@ -677,6 +678,42 @@ public void testSaveAndReturnFinalResult() { |
677 | 678 | assertEquals(finalResult, secondModelTensorList.get(0).getDataAsMap().get("response")); |
678 | 679 | } |
679 | 680 |
|
| 681 | + @Test |
| 682 | + public void testParseTensorDataMap() { |
| 683 | + // Test with response only |
| 684 | + Map<String, Object> dataMap = new HashMap<>(); |
| 685 | + dataMap.put("response", "test response"); |
| 686 | + ModelTensor tensor = ModelTensor.builder().dataAsMap(dataMap).build(); |
| 687 | + |
| 688 | + String result = mlPlanExecuteAndReflectAgentRunner.parseTensorDataMap(tensor); |
| 689 | + assertEquals("test response", result); |
| 690 | + |
| 691 | + // Test with additional info |
| 692 | + Map<String, Object> additionalInfo = new HashMap<>(); |
| 693 | + additionalInfo.put("trace1", "content1"); |
| 694 | + additionalInfo.put("trace2", "content2"); |
| 695 | + dataMap.put("additional_info", additionalInfo); |
| 696 | + |
| 697 | + result = mlPlanExecuteAndReflectAgentRunner.parseTensorDataMap(tensor); |
| 698 | + assertTrue(result.contains("test response")); |
| 699 | + assertTrue(result.contains("<step-traces>")); |
| 700 | + assertTrue(result.contains("<trace1>\ncontent1\n</trace1>")); |
| 701 | + assertTrue(result.contains("<trace2>\ncontent2\n</trace2>")); |
| 702 | + assertTrue(result.contains("</step-traces>")); |
| 703 | + |
| 704 | + // Test with null dataMap |
| 705 | + ModelTensor nullTensor = ModelTensor.builder().build(); |
| 706 | + assertNull(mlPlanExecuteAndReflectAgentRunner.parseTensorDataMap(nullTensor)); |
| 707 | + |
| 708 | + // No response field |
| 709 | + Map<String, Object> noResponseMap = new HashMap<>(); |
| 710 | + noResponseMap.put("additional_info", additionalInfo); |
| 711 | + ModelTensor noResponseTensor = ModelTensor.builder().dataAsMap(noResponseMap).build(); |
| 712 | + result = mlPlanExecuteAndReflectAgentRunner.parseTensorDataMap(noResponseTensor); |
| 713 | + assertTrue(result.contains("<step-traces>")); |
| 714 | + assertFalse(result.contains("test response")); |
| 715 | + } |
| 716 | + |
680 | 717 | @Test |
681 | 718 | public void testUpdateTaskWithExecutorAgentInfo() { |
682 | 719 | MLAgent mlAgent = createMLAgentWithTools(); |
@@ -765,4 +802,57 @@ public void testUpdateTaskWithExecutorAgentInfo() { |
765 | 802 | mlTaskUtilsMockedStatic.verify(() -> MLTaskUtils.updateMLTaskDirectly(eq(taskId), eq(taskUpdates), eq(client), any())); |
766 | 803 | } |
767 | 804 | } |
| 805 | + |
| 806 | + @Test |
| 807 | + public void testExecutionWithNullStepResult() { |
| 808 | + MLAgent mlAgent = createMLAgentWithTools(); |
| 809 | + |
| 810 | + // Setup LLM response for planning phase - returns steps to execute |
| 811 | + doAnswer(invocation -> { |
| 812 | + ActionListener<Object> listener = invocation.getArgument(2); |
| 813 | + ModelTensor modelTensor = ModelTensor |
| 814 | + .builder() |
| 815 | + .dataAsMap(ImmutableMap.of("response", "{\"steps\":[\"step1\"], \"result\":\"\"}")) |
| 816 | + .build(); |
| 817 | + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); |
| 818 | + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); |
| 819 | + when(mlTaskResponse.getOutput()).thenReturn(mlModelTensorOutput); |
| 820 | + listener.onResponse(mlTaskResponse); |
| 821 | + return null; |
| 822 | + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any()); |
| 823 | + |
| 824 | + // Setup executor response with tensor that has null dataMap - this will hit line 465 |
| 825 | + doAnswer(invocation -> { |
| 826 | + ActionListener<Object> listener = invocation.getArgument(2); |
| 827 | + ModelTensor memoryIdTensor = ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result("test_memory_id").build(); |
| 828 | + ModelTensor parentIdTensor = ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result("test_parent_id").build(); |
| 829 | + // This tensor will return null from parseTensorDataMap, hitting the stepResult != null check |
| 830 | + ModelTensor nullDataTensor = ModelTensor.builder().name("other").build(); |
| 831 | + ModelTensors modelTensors = ModelTensors |
| 832 | + .builder() |
| 833 | + .mlModelTensors(Arrays.asList(memoryIdTensor, parentIdTensor, nullDataTensor)) |
| 834 | + .build(); |
| 835 | + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); |
| 836 | + when(mlExecuteTaskResponse.getOutput()).thenReturn(mlModelTensorOutput); |
| 837 | + listener.onResponse(mlExecuteTaskResponse); |
| 838 | + return null; |
| 839 | + }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(MLExecuteTaskRequest.class), any()); |
| 840 | + |
| 841 | + Map<String, String> params = new HashMap<>(); |
| 842 | + params.put("question", "test question"); |
| 843 | + params.put("parent_interaction_id", "test_parent_interaction_id"); |
| 844 | + |
| 845 | + // Capture the exception in the listener |
| 846 | + doAnswer(invocation -> { |
| 847 | + Exception e = invocation.getArgument(0); |
| 848 | + assertTrue(e instanceof IllegalStateException); |
| 849 | + assertEquals("No valid response found in ReAct agent output", e.getMessage()); |
| 850 | + return null; |
| 851 | + }).when(agentActionListener).onFailure(any()); |
| 852 | + |
| 853 | + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); |
| 854 | + |
| 855 | + // Verify that onFailure was called with the expected exception |
| 856 | + verify(agentActionListener).onFailure(any(IllegalStateException.class)); |
| 857 | + } |
768 | 858 | } |
0 commit comments