Skip to content

Commit 984b13a

Browse files
authored
adding more unit tests (#4124)
Signed-off-by: Dhrubo Saha <[email protected]>
1 parent 1a0e1ff commit 984b13a

File tree

4 files changed

+484
-175
lines changed

4 files changed

+484
-175
lines changed

plugin/build.gradle

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -334,10 +334,6 @@ jacocoTestReport {
334334

335335
List<String> jacocoExclusions = [
336336
// TODO: add more unit test to meet the minimal test coverage.
337-
'org.opensearch.ml.constant.CommonValue',
338-
'org.opensearch.ml.indices.MLIndicesHandler',
339-
'org.opensearch.ml.rest.RestMLPredictionAction',
340-
'org.opensearch.ml.profile.MLModelProfile',
341337
'org.opensearch.ml.profile.MLPredictRequestStats',
342338
'org.opensearch.ml.action.deploy.TransportDeployModelAction',
343339
'org.opensearch.ml.action.deploy.TransportDeployModelOnNodeAction',
@@ -360,7 +356,6 @@ List<String> jacocoExclusions = [
360356
'org.opensearch.ml.task.MLTrainAndPredictTaskRunner',
361357
'org.opensearch.ml.task.MLExecuteTaskRunner',
362358
'org.opensearch.ml.action.profile.MLProfileTransportAction',
363-
'org.opensearch.ml.rest.RestMLPredictionAction',
364359
'org.opensearch.ml.breaker.DiskCircuitBreaker',
365360
'org.opensearch.ml.autoredeploy.MLModelAutoReDeployer.SearchRequestBuilderFactory',
366361
'org.opensearch.ml.action.training.TrainingITTests',
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.constant;
7+
8+
import static org.junit.Assert.assertEquals;
9+
import static org.junit.Assert.assertNotNull;
10+
11+
import org.junit.Test;
12+
13+
public class CommonValueTests {
14+
15+
@Test
16+
public void testActionPrefix() {
17+
assertEquals("cluster:admin/opensearch/ml/", CommonValue.ACTION_PREFIX);
18+
}
19+
20+
@Test
21+
public void testConstructor() {
22+
// Test constructor to achieve full line coverage
23+
CommonValue commonValue = new CommonValue();
24+
assertNotNull(commonValue);
25+
}
26+
}
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.profile;
7+
8+
import static org.junit.Assert.assertEquals;
9+
import static org.junit.Assert.assertFalse;
10+
import static org.junit.Assert.assertNull;
11+
import static org.junit.Assert.assertTrue;
12+
13+
import java.io.IOException;
14+
15+
import org.junit.Assert;
16+
import org.junit.Test;
17+
import org.opensearch.Version;
18+
import org.opensearch.common.io.stream.BytesStreamOutput;
19+
import org.opensearch.common.xcontent.XContentFactory;
20+
import org.opensearch.core.common.io.stream.StreamInput;
21+
import org.opensearch.core.xcontent.XContentBuilder;
22+
import org.opensearch.ml.common.model.MLModelState;
23+
24+
public class MLModelProfileTests {
25+
26+
@Test
27+
public void testConstructorAndGetters() {
28+
String[] targetNodes = { "node1", "node2" };
29+
String[] workerNodes = { "worker1", "worker2" };
30+
MLPredictRequestStats stats = MLPredictRequestStats.builder().count(10L).max(5.0).min(1.0).average(3.0).build();
31+
32+
MLModelProfile profile = MLModelProfile
33+
.builder()
34+
.modelState(MLModelState.DEPLOYED)
35+
.predictor("test-predictor")
36+
.targetWorkerNodes(targetNodes)
37+
.workerNodes(workerNodes)
38+
.modelInferenceStats(stats)
39+
.predictRequestStats(stats)
40+
.memSizeEstimationCPU(1024L)
41+
.memSizeEstimationGPU(2048L)
42+
.build();
43+
44+
assertEquals(MLModelState.DEPLOYED, profile.getModelState());
45+
assertEquals("test-predictor", profile.getPredictor());
46+
assertEquals(targetNodes, profile.getTargetWorkerNodes());
47+
assertEquals(workerNodes, profile.getWorkerNodes());
48+
assertEquals(stats, profile.getModelInferenceStats());
49+
assertEquals(stats, profile.getPredictRequestStats());
50+
assertEquals(Long.valueOf(1024L), profile.getMemSizeEstimationCPU());
51+
assertEquals(Long.valueOf(2048L), profile.getMemSizeEstimationGPU());
52+
assertNull(profile.getIsHidden());
53+
}
54+
55+
@Test
56+
public void testConstructorWithNullValues() {
57+
MLModelProfile profile = MLModelProfile.builder().build();
58+
59+
assertNull(profile.getModelState());
60+
assertNull(profile.getPredictor());
61+
assertNull(profile.getTargetWorkerNodes());
62+
assertNull(profile.getWorkerNodes());
63+
assertNull(profile.getModelInferenceStats());
64+
assertNull(profile.getPredictRequestStats());
65+
assertNull(profile.getMemSizeEstimationCPU());
66+
assertNull(profile.getMemSizeEstimationGPU());
67+
assertNull(profile.getIsHidden());
68+
}
69+
70+
@Test
71+
public void testSetIsHidden() {
72+
MLModelProfile profile = MLModelProfile.builder().build();
73+
profile.setIsHidden(true);
74+
assertTrue(profile.getIsHidden());
75+
76+
profile.setIsHidden(false);
77+
assertFalse(profile.getIsHidden());
78+
}
79+
80+
@Test
81+
public void testToXContentWithAllFields() throws IOException {
82+
String[] targetNodes = { "node1" };
83+
String[] workerNodes = { "worker1" };
84+
MLPredictRequestStats stats = MLPredictRequestStats.builder().count(10L).max(5.0).build();
85+
86+
MLModelProfile profile = MLModelProfile
87+
.builder()
88+
.modelState(MLModelState.DEPLOYED)
89+
.predictor("test-predictor")
90+
.targetWorkerNodes(targetNodes)
91+
.workerNodes(workerNodes)
92+
.modelInferenceStats(stats)
93+
.predictRequestStats(stats)
94+
.memSizeEstimationCPU(1024L)
95+
.memSizeEstimationGPU(2048L)
96+
.build();
97+
profile.setIsHidden(true);
98+
99+
XContentBuilder builder = XContentFactory.jsonBuilder();
100+
profile.toXContent(builder, null);
101+
102+
String json = builder.toString();
103+
assertTrue(json.contains("\"model_state\":\"DEPLOYED\""));
104+
assertTrue(json.contains("\"predictor\":\"test-predictor\""));
105+
assertTrue(json.contains("\"target_worker_nodes\":[\"node1\"]"));
106+
assertTrue(json.contains("\"worker_nodes\":[\"worker1\"]"));
107+
assertTrue(json.contains("\"model_inference_stats\""));
108+
assertTrue(json.contains("\"predict_request_stats\""));
109+
assertTrue(json.contains("\"memory_size_estimation_cpu\":1024"));
110+
assertTrue(json.contains("\"memory_size_estimation_gpu\":2048"));
111+
assertTrue(json.contains("\"is_hidden\":true"));
112+
}
113+
114+
@Test
115+
public void testToXContentWithNullFields() throws IOException {
116+
MLModelProfile profile = MLModelProfile.builder().build();
117+
118+
XContentBuilder builder = XContentFactory.jsonBuilder();
119+
profile.toXContent(builder, null);
120+
121+
String json = builder.toString();
122+
assertEquals("{}", json);
123+
}
124+
125+
@Test
126+
public void testToXContentWithIsHiddenFalse() throws IOException {
127+
MLModelProfile profile = MLModelProfile.builder().build();
128+
profile.setIsHidden(false);
129+
130+
XContentBuilder builder = XContentFactory.jsonBuilder();
131+
profile.toXContent(builder, null);
132+
133+
String json = builder.toString();
134+
assertEquals("{}", json); // is_hidden is only included when true
135+
}
136+
137+
@Test
138+
public void testStreamSerializationWithAllFields() throws IOException {
139+
String[] targetNodes = { "node1" };
140+
String[] workerNodes = { "worker1" };
141+
MLPredictRequestStats stats = MLPredictRequestStats.builder().count(10L).max(5.0).build();
142+
143+
MLModelProfile original = MLModelProfile
144+
.builder()
145+
.modelState(MLModelState.DEPLOYED)
146+
.predictor("test-predictor")
147+
.targetWorkerNodes(targetNodes)
148+
.workerNodes(workerNodes)
149+
.modelInferenceStats(stats)
150+
.predictRequestStats(stats)
151+
.memSizeEstimationCPU(1024L)
152+
.memSizeEstimationGPU(2048L)
153+
.build();
154+
original.setIsHidden(true);
155+
156+
BytesStreamOutput output = new BytesStreamOutput();
157+
output.setVersion(Version.CURRENT);
158+
original.writeTo(output);
159+
160+
StreamInput input = output.bytes().streamInput();
161+
input.setVersion(Version.CURRENT);
162+
MLModelProfile deserialized = new MLModelProfile(input);
163+
164+
assertEquals(original.getModelState(), deserialized.getModelState());
165+
assertEquals(original.getPredictor(), deserialized.getPredictor());
166+
Assert.assertNotNull(deserialized.getTargetWorkerNodes());
167+
assertEquals(original.getTargetWorkerNodes()[0], deserialized.getTargetWorkerNodes()[0]);
168+
Assert.assertNotNull(deserialized.getWorkerNodes());
169+
assertEquals(original.getWorkerNodes()[0], deserialized.getWorkerNodes()[0]);
170+
assertEquals(original.getMemSizeEstimationCPU(), deserialized.getMemSizeEstimationCPU());
171+
assertEquals(original.getMemSizeEstimationGPU(), deserialized.getMemSizeEstimationGPU());
172+
assertEquals(original.getIsHidden(), deserialized.getIsHidden());
173+
}
174+
175+
@Test
176+
public void testStreamSerializationWithNullFields() throws IOException {
177+
MLModelProfile original = MLModelProfile.builder().build();
178+
179+
BytesStreamOutput output = new BytesStreamOutput();
180+
output.setVersion(Version.CURRENT);
181+
original.writeTo(output);
182+
183+
StreamInput input = output.bytes().streamInput();
184+
input.setVersion(Version.CURRENT);
185+
MLModelProfile deserialized = new MLModelProfile(input);
186+
187+
assertNull(deserialized.getModelState());
188+
assertNull(deserialized.getPredictor());
189+
assertNull(deserialized.getTargetWorkerNodes());
190+
assertNull(deserialized.getWorkerNodes());
191+
assertNull(deserialized.getModelInferenceStats());
192+
assertNull(deserialized.getPredictRequestStats());
193+
assertNull(deserialized.getMemSizeEstimationCPU());
194+
assertNull(deserialized.getMemSizeEstimationGPU());
195+
}
196+
}

0 commit comments

Comments
 (0)