Skip to content

Commit 2dd3690

Browse files
authored
Merge branch 'main' into fix-websearchtool-security-issue
2 parents fda5e1f + 266bcfe commit 2dd3690

File tree

23 files changed

+920
-148
lines changed

23 files changed

+920
-148
lines changed

common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
package org.opensearch.ml.common.settings;
77

88
import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_ENDPOINT_KEY;
9+
import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_GLOBAL_RESOURCE_CACHE_TTL_KEY;
10+
import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_GLOBAL_TENANT_ID_KEY;
911
import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_REGION_KEY;
1012
import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_SERVICE_NAME_KEY;
1113
import static org.opensearch.remote.metadata.common.CommonValue.REMOTE_METADATA_TYPE_KEY;
@@ -372,4 +374,14 @@ private MLCommonsSettings() {}
372374
.boolSetting("plugins.ml_commons.agentic_memory_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
373375
public static final String ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE =
374376
"The Agentic Memory APIs are not enabled. To enable, please update the setting " + ML_COMMONS_AGENTIC_MEMORY_ENABLED.getKey();
377+
378+
public static final Setting<String> REMOTE_METADATA_GLOBAL_TENANT_ID = Setting
379+
.simpleString("plugins.ml-commons." + REMOTE_METADATA_GLOBAL_TENANT_ID_KEY, Setting.Property.NodeScope, Setting.Property.Final);
380+
381+
public static final Setting<String> REMOTE_METADATA_GLOBAL_RESOURCE_CACHE_TTL = Setting
382+
.simpleString(
383+
"plugins.ml-commons." + REMOTE_METADATA_GLOBAL_RESOURCE_CACHE_TTL_KEY,
384+
Setting.Property.NodeScope,
385+
Setting.Property.Final
386+
);
375387
}

common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import static org.apache.commons.text.StringEscapeUtils.escapeJson;
99
import static org.opensearch.action.ValidateActions.addValidationError;
1010

11+
import java.io.IOException;
12+
import java.math.BigDecimal;
1113
import java.nio.ByteBuffer;
1214
import java.nio.charset.StandardCharsets;
1315
import java.security.AccessController;
@@ -38,11 +40,15 @@
3840
import com.fasterxml.jackson.databind.JsonNode;
3941
import com.fasterxml.jackson.databind.ObjectMapper;
4042
import com.google.gson.Gson;
43+
import com.google.gson.GsonBuilder;
4144
import com.google.gson.JsonElement;
4245
import com.google.gson.JsonObject;
4346
import com.google.gson.JsonParser;
4447
import com.google.gson.JsonSyntaxException;
48+
import com.google.gson.TypeAdapter;
4549
import com.google.gson.reflect.TypeToken;
50+
import com.google.gson.stream.JsonReader;
51+
import com.google.gson.stream.JsonWriter;
4652
import com.jayway.jsonpath.JsonPath;
4753
import com.jayway.jsonpath.PathNotFoundException;
4854
import com.networknt.schema.JsonSchema;
@@ -71,11 +77,15 @@ public class StringUtils {
7177

7278
public static final String SAFE_INPUT_DESCRIPTION = "can only contain letters, numbers, spaces, and basic punctuation (.,!?():@-_'/\")";
7379

74-
public static final Gson gson;
80+
public static final Gson gson = new Gson();
81+
public static final Gson PLAIN_NUMBER_GSON = new GsonBuilder()
82+
.serializeNulls()
83+
.registerTypeAdapter(Float.class, new PlainFloatAdapter())
84+
.registerTypeAdapter(float.class, new PlainFloatAdapter())
85+
.registerTypeAdapter(Double.class, new PlainDoubleAdapter())
86+
.registerTypeAdapter(double.class, new PlainDoubleAdapter())
87+
.create();
7588

76-
static {
77-
gson = new Gson();
78-
}
7989
public static final String TO_STRING_FUNCTION_NAME = ".toString()";
8090

8191
public static final ObjectMapper MAPPER = new ObjectMapper();
@@ -597,4 +607,49 @@ public static List<String> parseStringArrayToList(String jsonArrayString) {
597607
return Collections.emptyList();
598608
}
599609
}
610+
611+
/**
612+
* Custom Gson adapter for Double and Float type.
613+
* Serializes numbers without scientific notation.
614+
* Writes null for null, NaN, and Infinity values.
615+
* Deserializes JSON numbers back to Double and Float.
616+
*/
617+
private static class PlainDoubleAdapter extends TypeAdapter<Double> {
618+
@Override
619+
public void write(JsonWriter out, Double value) throws IOException {
620+
if (value == null || value.isNaN() || value.isInfinite()) {
621+
out.nullValue();
622+
return;
623+
}
624+
625+
BigDecimal bd = BigDecimal.valueOf(value).stripTrailingZeros();
626+
627+
out.jsonValue(bd.toPlainString());
628+
}
629+
630+
@Override
631+
public Double read(JsonReader in) throws IOException {
632+
return in.nextDouble();
633+
}
634+
}
635+
636+
public static class PlainFloatAdapter extends TypeAdapter<Float> {
637+
@Override
638+
public void write(JsonWriter out, Float value) throws IOException {
639+
if (value == null || value.isNaN() || value.isInfinite()) {
640+
out.nullValue();
641+
return;
642+
}
643+
644+
BigDecimal bd = new BigDecimal(Float.toString(value)).stripTrailingZeros();
645+
out.jsonValue(bd.toPlainString());
646+
}
647+
648+
@Override
649+
public Float read(JsonReader in) throws IOException {
650+
double d = in.nextDouble();
651+
float f = (float) d;
652+
return f;
653+
}
654+
}
600655
}

common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import static org.opensearch.ml.common.utils.StringUtils.*;
1717

1818
import java.io.IOException;
19+
import java.lang.reflect.Constructor;
1920
import java.util.ArrayList;
2021
import java.util.Arrays;
2122
import java.util.HashMap;
@@ -31,6 +32,9 @@
3132
import org.opensearch.OpenSearchParseException;
3233
import org.opensearch.action.ActionRequestValidationException;
3334

35+
import com.google.gson.JsonElement;
36+
import com.google.gson.TypeAdapter;
37+
import com.google.gson.reflect.TypeToken;
3438
import com.jayway.jsonpath.JsonPath;
3539

3640
public class StringUtilsTest {
@@ -949,4 +953,133 @@ public void testParseStringArrayToList_Null() {
949953
assertEquals(0, array.size());
950954
}
951955

956+
// reflect method for PlainDoubleAdapter
957+
private static TypeAdapter<Double> createPlainDoubleAdapter() {
958+
try {
959+
Class<?> clazz = Class.forName("org.opensearch.ml.common.utils.StringUtils$PlainDoubleAdapter");
960+
Constructor<?> constructor = clazz.getDeclaredConstructor();
961+
constructor.setAccessible(true);
962+
Object adapterInstance = constructor.newInstance();
963+
return (TypeAdapter<Double>) adapterInstance;
964+
} catch (Exception e) {
965+
throw new RuntimeException("Failed to create PlainDoubleAdapter via reflection", e);
966+
}
967+
}
968+
969+
@Test
970+
public void testSerializeScientificNotation_RemovesExponent() {
971+
Map<String, Object> data = Map.of("test1", 1e30, "test2", 1.2e3, "test3", 9.5e-3, "test4", 1.56e-30);
972+
973+
String json = StringUtils.PLAIN_NUMBER_GSON.toJson(data);
974+
975+
assertTrue(json.contains("1000000000000000000000000000000"));
976+
assertTrue(json.contains("1200"));
977+
assertTrue(json.contains("0.0095"));
978+
assertTrue(json.contains("0.00000000000000000000000000000156"));
979+
980+
}
981+
982+
@Test
983+
public void testSerializeInteger_RemovesDecimalPoint() {
984+
Map<String, Object> data = Map.of("intLike", 42.0);
985+
986+
String json = StringUtils.PLAIN_NUMBER_GSON.toJson(data);
987+
988+
assertTrue(json.contains("42"));
989+
assertFalse(json.contains("42.0"));
990+
}
991+
992+
@Test
993+
public void testSerializeNaNAndInfinity_BecomesNull() {
994+
Map<String, Double> data = new HashMap<>();
995+
data.put("nul", null);
996+
data.put("nan", Double.NaN);
997+
data.put("inf", Double.POSITIVE_INFINITY);
998+
data.put("ninf", Double.NEGATIVE_INFINITY);
999+
1000+
String json = StringUtils.PLAIN_NUMBER_GSON.toJson(data);
1001+
1002+
assertTrue(json.contains("\"nan\":null"));
1003+
assertTrue(json.contains("\"inf\":null"));
1004+
assertTrue(json.contains("\"ninf\":null"));
1005+
assertTrue(json.contains("\"nul\":null"));
1006+
1007+
assertFalse(json.contains("NaN"));
1008+
assertFalse(json.contains("Infinity"));
1009+
}
1010+
1011+
@Test
1012+
public void testDeserializeBackToDouble() {
1013+
String json = "{\"value\": 12345.6789}";
1014+
1015+
Map<?, ?> result = StringUtils.PLAIN_NUMBER_GSON.fromJson(json, Map.class);
1016+
1017+
Object value = result.get("value");
1018+
assertTrue(value instanceof Double);
1019+
assertEquals(12345.6789, (Double) value, 1e-7);
1020+
}
1021+
1022+
@Test
1023+
public void testQuotedScientificNotation_RemainsString() {
1024+
String json = "{\"code\":\"1e-6\"}";
1025+
1026+
Map<?, ?> result = StringUtils.PLAIN_NUMBER_GSON.fromJson(json, Map.class);
1027+
1028+
assertEquals("1e-6", result.get("code"));
1029+
}
1030+
1031+
@Test
1032+
public void testSerializeFloatScientificNotation_RemovesExponent_InPojo() {
1033+
java.util.Map<String, Float> data = new java.util.LinkedHashMap<>();
1034+
data.put("fObj", 1.23e-5f);
1035+
data.put("fPrim", 9.5e-3f);
1036+
1037+
String json = StringUtils.PLAIN_NUMBER_GSON.toJson(data);
1038+
1039+
assertTrue(json.contains("\"fObj\":0.0000123"));
1040+
1041+
assertTrue(json.contains("\"fPrim\":0.0095") || json.contains("\"fPrim\":9.5E-3") || json.contains("\"fPrim\":9.5e-3"));
1042+
}
1043+
1044+
@Test
1045+
public void testSerializeFloatNaNAndInfinity_BecomesNull_InPojo() {
1046+
java.util.Map<String, Float> data = new java.util.LinkedHashMap<>();
1047+
data.put("fObj", Float.NaN);
1048+
data.put("fPrimBox", Float.POSITIVE_INFINITY);
1049+
data.put("fNull", null);
1050+
1051+
String json = StringUtils.PLAIN_NUMBER_GSON.toJson(data);
1052+
1053+
assertTrue(json.contains("\"fObj\":null"));
1054+
assertTrue(json.contains("\"fNull\":null"));
1055+
assertTrue(json.contains("\"fPrimBox\":null") || !json.contains("\"fPrimBox\""));
1056+
}
1057+
1058+
@Test
1059+
public void testDeserializeScientificNotation_ToFloatAndPrimitive() {
1060+
String jsonObj = "{\"fObj\":1.23e-5}";
1061+
java.lang.reflect.Type mapType = new com.google.gson.reflect.TypeToken<java.util.Map<String, Float>>() {
1062+
}.getType();
1063+
java.util.Map<String, Float> m = StringUtils.PLAIN_NUMBER_GSON.fromJson(jsonObj, mapType);
1064+
assertEquals(1.23e-5f, m.get("fObj"), 1e-9f);
1065+
1066+
String jsonArr = "[4.56e1]";
1067+
float[] arr = StringUtils.PLAIN_NUMBER_GSON.fromJson(jsonArr, float[].class);
1068+
assertEquals(45.6f, arr[0], 1e-6f);
1069+
}
1070+
1071+
@Test
1072+
public void testDeserializeNullFloat_ToNull() {
1073+
String json = "{\"fObj\":null,\"fPrim\":1.0}";
1074+
1075+
java.lang.reflect.Type mapType = new TypeToken<java.util.Map<String, JsonElement>>() {
1076+
}.getType();
1077+
java.util.Map<String, JsonElement> m = StringUtils.PLAIN_NUMBER_GSON.fromJson(json, mapType);
1078+
1079+
assertTrue(m.containsKey("fObj"));
1080+
assertTrue(m.get("fObj").isJsonNull());
1081+
1082+
assertTrue(m.get("fPrim").isJsonPrimitive());
1083+
assertEquals(1.0f, m.get("fPrim").getAsFloat(), 1e-9f);
1084+
}
9521085
}

ml-algorithms/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ dependencies {
5454
// Multi-tenant SDK Client
5555
implementation "org.opensearch:opensearch-remote-metadata-sdk:${opensearch_build}"
5656
implementation 'commons-beanutils:commons-beanutils:1.11.0'
57+
implementation "org.opensearch:opensearch-remote-metadata-sdk-ddb-client:${opensearch_build}"
5758

5859
def os = DefaultNativePlatform.currentOperatingSystem
5960
//arm/macos doesn't support GPU

ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,15 @@ public Predictable deploy(MLModel mlModel, Map<String, Object> params) {
150150
return predictable;
151151
}
152152

153+
public void deploy(MLModel mlModel, Map<String, Object> params, ActionListener<Predictable> listener) {
154+
Predictable predictable = MLEngineClassLoader.initInstance(mlModel.getAlgorithm(), null, MLAlgoParams.class);
155+
predictable.initModelAsync(mlModel, params, encryptor).thenAccept((b) -> listener.onResponse(predictable)).exceptionally(e -> {
156+
log.error("Failed to init init model", e);
157+
listener.onFailure(new RuntimeException(e));
158+
return null;
159+
});
160+
}
161+
153162
public MLExecutable deployExecute(MLModel mlModel, Map<String, Object> params) {
154163
MLExecutable executable = MLEngineClassLoader.initInstance(mlModel.getAlgorithm(), null, MLAlgoParams.class);
155164
executable.initModel(mlModel, params);

ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package org.opensearch.ml.engine;
77

88
import java.util.Map;
9+
import java.util.concurrent.CompletionStage;
910

1011
import org.opensearch.core.action.ActionListener;
1112
import org.opensearch.ml.common.MLModel;
@@ -19,6 +20,8 @@
1920
*/
2021
public interface Predictable {
2122

23+
String METHOD_NOT_IMPLEMENTED_ERROR_MSG = "Method is not implemented";
24+
2225
/**
2326
* Predict with given input data and model.
2427
* Will reload model into memory with model content.
@@ -34,11 +37,11 @@ public interface Predictable {
3437
* @return predicted results
3538
*/
3639
default MLOutput predict(MLInput mlInput) {
37-
throw new IllegalStateException("Method is not implemented");
40+
throw new IllegalStateException(METHOD_NOT_IMPLEMENTED_ERROR_MSG);
3841
}
3942

4043
default void asyncPredict(MLInput mlInput, ActionListener<MLTaskResponse> actionListener) {
41-
actionListener.onFailure(new IllegalStateException("Method is not implemented"));
44+
actionListener.onFailure(new IllegalStateException(METHOD_NOT_IMPLEMENTED_ERROR_MSG));
4245
}
4346

4447
/**
@@ -47,7 +50,13 @@ default void asyncPredict(MLInput mlInput, ActionListener<MLTaskResponse> action
4750
* @param params other parameters
4851
* @param encryptor encryptor
4952
*/
50-
void initModel(MLModel model, Map<String, Object> params, Encryptor encryptor);
53+
default void initModel(MLModel model, Map<String, Object> params, Encryptor encryptor) {
54+
throw new IllegalStateException(METHOD_NOT_IMPLEMENTED_ERROR_MSG);
55+
}
56+
57+
default CompletionStage<Boolean> initModelAsync(MLModel model, Map<String, Object> params, Encryptor encryptor) {
58+
throw new IllegalStateException(METHOD_NOT_IMPLEMENTED_ERROR_MSG);
59+
}
5160

5261
/**
5362
* Close resources like deployed model.

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,8 @@ private void response() {
206206
ModelTensors tensors = processOutput(action, body, connector, scriptService, parameters, mlGuard);
207207
tensors.setStatusCode(statusCode);
208208
actionListener.onResponse(new Tuple<>(executionContext.getSequence(), tensors));
209+
} catch (IllegalArgumentException e) {
210+
actionListener.onFailure(e);
209211
} catch (Exception e) {
210212
log.error("Failed to process response body: {}", body, e);
211213
actionListener.onFailure(new MLException("Fail to execute " + action + " in aws connector", e));

0 commit comments

Comments
 (0)