Skip to content

Commit 12a3746

Browse files
b4sjoopyek-bot
authored andcommitted
wrap remote agent memory config into an object to avoid override (#4411)
1 parent 7912128 commit 12a3746

File tree

2 files changed

+52
-20
lines changed

2 files changed

+52
-20
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,22 +1044,52 @@ public static Map<String, Object> createMemoryParams(
10441044
memoryParams.put(TENANT_ID_FIELD, mlAgent.getTenantId());
10451045
}
10461046
if (requestParameters != null) {
1047-
String endpointParam = requestParameters.get(ENDPOINT_FIELD);
1048-
if (!Strings.isNullOrEmpty(endpointParam)) {
1049-
memoryParams.put(ENDPOINT_FIELD, endpointParam);
1050-
}
1051-
String regionParam = requestParameters.get(HttpConnector.REGION_FIELD);
1052-
if (!Strings.isNullOrEmpty(regionParam)) {
1053-
memoryParams.put(HttpConnector.REGION_FIELD, regionParam);
1054-
}
1055-
Map<String, String> credential = parseStringMapParameter(requestParameters.get(CREDENTIAL_FIELD), CREDENTIAL_FIELD);
1056-
if (credential != null && !credential.isEmpty()) {
1057-
memoryParams.put(CREDENTIAL_FIELD, credential);
1058-
}
1059-
// Extract user_id if provided
1060-
String userIdParam = requestParameters.get("user_id");
1061-
if (!Strings.isNullOrEmpty(userIdParam)) {
1062-
memoryParams.put("user_id", userIdParam);
1047+
// Check if parameters are wrapped in remote_agent_memory_configuration
1048+
String remoteMemoryConfigStr = requestParameters.get("remote_agent_memory_configuration");
1049+
if (!Strings.isNullOrEmpty(remoteMemoryConfigStr)) {
1050+
// Parse the remote_agent_memory_configuration JSON
1051+
try (
1052+
XContentParser parser = JsonXContent.jsonXContent
1053+
.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, remoteMemoryConfigStr)
1054+
) {
1055+
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
1056+
Map<String, Object> remoteMemoryConfig = parser.map();
1057+
1058+
// Extract memory_container_id
1059+
String memoryContainerIdParam = (String) remoteMemoryConfig.get(MEMORY_CONTAINER_ID_FIELD);
1060+
if (!Strings.isNullOrEmpty(memoryContainerIdParam)) {
1061+
memoryParams.put(MEMORY_CONTAINER_ID_FIELD, memoryContainerIdParam);
1062+
}
1063+
1064+
// Extract endpoint
1065+
String endpointParam = (String) remoteMemoryConfig.get(ENDPOINT_FIELD);
1066+
if (!Strings.isNullOrEmpty(endpointParam)) {
1067+
memoryParams.put(ENDPOINT_FIELD, endpointParam);
1068+
}
1069+
1070+
// Extract region
1071+
String regionParam = (String) remoteMemoryConfig.get(HttpConnector.REGION_FIELD);
1072+
if (!Strings.isNullOrEmpty(regionParam)) {
1073+
memoryParams.put(HttpConnector.REGION_FIELD, regionParam);
1074+
}
1075+
1076+
// Extract credential
1077+
Object credentialObj = remoteMemoryConfig.get(CREDENTIAL_FIELD);
1078+
if (credentialObj instanceof Map) {
1079+
Map<String, String> credential = (Map<String, String>) credentialObj;
1080+
if (!credential.isEmpty()) {
1081+
memoryParams.put(CREDENTIAL_FIELD, credential);
1082+
}
1083+
}
1084+
1085+
// Extract user_id if provided
1086+
String userIdParam = (String) remoteMemoryConfig.get("user_id");
1087+
if (!Strings.isNullOrEmpty(userIdParam)) {
1088+
memoryParams.put("user_id", userIdParam);
1089+
}
1090+
} catch (Exception e) {
1091+
log.error("Failed to parse remote_agent_memory_configuration", e);
1092+
}
10631093
}
10641094
memoryParams.put(TENANT_ID_FIELD, mlAgent.getTenantId());
10651095
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/RemoteAgenticConversationMemory.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,12 +1201,14 @@ private String extractServiceName(String endpoint) {
12011201
if (endpoint.contains(".aoss.amazonaws.com")) {
12021202
return "aoss";
12031203
}
1204-
// For managed OpenSearch: https://xxx.us-west-2.es.amazonaws.com
1205-
if (endpoint.contains(".es.amazonaws.com")) {
1204+
// For managed OpenSearch (production, staging, integration)
1205+
if (endpoint.contains(".es.amazonaws.com")
1206+
|| endpoint.contains(".es-staging.amazonaws.com")
1207+
|| endpoint.contains(".es-integ.amazonaws.com")) {
12061208
return "es";
12071209
}
1208-
// Default to es for OpenSearch/Elasticsearch service
1209-
return "es";
1210+
// Default to aoss for other OpenSearch services
1211+
return "aoss";
12101212
}
12111213

12121214
/**

0 commit comments

Comments
 (0)