Skip to content

Commit 7eb2552

Browse files
authored
Make memory optional (#4438)
* remove memory Signed-off-by: Pavan Yekbote <[email protected]> * spotless Signed-off-by: Pavan Yekbote <[email protected]> * remove placeholder values Signed-off-by: Pavan Yekbote <[email protected]> --------- Signed-off-by: Pavan Yekbote <[email protected]>
1 parent 34f5135 commit 7eb2552

File tree

3 files changed

+78
-39
lines changed

3 files changed

+78
-39
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,8 @@ public void execute(Input input, ActionListener<Output> listener, TransportChann
273273
if (memorySpec != null
274274
&& memorySpec.getType() != null
275275
&& memoryFactoryMap.containsKey(MLMemoryType.from(memorySpec.getType()).name())
276+
&& memoryFactoryMap != null
277+
&& !memoryFactoryMap.isEmpty()
276278
&& (memoryId == null || parentInteractionId == null)) {
277279
Map<String, Object> memoryParams = createMemoryParams(
278280
question,

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,11 @@ public void run(MLAgent mlAgent, Map<String, String> inputParams, ActionListener
196196
functionCalling.configure(params);
197197
}
198198

199+
if (mlAgent.getMemory() == null || memoryFactoryMap == null || memoryFactoryMap.isEmpty()) {
200+
runAgent(mlAgent, params, listener, null, null, functionCalling);
201+
return;
202+
}
203+
199204
String memoryType = MLMemoryType.from(mlAgent.getMemory().getType()).name();
200205
String memoryId = params.get(MLAgentExecutor.MEMORY_ID);
201206
String appType = mlAgent.getAppType();
@@ -421,7 +426,7 @@ private void runReAct(
421426

422427
saveTraceData(
423428
memory,
424-
memory.getType(),
429+
memory != null ? memory.getType() : null,
425430
question,
426431
thoughtResponse,
427432
sessionId,
@@ -842,20 +847,18 @@ private void sendFinalAnswer(
842847

843848
public static List<ModelTensors> createModelTensors(String sessionId, String parentInteractionId) {
844849
List<ModelTensors> cotModelTensors = new ArrayList<>();
850+
List<ModelTensor> tensors = new ArrayList<>();
845851

846-
cotModelTensors
847-
.add(
848-
ModelTensors
849-
.builder()
850-
.mlModelTensors(
851-
List
852-
.of(
853-
ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(),
854-
ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result(parentInteractionId).build()
855-
)
856-
)
857-
.build()
858-
);
852+
if (sessionId != null) {
853+
tensors.add(ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build());
854+
}
855+
if (parentInteractionId != null) {
856+
tensors.add(ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result(parentInteractionId).build());
857+
}
858+
859+
if (!tensors.isEmpty()) {
860+
cotModelTensors.add(ModelTensors.builder().mlModelTensors(tensors).build());
861+
}
859862
return cotModelTensors;
860863
}
861864

@@ -1013,18 +1016,22 @@ private void saveMessage(
10131016
boolean traceDisabled,
10141017
ActionListener listener
10151018
) {
1016-
ConversationIndexMessage msgTemp = ConversationIndexMessage
1017-
.conversationIndexMessageBuilder()
1018-
.type(memory.getType())
1019-
.question(question)
1020-
.response(finalAnswer)
1021-
.finalAnswer(isFinalAnswer)
1022-
.sessionId(sessionId)
1023-
.build();
1024-
if (traceDisabled) {
1025-
listener.onResponse(true);
1019+
if (memory != null) {
1020+
ConversationIndexMessage msgTemp = ConversationIndexMessage
1021+
.conversationIndexMessageBuilder()
1022+
.type(memory.getType())
1023+
.question(question)
1024+
.response(finalAnswer)
1025+
.finalAnswer(isFinalAnswer)
1026+
.sessionId(sessionId)
1027+
.build();
1028+
if (traceDisabled) {
1029+
listener.onResponse(true);
1030+
} else {
1031+
memory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), "LLM", listener);
1032+
}
10261033
} else {
1027-
memory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), "LLM", listener);
1034+
listener.onResponse(true);
10281035
}
10291036
}
10301037

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,12 @@ public void run(MLAgent mlAgent, Map<String, String> apiParams, ActionListener<O
298298
// planner prompt for the first call
299299
usePlannerPromptTemplate(allParams);
300300

301+
if (mlAgent.getMemory() == null || memoryFactoryMap == null || memoryFactoryMap.isEmpty()) {
302+
List<String> completedSteps = new ArrayList<>();
303+
setToolsAndRunAgent(mlAgent, allParams, completedSteps, null, null, listener);
304+
return;
305+
}
306+
301307
String memoryId = allParams.get(MEMORY_ID_FIELD);
302308
String memoryType = MLMemoryType.from(mlAgent.getMemory().getType()).name();
303309
String appType = mlAgent.getAppType();
@@ -589,7 +595,7 @@ private void executePlanningLoop(
589595

590596
saveTraceData(
591597
memory,
592-
memory.getType(),
598+
memory != null ? memory.getType() : null,
593599
stepToExecute,
594600
results.get(STEP_RESULT_FIELD),
595601
conversationId,
@@ -754,16 +760,38 @@ void saveAndReturnFinalResult(
754760
String input,
755761
ActionListener<Object> finalListener
756762
) {
757-
Map<String, Object> updateContent = new HashMap<>();
758-
updateContent.put(INTERACTIONS_RESPONSE_FIELD, finalResult);
763+
if (memory != null) {
764+
Map<String, Object> updateContent = new HashMap<>();
765+
updateContent.put(INTERACTIONS_RESPONSE_FIELD, finalResult);
759766

760-
if (input != null) {
761-
updateContent.put(INTERACTIONS_INPUT_FIELD, input);
762-
}
767+
if (input != null) {
768+
updateContent.put(INTERACTIONS_INPUT_FIELD, input);
769+
}
763770

764-
memory.update(parentInteractionId, updateContent, ActionListener.wrap(res -> {
771+
memory.update(parentInteractionId, updateContent, ActionListener.wrap(res -> {
772+
List<ModelTensors> finalModelTensors = createModelTensors(
773+
memory.getId(),
774+
parentInteractionId,
775+
reactAgentMemoryId,
776+
reactParentInteractionId
777+
);
778+
finalModelTensors
779+
.add(
780+
ModelTensors
781+
.builder()
782+
.mlModelTensors(
783+
List.of(ModelTensor.builder().name(RESPONSE_FIELD).dataAsMap(Map.of(RESPONSE_FIELD, finalResult)).build())
784+
)
785+
.build()
786+
);
787+
finalListener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build());
788+
}, e -> {
789+
log.error("Failed to update interaction with final result", e);
790+
finalListener.onFailure(e);
791+
}));
792+
} else {
765793
List<ModelTensors> finalModelTensors = createModelTensors(
766-
memory.getId(),
794+
null,
767795
parentInteractionId,
768796
reactAgentMemoryId,
769797
reactParentInteractionId
@@ -778,10 +806,7 @@ void saveAndReturnFinalResult(
778806
.build()
779807
);
780808
finalListener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build());
781-
}, e -> {
782-
log.error("Failed to update interaction with final result", e);
783-
finalListener.onFailure(e);
784-
}));
809+
}
785810
}
786811

787812
@VisibleForTesting
@@ -794,8 +819,13 @@ static List<ModelTensors> createModelTensors(
794819
List<ModelTensors> modelTensors = new ArrayList<>();
795820
List<ModelTensor> tensors = new ArrayList<>();
796821

797-
tensors.add(ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build());
798-
tensors.add(ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result(parentInteractionId).build());
822+
if (sessionId != null) {
823+
tensors.add(ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build());
824+
}
825+
826+
if (parentInteractionId != null) {
827+
tensors.add(ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result(parentInteractionId).build());
828+
}
799829

800830
if (reactAgentMemoryId != null && !reactAgentMemoryId.isEmpty()) {
801831
tensors.add(ModelTensor.builder().name(EXECUTOR_AGENT_MEMORY_ID_FIELD).result(reactAgentMemoryId).build());

0 commit comments

Comments
 (0)