Skip to content

Commit 2a15074

Browse files
committed
make MLSdkAsyncHttpResponseHandler return IllegalArgumentException
Signed-off-by: Brian Flores <[email protected]>
1 parent 281c430 commit 2a15074

File tree

2 files changed

+46
-15
lines changed

2 files changed

+46
-15
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,8 @@ private void response() {
206206
ModelTensors tensors = processOutput(action, body, connector, scriptService, parameters, mlGuard);
207207
tensors.setStatusCode(statusCode);
208208
actionListener.onResponse(new Tuple<>(executionContext.getSequence(), tensors));
209+
} catch (IllegalArgumentException e) {
210+
actionListener.onFailure(e);
209211
} catch (Exception e) {
210212
log.error("Failed to process response body: {}", body, e);
211213
actionListener.onFailure(new MLException("Fail to execute " + action + " in aws connector", e));

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.opensearch.ml.engine.algorithms.remote;
77

8+
import static org.junit.Assert.assertEquals;
89
import static org.mockito.Mockito.mock;
910
import static org.mockito.Mockito.times;
1011
import static org.mockito.Mockito.verify;
@@ -31,7 +32,6 @@
3132
import org.opensearch.ml.common.connector.ConnectorAction;
3233
import org.opensearch.ml.common.connector.HttpConnector;
3334
import org.opensearch.ml.common.connector.MLPostProcessFunction;
34-
import org.opensearch.ml.common.exception.MLException;
3535
import org.opensearch.ml.common.output.model.ModelTensors;
3636
import org.opensearch.script.ScriptService;
3737
import org.reactivestreams.Publisher;
@@ -191,7 +191,7 @@ public void test_onError() {
191191
ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(Exception.class);
192192
verify(actionListener).onFailure(captor.capture());
193193
assert captor.getValue() instanceof OpenSearchStatusException;
194-
assert captor.getValue().getMessage().equals("Error communicating with remote model: runtime exception");
194+
assertEquals("Error communicating with remote model: runtime exception", captor.getValue().getMessage());
195195
}
196196

197197
@Test
@@ -209,7 +209,7 @@ public void test_onSubscribe() {
209209
public void test_onNext() {
210210
test_onSubscribe();// set the subscription to non-null.
211211
responseSubscriber.onNext(ByteBuffer.wrap("hello world".getBytes()));
212-
assert mlSdkAsyncHttpResponseHandler.getResponseBody().toString().equals("hello world");
212+
assertEquals("hello world", mlSdkAsyncHttpResponseHandler.getResponseBody().toString());
213213
}
214214

215215
@Test
@@ -221,7 +221,7 @@ public void test_MLResponseSubscriber_onError() {
221221
ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(Exception.class);
222222
verify(actionListener, times(1)).onFailure(captor.capture());
223223
assert captor.getValue() instanceof OpenSearchStatusException;
224-
assert captor.getValue().getMessage().equals("Remote service returned error status 500 with empty body");
224+
assertEquals("Remote service returned error status 500 with empty body", captor.getValue().getMessage());
225225
}
226226

227227
@Test
@@ -283,7 +283,7 @@ public void test_onComplete_failed() {
283283
mlSdkAsyncHttpResponseHandler.onStream(stream);
284284
ArgumentCaptor<OpenSearchStatusException> captor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
285285
verify(actionListener, times(1)).onFailure(captor.capture());
286-
assert captor.getValue().getMessage().equals("Error from remote service: Model current status is: FAILED");
286+
assertEquals("Error from remote service: Model current status is: FAILED", captor.getValue().getMessage());
287287
assert captor.getValue().status().getStatus() == 500;
288288
}
289289

@@ -302,7 +302,7 @@ public void test_onComplete_empty_response_body() {
302302
mlSdkAsyncHttpResponseHandler.onStream(stream);
303303
ArgumentCaptor<OpenSearchStatusException> captor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
304304
verify(actionListener, times(1)).onFailure(captor.capture());
305-
assert captor.getValue().getMessage().equals("Remote service returned empty response body");
305+
assertEquals("Remote service returned empty response body", captor.getValue().getMessage());
306306
}
307307

308308
@Test
@@ -380,14 +380,12 @@ public void test_onComplete_throttle_exception_onFailure() {
380380

381381
ArgumentCaptor<OpenSearchStatusException> captor = ArgumentCaptor.forClass(RemoteConnectorThrottlingException.class);
382382
verify(actionListener, times(1)).onFailure(captor.capture());
383-
assert captor
384-
.getValue()
385-
.getMessage()
386-
.equals(
387-
"Error from remote service: The request was denied due to remote server throttling. "
388-
+ "To change the retry policy and behavior, please update the connector client_config."
389-
);
390383
assert captor.getValue().status().getStatus() == HttpStatusCode.BAD_REQUEST;
384+
assertEquals(
385+
"Error from remote service: The request was denied due to remote server throttling. "
386+
+ "To change the retry policy and behavior, please update the connector client_config.",
387+
captor.getValue().getMessage()
388+
);
391389
}
392390

393391
@Test
@@ -416,8 +414,39 @@ public void test_onComplete_processOutputFail_onFailure() {
416414
};
417415
mlSdkAsyncHttpResponseHandler.onStream(stream);
418416

419-
ArgumentCaptor<MLException> captor = ArgumentCaptor.forClass(MLException.class);
417+
ArgumentCaptor<IllegalArgumentException> captor = ArgumentCaptor.forClass(IllegalArgumentException.class);
420418
verify(actionListener, times(1)).onFailure(captor.capture());
421-
assert captor.getValue().getMessage().equals("Fail to execute PREDICT in aws connector");
419+
assertEquals("no PREDICT action found", captor.getValue().getMessage());
420+
}
421+
422+
/**
423+
* Asserts that IllegalArgumentException is propagated where post-processing function fails
424+
* on response
425+
*/
426+
@Test
427+
public void onComplete_InvalidEmbeddingBedRockPostProcessingOccurs_IllegalArgumentExceptionThrown() {
428+
String invalidEmbeddingResponse = "{ \"embedding\": [[1]] }";
429+
430+
mlSdkAsyncHttpResponseHandler.onHeaders(sdkHttpResponse);
431+
Publisher<ByteBuffer> stream = s -> {
432+
try {
433+
s.onSubscribe(mock(Subscription.class));
434+
s.onNext(ByteBuffer.wrap(invalidEmbeddingResponse.getBytes()));
435+
s.onComplete();
436+
} catch (Throwable e) {
437+
s.onError(e);
438+
}
439+
};
440+
mlSdkAsyncHttpResponseHandler.onStream(stream);
441+
442+
ArgumentCaptor<IllegalArgumentException> exceptionCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class);
443+
verify(actionListener, times(1)).onFailure(exceptionCaptor.capture());
444+
445+
// Error message
446+
assertEquals(
447+
"BedrockEmbeddingPostProcessFunction exception message should match",
448+
"The embedding should be a non-empty List containing Float values.",
449+
exceptionCaptor.getValue().getMessage()
450+
);
422451
}
423452
}

0 commit comments

Comments
 (0)