Skip to content

Commit b1f7291

Browse files
committed
fix failed UT
Signed-off-by: Sicheng Song <[email protected]>
1 parent b92e13d commit b1f7291

File tree

2 files changed

+26
-7
lines changed

2 files changed

+26
-7
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/RemoteAgenticConversationMemory.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,9 +1155,6 @@ private Connector createInlineConnector(String endpoint, String region, Map<Stri
11551155
// Extract tenant ID from role ARN if applicable
11561156
String tenantId = extractTenantIdFromRoleArn(serviceName, credentials);
11571157

1158-
// Extract tenant ID from role ARN if applicable
1159-
String tenantId = extractTenantIdFromRoleArn(serviceName, credentials);
1160-
11611158
// Create Memory Container API actions
11621159
List<ConnectorAction> actions = createMemoryContainerActions();
11631160

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ public void testRunWithIncludeOutputNotSet() {
253253
.llm(llmSpec)
254254
.memory(mlMemorySpec)
255255
.tools(Arrays.asList(firstToolSpec, secondToolSpec))
256+
.tenantId("test_tenant")
256257
.build();
257258
mlChatAgentRunner.run(mlAgent, new HashMap<>(), agentActionListener, transportChannel);
258259
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
@@ -283,6 +284,7 @@ public void testRunWithIncludeOutputMLModel() {
283284
.llm(llmSpec)
284285
.memory(mlMemorySpec)
285286
.tools(Arrays.asList(firstToolSpec, secondToolSpec))
287+
.tenantId("test_tenant")
286288
.build();
287289
mlChatAgentRunner.run(mlAgent, new HashMap<>(), agentActionListener, transportChannel);
288290
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
@@ -317,6 +319,7 @@ public void testRunWithIncludeOutputSet() {
317319
.memory(mlMemorySpec)
318320
.llm(llmSpec)
319321
.tools(Arrays.asList(firstToolSpec, secondToolSpec))
322+
.tenantId("test_tenant")
320323
.build();
321324
HashMap<String, String> params = new HashMap<>();
322325
mlChatAgentRunner.run(mlAgent, params, agentActionListener, transportChannel);
@@ -358,6 +361,7 @@ public void testChatHistoryExcludeOngoingQuestion() {
358361
.llm(llmSpec)
359362
.description("mlagent description")
360363
.tools(Arrays.asList(firstToolSpec, secondToolSpec))
364+
.tenantId("test_tenant")
361365
.build();
362366

363367
doAnswer(invocation -> {
@@ -414,6 +418,7 @@ private void testInteractions(String maxInteraction) {
414418
.llm(llmSpec)
415419
.description("mlagent description")
416420
.tools(Arrays.asList(firstToolSpec, secondToolSpec))
421+
.tenantId("test_tenant")
417422
.build();
418423

419424
doAnswer(invocation -> {
@@ -446,6 +451,7 @@ public void testChatHistoryException() {
446451
.memory(mlMemorySpec)
447452
.llm(llmSpec)
448453
.tools(Arrays.asList(firstToolSpec, secondToolSpec))
454+
.tenantId("test_tenant")
449455
.build();
450456

451457
doAnswer(invocation -> {
@@ -516,6 +522,7 @@ public void testToolNotFound() {
516522
.memory(mlMemorySpec)
517523
.llm(llmSpec)
518524
.name("TestAgent")
525+
.tenantId("test_tenant")
519526
.build();
520527

521528
// Create parameters for the agent with a non-existent tool
@@ -598,9 +605,10 @@ public void testToolParameters() {
598605
// Verify that the tool's run method was called.
599606
verify(firstTool).run(any(), any());
600607
// Verify the size of parameters passed in the tool run method.
608+
// Note: size is 19 because tenant_id is now passed to tools
601609
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
602610
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
603-
assertEquals(18, ((Map) argumentCaptor.getValue()).size());
611+
assertEquals(19, ((Map) argumentCaptor.getValue()).size());
604612

605613
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
606614
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
@@ -626,9 +634,10 @@ public void testToolUseOriginalInput() {
626634
// Verify that the tool's run method was called.
627635
verify(firstTool).run(any(), any());
628636
// Verify the size of parameters passed in the tool run method.
637+
// Note: size is 20 because tenant_id is now passed to tools
629638
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
630639
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
631-
assertEquals(19, ((Map) argumentCaptor.getValue()).size());
640+
assertEquals(20, ((Map) argumentCaptor.getValue()).size());
632641
assertEquals("raw input", ((Map<?, ?>) argumentCaptor.getValue()).get("input"));
633642

634643
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
@@ -692,9 +701,10 @@ public void testToolConfig() {
692701
// Verify that the tool's run method was called.
693702
verify(firstTool).run(any(), any());
694703
// Verify the size of parameters passed in the tool run method.
704+
// Note: size is 20 because tenant_id is now passed to tools
695705
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
696706
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
697-
assertEquals(19, ((Map) argumentCaptor.getValue()).size());
707+
assertEquals(20, ((Map) argumentCaptor.getValue()).size());
698708
// The value of input should be "config_value".
699709
assertEquals("config_value", ((Map<?, ?>) argumentCaptor.getValue()).get("input"));
700710

@@ -722,9 +732,10 @@ public void testToolConfigWithInputPlaceholder() {
722732
// Verify that the tool's run method was called.
723733
verify(firstTool).run(any(), any());
724734
// Verify the size of parameters passed in the tool run method.
735+
// Note: size is 20 because tenant_id is now passed to tools
725736
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
726737
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
727-
assertEquals(19, ((Map) argumentCaptor.getValue()).size());
738+
assertEquals(20, ((Map) argumentCaptor.getValue()).size());
728739
// The value of input should be replaced with the value associated with the key "key2" of the first tool.
729740
assertEquals("value2", ((Map<?, ?>) argumentCaptor.getValue()).get("input"));
730741

@@ -785,6 +796,7 @@ public void testToolExecutionWithChatHistoryParameter() {
785796
.llm(llmSpec)
786797
.description("mlagent description")
787798
.tools(Arrays.asList(firstToolSpec, secondToolSpec))
799+
.tenantId("test_tenant")
788800
.build();
789801

790802
doAnswer(invocation -> {
@@ -826,6 +838,7 @@ private MLAgent createMLAgentWithTools() {
826838
.tools(Arrays.asList(firstToolSpec))
827839
.memory(mlMemorySpec)
828840
.llm(llmSpec)
841+
.tenantId("test_tenant")
829842
.build();
830843
}
831844

@@ -845,6 +858,7 @@ private MLAgent createMLAgentWithToolsConfig(Map<String, String> configMap) {
845858
.tools(Arrays.asList(firstToolSpec))
846859
.memory(mlMemorySpec)
847860
.llm(llmSpec)
861+
.tenantId("test_tenant")
848862
.build();
849863
}
850864

@@ -919,6 +933,7 @@ public void testMaxIterationsReached() {
919933
.llm(llmSpec)
920934
.memory(mlMemorySpec)
921935
.tools(Arrays.asList(firstToolSpec))
936+
.tenantId("test_tenant")
922937
.build();
923938

924939
// Reset client mock for this test
@@ -972,6 +987,7 @@ public void testMaxIterationsReachedWithValidThought() {
972987
.llm(llmSpec)
973988
.memory(mlMemorySpec)
974989
.tools(Arrays.asList(firstToolSpec))
990+
.tenantId("test_tenant")
975991
.build();
976992

977993
// Reset client mock for this test
@@ -1067,6 +1083,7 @@ private MLAgent createMLAgentWithScratchpadTools() {
10671083
.tools(Arrays.asList(writeToolSpec, readToolSpec))
10681084
.memory(mlMemorySpec)
10691085
.llm(llmSpec)
1086+
.tenantId("test_tenant")
10701087
.build();
10711088
}
10721089

@@ -1096,6 +1113,7 @@ public void testMaxIterationsWithSummaryEnabled() {
10961113
.llm(llmSpec)
10971114
.memory(mlMemorySpec)
10981115
.tools(Arrays.asList(firstToolSpec))
1116+
.tenantId("test_tenant")
10991117
.build();
11001118

11011119
// Reset and setup fresh mocks
@@ -1157,6 +1175,7 @@ public void testMaxIterationsWithSummaryDisabled() {
11571175
.llm(llmSpec)
11581176
.memory(mlMemorySpec)
11591177
.tools(Arrays.asList(firstToolSpec))
1178+
.tenantId("test_tenant")
11601179
.build();
11611180

11621181
// Reset client mock for this test
@@ -1199,6 +1218,7 @@ public void testCreateMemoryAdapter_ConversationIndex() {
11991218
.type(MLAgentType.CONVERSATIONAL.name())
12001219
.llm(llmSpec)
12011220
.memory(memorySpec)
1221+
.tenantId("test_tenant")
12021222
.build();
12031223

12041224
Map<String, String> params = new HashMap<>();
@@ -1249,6 +1269,7 @@ public void testCreateMemoryAdapter_AgenticMemory() {
12491269
.type(MLAgentType.CONVERSATIONAL.name())
12501270
.llm(llmSpec)
12511271
.memory(memorySpec)
1272+
.tenantId("test_tenant")
12521273
.build();
12531274

12541275
Map<String, String> params = new HashMap<>();
@@ -1352,6 +1373,7 @@ public void testExtractSummaryFromResponse_ThrowsException_FallbackStrategyUsed(
13521373
.llm(llmSpec)
13531374
.memory(mlMemorySpec)
13541375
.tools(Arrays.asList(firstToolSpec))
1376+
.tenantId("test_tenant")
13551377
.build();
13561378

13571379
Mockito.reset(client);

0 commit comments

Comments
 (0)