Skip to content

Commit 247a88c

Browse files
add PlainNumberAdapter and corresponding tests for Gson in SearchIndexTool (#4133)
* fix search index tool Signed-off-by: xinyual <[email protected]> * remove Signed-off-by: xinyual <[email protected]> * add PlainDoubleAdapter and corresponding tests for Gson in SearchIndexTool Signed-off-by: juqnuanp <[email protected]> * move PlainDoubleAdapter to StringUtils and optimized it in SearchIndexTool Signed-off-by: juqnuanp <[email protected]> * Avoid import * in StringUtils Signed-off-by: juqnuanp <[email protected]> * simplify Gson initialization in StringUtils Signed-off-by: juqnuanp <[email protected]> * update ut Signed-off-by: juqnuanp <[email protected]> * update ut Signed-off-by: juqnuanp <[email protected]> * add float adapter Signed-off-by: juqnuanp <[email protected]> --------- Signed-off-by: xinyual <[email protected]> Signed-off-by: juqnuanp <[email protected]> Co-authored-by: xinyual <[email protected]>
1 parent 7bf4b54 commit 247a88c

File tree

4 files changed

+233
-5
lines changed

4 files changed

+233
-5
lines changed

common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import static org.apache.commons.text.StringEscapeUtils.escapeJson;
99
import static org.opensearch.action.ValidateActions.addValidationError;
1010

11+
import java.io.IOException;
12+
import java.math.BigDecimal;
1113
import java.nio.ByteBuffer;
1214
import java.nio.charset.StandardCharsets;
1315
import java.security.AccessController;
@@ -38,11 +40,15 @@
3840
import com.fasterxml.jackson.databind.JsonNode;
3941
import com.fasterxml.jackson.databind.ObjectMapper;
4042
import com.google.gson.Gson;
43+
import com.google.gson.GsonBuilder;
4144
import com.google.gson.JsonElement;
4245
import com.google.gson.JsonObject;
4346
import com.google.gson.JsonParser;
4447
import com.google.gson.JsonSyntaxException;
48+
import com.google.gson.TypeAdapter;
4549
import com.google.gson.reflect.TypeToken;
50+
import com.google.gson.stream.JsonReader;
51+
import com.google.gson.stream.JsonWriter;
4652
import com.jayway.jsonpath.JsonPath;
4753
import com.jayway.jsonpath.PathNotFoundException;
4854
import com.networknt.schema.JsonSchema;
@@ -71,11 +77,15 @@ public class StringUtils {
7177

7278
public static final String SAFE_INPUT_DESCRIPTION = "can only contain letters, numbers, spaces, and basic punctuation (.,!?():@-_'/\")";
7379

74-
public static final Gson gson;
80+
public static final Gson gson = new Gson();
81+
public static final Gson PLAIN_NUMBER_GSON = new GsonBuilder()
82+
.serializeNulls()
83+
.registerTypeAdapter(Float.class, new PlainFloatAdapter())
84+
.registerTypeAdapter(float.class, new PlainFloatAdapter())
85+
.registerTypeAdapter(Double.class, new PlainDoubleAdapter())
86+
.registerTypeAdapter(double.class, new PlainDoubleAdapter())
87+
.create();
7588

76-
static {
77-
gson = new Gson();
78-
}
7989
public static final String TO_STRING_FUNCTION_NAME = ".toString()";
8090

8191
public static final ObjectMapper MAPPER = new ObjectMapper();
@@ -597,4 +607,49 @@ public static List<String> parseStringArrayToList(String jsonArrayString) {
597607
return Collections.emptyList();
598608
}
599609
}
610+
611+
/**
612+
* Custom Gson adapter for Double and Float type.
613+
* Serializes numbers without scientific notation.
614+
* Writes null for null, NaN, and Infinity values.
615+
* Deserializes JSON numbers back to Double and Float.
616+
*/
617+
private static class PlainDoubleAdapter extends TypeAdapter<Double> {
618+
@Override
619+
public void write(JsonWriter out, Double value) throws IOException {
620+
if (value == null || value.isNaN() || value.isInfinite()) {
621+
out.nullValue();
622+
return;
623+
}
624+
625+
BigDecimal bd = BigDecimal.valueOf(value).stripTrailingZeros();
626+
627+
out.jsonValue(bd.toPlainString());
628+
}
629+
630+
@Override
631+
public Double read(JsonReader in) throws IOException {
632+
return in.nextDouble();
633+
}
634+
}
635+
636+
public static class PlainFloatAdapter extends TypeAdapter<Float> {
637+
@Override
638+
public void write(JsonWriter out, Float value) throws IOException {
639+
if (value == null || value.isNaN() || value.isInfinite()) {
640+
out.nullValue();
641+
return;
642+
}
643+
644+
BigDecimal bd = new BigDecimal(Float.toString(value)).stripTrailingZeros();
645+
out.jsonValue(bd.toPlainString());
646+
}
647+
648+
@Override
649+
public Float read(JsonReader in) throws IOException {
650+
double d = in.nextDouble();
651+
float f = (float) d;
652+
return f;
653+
}
654+
}
600655
}

common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import static org.opensearch.ml.common.utils.StringUtils.*;
1717

1818
import java.io.IOException;
19+
import java.lang.reflect.Constructor;
1920
import java.util.ArrayList;
2021
import java.util.Arrays;
2122
import java.util.HashMap;
@@ -31,6 +32,9 @@
3132
import org.opensearch.OpenSearchParseException;
3233
import org.opensearch.action.ActionRequestValidationException;
3334

35+
import com.google.gson.JsonElement;
36+
import com.google.gson.TypeAdapter;
37+
import com.google.gson.reflect.TypeToken;
3438
import com.jayway.jsonpath.JsonPath;
3539

3640
public class StringUtilsTest {
@@ -949,4 +953,133 @@ public void testParseStringArrayToList_Null() {
949953
assertEquals(0, array.size());
950954
}
951955

956+
// reflect method for PlainDoubleAdapter
957+
private static TypeAdapter<Double> createPlainDoubleAdapter() {
958+
try {
959+
Class<?> clazz = Class.forName("org.opensearch.ml.common.utils.StringUtils$PlainDoubleAdapter");
960+
Constructor<?> constructor = clazz.getDeclaredConstructor();
961+
constructor.setAccessible(true);
962+
Object adapterInstance = constructor.newInstance();
963+
return (TypeAdapter<Double>) adapterInstance;
964+
} catch (Exception e) {
965+
throw new RuntimeException("Failed to create PlainDoubleAdapter via reflection", e);
966+
}
967+
}
968+
969+
@Test
970+
public void testSerializeScientificNotation_RemovesExponent() {
971+
Map<String, Object> data = Map.of("test1", 1e30, "test2", 1.2e3, "test3", 9.5e-3, "test4", 1.56e-30);
972+
973+
String json = StringUtils.PLAIN_NUMBER_GSON.toJson(data);
974+
975+
assertTrue(json.contains("1000000000000000000000000000000"));
976+
assertTrue(json.contains("1200"));
977+
assertTrue(json.contains("0.0095"));
978+
assertTrue(json.contains("0.00000000000000000000000000000156"));
979+
980+
}
981+
982+
@Test
983+
public void testSerializeInteger_RemovesDecimalPoint() {
984+
Map<String, Object> data = Map.of("intLike", 42.0);
985+
986+
String json = StringUtils.PLAIN_NUMBER_GSON.toJson(data);
987+
988+
assertTrue(json.contains("42"));
989+
assertFalse(json.contains("42.0"));
990+
}
991+
992+
@Test
993+
public void testSerializeNaNAndInfinity_BecomesNull() {
994+
Map<String, Double> data = new HashMap<>();
995+
data.put("nul", null);
996+
data.put("nan", Double.NaN);
997+
data.put("inf", Double.POSITIVE_INFINITY);
998+
data.put("ninf", Double.NEGATIVE_INFINITY);
999+
1000+
String json = StringUtils.PLAIN_NUMBER_GSON.toJson(data);
1001+
1002+
assertTrue(json.contains("\"nan\":null"));
1003+
assertTrue(json.contains("\"inf\":null"));
1004+
assertTrue(json.contains("\"ninf\":null"));
1005+
assertTrue(json.contains("\"nul\":null"));
1006+
1007+
assertFalse(json.contains("NaN"));
1008+
assertFalse(json.contains("Infinity"));
1009+
}
1010+
1011+
@Test
1012+
public void testDeserializeBackToDouble() {
1013+
String json = "{\"value\": 12345.6789}";
1014+
1015+
Map<?, ?> result = StringUtils.PLAIN_NUMBER_GSON.fromJson(json, Map.class);
1016+
1017+
Object value = result.get("value");
1018+
assertTrue(value instanceof Double);
1019+
assertEquals(12345.6789, (Double) value, 1e-7);
1020+
}
1021+
1022+
@Test
1023+
public void testQuotedScientificNotation_RemainsString() {
1024+
String json = "{\"code\":\"1e-6\"}";
1025+
1026+
Map<?, ?> result = StringUtils.PLAIN_NUMBER_GSON.fromJson(json, Map.class);
1027+
1028+
assertEquals("1e-6", result.get("code"));
1029+
}
1030+
1031+
@Test
1032+
public void testSerializeFloatScientificNotation_RemovesExponent_InPojo() {
1033+
java.util.Map<String, Float> data = new java.util.LinkedHashMap<>();
1034+
data.put("fObj", 1.23e-5f);
1035+
data.put("fPrim", 9.5e-3f);
1036+
1037+
String json = StringUtils.PLAIN_NUMBER_GSON.toJson(data);
1038+
1039+
assertTrue(json.contains("\"fObj\":0.0000123"));
1040+
1041+
assertTrue(json.contains("\"fPrim\":0.0095") || json.contains("\"fPrim\":9.5E-3") || json.contains("\"fPrim\":9.5e-3"));
1042+
}
1043+
1044+
@Test
1045+
public void testSerializeFloatNaNAndInfinity_BecomesNull_InPojo() {
1046+
java.util.Map<String, Float> data = new java.util.LinkedHashMap<>();
1047+
data.put("fObj", Float.NaN);
1048+
data.put("fPrimBox", Float.POSITIVE_INFINITY);
1049+
data.put("fNull", null);
1050+
1051+
String json = StringUtils.PLAIN_NUMBER_GSON.toJson(data);
1052+
1053+
assertTrue(json.contains("\"fObj\":null"));
1054+
assertTrue(json.contains("\"fNull\":null"));
1055+
assertTrue(json.contains("\"fPrimBox\":null") || !json.contains("\"fPrimBox\""));
1056+
}
1057+
1058+
@Test
1059+
public void testDeserializeScientificNotation_ToFloatAndPrimitive() {
1060+
String jsonObj = "{\"fObj\":1.23e-5}";
1061+
java.lang.reflect.Type mapType = new com.google.gson.reflect.TypeToken<java.util.Map<String, Float>>() {
1062+
}.getType();
1063+
java.util.Map<String, Float> m = StringUtils.PLAIN_NUMBER_GSON.fromJson(jsonObj, mapType);
1064+
assertEquals(1.23e-5f, m.get("fObj"), 1e-9f);
1065+
1066+
String jsonArr = "[4.56e1]";
1067+
float[] arr = StringUtils.PLAIN_NUMBER_GSON.fromJson(jsonArr, float[].class);
1068+
assertEquals(45.6f, arr[0], 1e-6f);
1069+
}
1070+
1071+
@Test
1072+
public void testDeserializeNullFloat_ToNull() {
1073+
String json = "{\"fObj\":null,\"fPrim\":1.0}";
1074+
1075+
java.lang.reflect.Type mapType = new TypeToken<java.util.Map<String, JsonElement>>() {
1076+
}.getType();
1077+
java.util.Map<String, JsonElement> m = StringUtils.PLAIN_NUMBER_GSON.fromJson(json, mapType);
1078+
1079+
assertTrue(m.containsKey("fObj"));
1080+
assertTrue(m.get("fObj").isJsonNull());
1081+
1082+
assertTrue(m.get("fPrim").isJsonPrimitive());
1083+
assertEquals(1.0f, m.get("fPrim").getAsFloat(), 1e-9f);
1084+
}
9521085
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package org.opensearch.ml.engine.tools;
77

88
import static org.opensearch.ml.common.CommonValue.*;
9+
import static org.opensearch.ml.common.utils.StringUtils.PLAIN_NUMBER_GSON;
910

1011
import java.io.IOException;
1112
import java.util.ArrayList;
@@ -175,7 +176,11 @@ public <T> void run(Map<String, String> originalParameters, ActionListener<T> li
175176
if (jsonObject != null && jsonObject.has(INDEX_FIELD) && jsonObject.has(QUERY_FIELD)) {
176177
index = jsonObject.get(INDEX_FIELD).getAsString();
177178
JsonElement queryElement = jsonObject.get(QUERY_FIELD);
178-
query = queryElement == null ? null : queryElement.toString();
179+
180+
if (queryElement != null) {
181+
Object queryObject = PLAIN_NUMBER_GSON.fromJson(queryElement, Object.class);
182+
query = PLAIN_NUMBER_GSON.toJson(queryObject);
183+
}
179184
}
180185
} catch (JsonSyntaxException e) {
181186
log.error("Invalid JSON input: {}", input, e);

ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/SearchIndexToolTests.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.opensearch.ml.engine.tools;
77

8+
import static org.junit.Assert.assertArrayEquals;
89
import static org.junit.Assert.assertEquals;
910
import static org.junit.Assert.assertFalse;
1011
import static org.junit.Assert.assertTrue;
@@ -24,6 +25,7 @@
2425
import org.junit.Test;
2526
import org.mockito.ArgumentCaptor;
2627
import org.mockito.Mockito;
28+
import org.opensearch.action.search.SearchRequest;
2729
import org.opensearch.action.search.SearchResponse;
2830
import org.opensearch.common.settings.Settings;
2931
import org.opensearch.common.xcontent.json.JsonXContent;
@@ -409,4 +411,37 @@ public void testRunWithoutReturnFullResponse() {
409411
assertFalse(((String) result).contains("took"));
410412
}
411413

414+
@Test
415+
@SneakyThrows
416+
public void testRun_withMatchQuery_triggersPlainDoubleGson() {
417+
String input = "{\"index\":\"test-index\",\"query\":{}}";
418+
Map<String, String> params = Map.of("input", input);
419+
@SuppressWarnings("unchecked")
420+
ActionListener<String> listener = mock(ActionListener.class);
421+
422+
mockedSearchIndexTool.run(params, listener);
423+
424+
ArgumentCaptor<SearchRequest> cap = ArgumentCaptor.forClass(SearchRequest.class);
425+
verify(client, times(1)).search(cap.capture(), any());
426+
verify(client, never()).execute(any(), any(), any());
427+
428+
assertArrayEquals(new String[] { "test-index" }, cap.getValue().indices());
429+
}
430+
431+
@Test
432+
@SneakyThrows
433+
public void testRun_withRangeQuery_triggersPlainDoubleGson() {
434+
String input = "{\"index\":\"test-index\",\"query\":{}}";
435+
Map<String, String> params = Map.of("input", input);
436+
@SuppressWarnings("unchecked")
437+
ActionListener<String> listener = mock(ActionListener.class);
438+
439+
mockedSearchIndexTool.run(params, listener);
440+
441+
ArgumentCaptor<SearchRequest> cap = ArgumentCaptor.forClass(SearchRequest.class);
442+
verify(client, times(1)).search(cap.capture(), any());
443+
verify(client, never()).execute(any(), any(), any());
444+
445+
assertArrayEquals(new String[] { "test-index" }, cap.getValue().indices());
446+
}
412447
}

0 commit comments

Comments
 (0)