Skip to content

Commit d8a9c75

Browse files
Add PER streaming support
Signed-off-by: Nathalie Jonathan <[email protected]>
1 parent 7d25d56 commit d8a9c75

File tree

6 files changed

+146
-41
lines changed

6 files changed

+146
-41
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,7 @@ private void sendFinalAnswer(
719719
String finalAnswer
720720
) {
721721
// Send completion chunk for streaming
722-
streamingWrapper.sendCompletionChunk(sessionId, parentInteractionId);
722+
streamingWrapper.sendCompletionChunk(sessionId, parentInteractionId, null, null);
723723

724724
if (conversationIndexMemory != null) {
725725
String copyOfFinalAnswer = finalAnswer;
@@ -794,6 +794,10 @@ private static String constructLLMPrompt(Map<String, Tool> tools, Map<String, St
794794
@VisibleForTesting
795795
static Map<String, String> constructLLMParams(LLMSpec llm, Map<String, String> parameters) {
796796
Map<String, String> tmpParameters = new HashMap<>();
797+
798+
// Set agent type for Chat agent for streaming
799+
tmpParameters.put("agent_type", "chat");
800+
797801
if (llm.getParameters() != null) {
798802
tmpParameters.putAll(llm.getParameters());
799803
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import java.util.function.Consumer;
4444

4545
import org.apache.commons.text.StringSubstitutor;
46+
import org.opensearch.action.ActionRequest;
4647
import org.opensearch.action.StepListener;
4748
import org.opensearch.cluster.service.ClusterService;
4849
import org.opensearch.common.settings.Settings;
@@ -57,7 +58,6 @@
5758
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
5859
import org.opensearch.ml.common.exception.MLException;
5960
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
60-
import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput;
6161
import org.opensearch.ml.common.output.model.ModelTensor;
6262
import org.opensearch.ml.common.output.model.ModelTensorOutput;
6363
import org.opensearch.ml.common.output.model.ModelTensors;
@@ -66,8 +66,6 @@
6666
import org.opensearch.ml.common.transport.MLTaskResponse;
6767
import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
6868
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
69-
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
70-
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
7169
import org.opensearch.ml.common.utils.StringUtils;
7270
import org.opensearch.ml.engine.encryptor.Encryptor;
7371
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
@@ -92,6 +90,7 @@ public class MLPlanExecuteAndReflectAgentRunner implements MLAgentRunner {
9290
private final Map<String, Memory.Factory> memoryFactoryMap;
9391
private SdkClient sdkClient;
9492
private Encryptor encryptor;
93+
private StreamingWrapper streamingWrapper;
9594
// flag to track if task has been updated with executor memory ids or not
9695
private boolean taskUpdated = false;
9796
private final Map<String, Object> taskUpdates = new HashMap<>();
@@ -182,6 +181,9 @@ public MLPlanExecuteAndReflectAgentRunner(
182181

183182
@VisibleForTesting
184183
void setupPromptParameters(Map<String, String> params) {
184+
// Set agent type for PER agent for streaming
185+
params.put("agent_type", "per");
186+
185187
// populated depending on whether LLM is asked to plan or re-evaluate
186188
// removed here, so that error is thrown in case this field is not populated
187189
params.remove(PROMPT_FIELD);
@@ -273,6 +275,7 @@ void populatePrompt(Map<String, String> allParams) {
273275

274276
@Override
275277
public void run(MLAgent mlAgent, Map<String, String> apiParams, ActionListener<Object> listener, TransportChannel channel) {
278+
this.streamingWrapper = new StreamingWrapper(channel, client);
276279
Map<String, String> allParams = new HashMap<>();
277280
allParams.putAll(apiParams);
278281
allParams.putAll(mlAgent.getParameters());
@@ -387,16 +390,7 @@ private void executePlanningLoop(
387390
return;
388391
}
389392

390-
MLPredictionTaskRequest request = new MLPredictionTaskRequest(
391-
llm.getModelId(),
392-
RemoteInferenceMLInput
393-
.builder()
394-
.algorithm(FunctionName.REMOTE)
395-
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(allParams).build())
396-
.build(),
397-
null,
398-
allParams.get(TENANT_ID_FIELD)
399-
);
393+
ActionRequest request = streamingWrapper.createPredictionRequest(llm, allParams, allParams.get(TENANT_ID_FIELD));
400394

401395
StepListener<MLTaskResponse> planListener = new StepListener<>();
402396

@@ -550,8 +544,7 @@ private void executePlanningLoop(
550544
log.error("Failed to run deep research agent", e);
551545
finalListener.onFailure(e);
552546
});
553-
554-
client.execute(MLPredictionTaskAction.INSTANCE, request, planListener);
547+
streamingWrapper.executeRequest(request, planListener);
555548
}
556549

557550
@VisibleForTesting
@@ -689,6 +682,9 @@ void saveAndReturnFinalResult(
689682
}
690683

691684
memory.getMemoryManager().updateInteraction(parentInteractionId, updateContent, ActionListener.wrap(res -> {
685+
// Send completion chunk to close streaming connection
686+
streamingWrapper
687+
.sendCompletionChunk(memory.getConversationId(), parentInteractionId, reactAgentMemoryId, reactParentInteractionId);
692688
List<ModelTensors> finalModelTensors = createModelTensors(
693689
memory.getConversationId(),
694690
parentInteractionId,

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/StreamingWrapper.java

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,23 @@ public void executeRequest(ActionRequest request, ActionListener<MLTaskResponse>
8383
client.execute(MLPredictionTaskAction.INSTANCE, request, listener);
8484
}
8585

86-
public void sendCompletionChunk(String sessionId, String parentInteractionId) {
86+
public void sendCompletionChunk(
87+
String sessionId,
88+
String parentInteractionId,
89+
String executorMemoryId,
90+
String executorParentInteractionId
91+
) {
8792
if (!isStreaming) {
8893
return;
8994
}
90-
MLTaskResponse completionChunk = createStreamChunk("", sessionId, parentInteractionId, true);
95+
MLTaskResponse completionChunk = createStreamChunk(
96+
"",
97+
sessionId,
98+
parentInteractionId,
99+
executorMemoryId,
100+
executorParentInteractionId,
101+
true
102+
);
91103
try {
92104
channel.sendResponseBatch(completionChunk);
93105
} catch (Exception e) {
@@ -114,20 +126,29 @@ public void sendFinalResponse(
114126
public void sendToolResponse(String toolOutput, String sessionId, String parentInteractionId) {
115127
if (isStreaming) {
116128
try {
117-
MLTaskResponse toolChunk = createStreamChunk(toolOutput, sessionId, parentInteractionId, false);
129+
MLTaskResponse toolChunk = createStreamChunk(toolOutput, sessionId, parentInteractionId, null, null, false);
118130
channel.sendResponseBatch(toolChunk);
119131
} catch (Exception e) {
120132
log.error("Failed to send tool response chunk", e);
121133
}
122134
}
123135
}
124136

125-
private MLTaskResponse createStreamChunk(String toolOutput, String sessionId, String parentInteractionId, boolean isLast) {
137+
private MLTaskResponse createStreamChunk(
138+
String toolOutput,
139+
String sessionId,
140+
String parentInteractionId,
141+
String executorMemoryId,
142+
String executorParentInteractionId,
143+
boolean isLast
144+
) {
126145
List<ModelTensor> tensors = Arrays
127146
.asList(
128147
ModelTensor.builder().name("response").dataAsMap(Map.of("content", toolOutput, "is_last", isLast)).build(),
129148
ModelTensor.builder().name("memory_id").result(sessionId).build(),
130-
ModelTensor.builder().name("parent_interaction_id").result(parentInteractionId).build()
149+
ModelTensor.builder().name("parent_interaction_id").result(parentInteractionId).build(),
150+
ModelTensor.builder().name("executor_agent_memory_id").result(executorMemoryId).build(),
151+
ModelTensor.builder().name("executor_agent_parent_interaction_id").result(executorParentInteractionId).build()
131152
);
132153

133154
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(tensors).build();

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/BedrockStreamingHandler.java

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ public void startStream(
9595
AtomicReference<String> toolUseId = new AtomicReference<>();
9696
StringBuilder toolInputAccumulator = new StringBuilder();
9797
AtomicReference<StreamState> currentState = new AtomicReference<>(StreamState.STREAMING_CONTENT);
98+
String agentType = parameters.get("agent_type");
99+
StringBuilder accumulatedContent = new StringBuilder();
98100

99101
// Build Bedrock client
100102
BedrockRuntimeAsyncClient bedrockClient = buildBedrockRuntimeAsyncClient();
@@ -128,7 +130,18 @@ public void startStream(
128130
log.debug("Tool execution in progress - keeping stream open");
129131
}
130132
}).subscriber(event -> {
131-
handleStreamEvent(event, listener, isStreamClosed, toolName, toolInput, toolUseId, toolInputAccumulator, currentState);
133+
handleStreamEvent(
134+
event,
135+
listener,
136+
isStreamClosed,
137+
toolName,
138+
toolInput,
139+
toolUseId,
140+
toolInputAccumulator,
141+
currentState,
142+
agentType,
143+
accumulatedContent
144+
);
132145
}).build();
133146

134147
// Start streaming
@@ -183,18 +196,29 @@ private void handleStreamEvent(
183196
AtomicReference<Map<String, Object>> toolInput,
184197
AtomicReference<String> toolUseId,
185198
StringBuilder toolInputAccumulator,
186-
AtomicReference<StreamState> currentState
199+
AtomicReference<StreamState> currentState,
200+
String agentType,
201+
StringBuilder accumulatedContent
187202
) {
188203
switch (currentState.get()) {
189204
case STREAMING_CONTENT:
190205
if (isToolUseDetected(event)) {
191206
currentState.set(StreamState.TOOL_CALL_DETECTED);
192207
extractToolInfo(event, toolName, toolUseId);
193208
} else if (isContentDelta(event)) {
194-
sendContentResponse(getTextContent(event), false, listener);
209+
String content = getTextContent(event);
210+
accumulatedContent.append(content);
211+
sendContentResponse(content, false, listener);
195212
} else if (isStreamComplete(event)) {
196-
currentState.set(StreamState.COMPLETED);
197-
sendCompletionResponse(isStreamClosed, listener);
213+
// For PER agent, we should keep the connection open after the planner LLM finish
214+
if ("per".equals(agentType)) {
215+
sendPlannerResponse(false, listener, String.valueOf(accumulatedContent));
216+
currentState.set(StreamState.WAITING_FOR_TOOL_RESULT);
217+
log.info("PER agent planner phase completed - waiting for execution phase");
218+
} else {
219+
currentState.set(StreamState.COMPLETED);
220+
sendCompletionResponse(isStreamClosed, listener);
221+
}
198222
}
199223
break;
200224

@@ -225,6 +249,26 @@ private void handleStreamEvent(
225249
}
226250
}
227251

252+
private void sendPlannerResponse(
253+
boolean isStreamClosed,
254+
StreamPredictActionListener<MLTaskResponse, ?> listener,
255+
String plannerContent
256+
) {
257+
if (!isStreamClosed) {
258+
Map<String, Object> responseMap = new HashMap<>();
259+
responseMap.put("output", Map.of("message", Map.of("content", List.of(Map.of("text", plannerContent)))));
260+
261+
ModelTensor tensor = ModelTensor.builder().name("response").dataAsMap(responseMap).build();
262+
263+
ModelTensors tensors = ModelTensors.builder().mlModelTensors(List.of(tensor)).build();
264+
265+
ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(List.of(tensors)).build();
266+
267+
listener.onResponse(MLTaskResponse.builder().output(output).build());
268+
log.debug("Sent planner response for PER agent");
269+
}
270+
}
271+
228272
// TODO: refactor the event type checker methods
229273
private void extractToolInfo(ConverseStreamOutput event, AtomicReference<String> toolName, AtomicReference<String> toolUseId) {
230274
ContentBlockStartEvent startEvent = (ContentBlockStartEvent) event;

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/HttpStreamingHandler.java

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ public void startStream(
7676
) {
7777
try {
7878
log.info("Creating SSE connection for streaming request");
79-
EventSourceListener listener = new HTTPEventSourceListener(actionListener, llmInterface);
79+
String agentType = parameters.get("agent_type");
80+
EventSourceListener listener = new HTTPEventSourceListener(actionListener, llmInterface, agentType);
8081
Request request = ConnectorUtils.buildOKHttpStreamingRequest(action, connector, parameters, payload);
8182

8283
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
@@ -100,16 +101,23 @@ public final class HTTPEventSourceListener extends EventSourceListener {
100101
private StreamPredictActionListener<MLTaskResponse, ?> streamActionListener;
101102
private final String llmInterface;
102103
private AtomicBoolean isStreamClosed;
104+
private final String agentType;
103105
private boolean functionCallInProgress = false;
104106
private boolean agentExecutionInProgress = false;
105107
private String accumulatedToolCallId = null;
106108
private String accumulatedToolName = null;
107109
private String accumulatedArguments = "";
110+
private StringBuilder accumulatedContent = new StringBuilder();
108111

109-
public HTTPEventSourceListener(StreamPredictActionListener<MLTaskResponse, ?> streamActionListener, String llmInterface) {
112+
public HTTPEventSourceListener(
113+
StreamPredictActionListener<MLTaskResponse, ?> streamActionListener,
114+
String llmInterface,
115+
String agentType
116+
) {
110117
this.streamActionListener = streamActionListener;
111118
this.llmInterface = llmInterface;
112119
this.isStreamClosed = new AtomicBoolean(false);
120+
this.agentType = agentType;
113121
}
114122

115123
/***
@@ -206,14 +214,20 @@ private void processStreamChunk(Map<String, Object> dataMap) {
206214
// Handle stop finish reason
207215
String finishReason = extractPath(dataMap, "$.choices[0].finish_reason");
208216
if ("stop".equals(finishReason)) {
209-
agentExecutionInProgress = false;
210-
sendCompletionResponse(isStreamClosed, streamActionListener);
217+
// For PER agent, we should keep the connection open after the planner LLM finish
218+
if ("per".equals(agentType)) {
219+
completePlannerResponse();
220+
} else {
221+
agentExecutionInProgress = false;
222+
sendCompletionResponse(isStreamClosed, streamActionListener);
223+
}
211224
return;
212225
}
213226

214227
// Process content
215228
String content = extractPath(dataMap, "$.choices[0].delta.content");
216229
if (content != null && !content.isEmpty()) {
230+
accumulatedContent.append(content);
217231
sendContentResponse(content, false, streamActionListener);
218232
}
219233

@@ -268,6 +282,19 @@ private ModelTensorOutput createModelTensorOutput(Map<String, Object> responseDa
268282
return ModelTensorOutput.builder().mlModelOutputs(List.of(tensors)).build();
269283
}
270284

285+
private void completePlannerResponse() {
286+
String fullContent = accumulatedContent.toString().trim();
287+
288+
// Create compatible response format
289+
Map<String, Object> message = Map.of("content", fullContent);
290+
Map<String, Object> choice = Map.of("message", message);
291+
Map<String, Object> response = Map.of("choices", List.of(choice));
292+
293+
ModelTensorOutput output = createModelTensorOutput(response);
294+
streamActionListener.onResponse(new MLTaskResponse(output));
295+
agentExecutionInProgress = true;
296+
}
297+
271298
private void accumulateFunctionCall(List<?> toolCalls) {
272299
functionCallInProgress = true;
273300
for (Object toolCall : toolCalls) {

plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.io.ByteArrayOutputStream;
2121
import java.io.IOException;
2222
import java.nio.ByteBuffer;
23+
import java.util.ArrayList;
2324
import java.util.LinkedHashMap;
2425
import java.util.List;
2526
import java.util.Locale;
@@ -346,21 +347,33 @@ private HttpChunk convertToHttpChunk(MLTaskResponse response) throws IOException
346347
// Regular response - extract values and build proper structure
347348
String memoryId = extractTensorResult(response, "memory_id");
348349
String parentInteractionId = extractTensorResult(response, "parent_interaction_id");
350+
String executorMemoryId = extractTensorResult(response, "executor_agent_memory_id");
351+
String executorParentInteractionId = extractTensorResult(response, "executor_agent_parent_interaction_id");
349352
String content = dataMap.containsKey("content") ? (String) dataMap.get("content") : "";
350353
isLast = dataMap.containsKey("is_last") ? Boolean.TRUE.equals(dataMap.get("is_last")) : false;
351354
boolean finalIsLast = isLast;
352355

353-
List<ModelTensor> orderedTensors = List
354-
.of(
355-
ModelTensor.builder().name("memory_id").result(memoryId).build(),
356-
ModelTensor.builder().name("parent_interaction_id").result(parentInteractionId).build(),
357-
ModelTensor.builder().name("response").dataAsMap(new LinkedHashMap<String, Object>() {
358-
{
359-
put("content", content);
360-
put("is_last", finalIsLast);
361-
}
362-
}).build()
363-
);
356+
List<ModelTensor> orderedTensors = new ArrayList<>();
357+
orderedTensors.add(ModelTensor.builder().name("memory_id").result(memoryId).build());
358+
orderedTensors.add(ModelTensor.builder().name("parent_interaction_id").result(parentInteractionId).build());
359+
360+
if (executorMemoryId != null && !executorMemoryId.isEmpty()) {
361+
orderedTensors.add(ModelTensor.builder().name("executor_agent_memory_id").result(executorMemoryId).build());
362+
}
363+
364+
if (executorParentInteractionId != null && !executorParentInteractionId.isEmpty()) {
365+
orderedTensors
366+
.add(
367+
ModelTensor.builder().name("executor_agent_parent_interaction_id").result(executorParentInteractionId).build()
368+
);
369+
}
370+
371+
orderedTensors.add(ModelTensor.builder().name("response").dataAsMap(new LinkedHashMap<String, Object>() {
372+
{
373+
put("content", content);
374+
put("is_last", finalIsLast);
375+
}
376+
}).build());
364377

365378
ModelTensors tensors = ModelTensors.builder().mlModelTensors(orderedTensors).build();
366379
ModelTensorOutput tensorOutput = ModelTensorOutput.builder().mlModelOutputs(List.of(tensors)).build();

0 commit comments

Comments
 (0)