5
5
6
6
package org .opensearch .ml .engine .algorithms .remote ;
7
7
8
+ import static org .junit .Assert .assertEquals ;
8
9
import static org .mockito .Mockito .mock ;
9
10
import static org .mockito .Mockito .times ;
10
11
import static org .mockito .Mockito .verify ;
31
32
import org .opensearch .ml .common .connector .ConnectorAction ;
32
33
import org .opensearch .ml .common .connector .HttpConnector ;
33
34
import org .opensearch .ml .common .connector .MLPostProcessFunction ;
34
- import org .opensearch .ml .common .exception .MLException ;
35
35
import org .opensearch .ml .common .output .model .ModelTensors ;
36
36
import org .opensearch .script .ScriptService ;
37
37
import org .reactivestreams .Publisher ;
@@ -191,7 +191,7 @@ public void test_onError() {
191
191
ArgumentCaptor <Exception > captor = ArgumentCaptor .forClass (Exception .class );
192
192
verify (actionListener ).onFailure (captor .capture ());
193
193
assert captor .getValue () instanceof OpenSearchStatusException ;
194
- assert captor . getValue (). getMessage (). equals ( "Error communicating with remote model: runtime exception" );
194
+ assertEquals ( "Error communicating with remote model: runtime exception" , captor . getValue (). getMessage () );
195
195
}
196
196
197
197
@ Test
@@ -209,7 +209,7 @@ public void test_onSubscribe() {
209
209
public void test_onNext () {
210
210
test_onSubscribe ();// set the subscription to non-null.
211
211
responseSubscriber .onNext (ByteBuffer .wrap ("hello world" .getBytes ()));
212
- assert mlSdkAsyncHttpResponseHandler .getResponseBody ().toString (). equals ( "hello world" );
212
+ assertEquals ( "hello world" , mlSdkAsyncHttpResponseHandler .getResponseBody ().toString ());
213
213
}
214
214
215
215
@ Test
@@ -221,7 +221,7 @@ public void test_MLResponseSubscriber_onError() {
221
221
ArgumentCaptor <Exception > captor = ArgumentCaptor .forClass (Exception .class );
222
222
verify (actionListener , times (1 )).onFailure (captor .capture ());
223
223
assert captor .getValue () instanceof OpenSearchStatusException ;
224
- assert captor . getValue (). getMessage (). equals ( "Remote service returned error status 500 with empty body" );
224
+ assertEquals ( "Remote service returned error status 500 with empty body" , captor . getValue (). getMessage () );
225
225
}
226
226
227
227
@ Test
@@ -283,7 +283,7 @@ public void test_onComplete_failed() {
283
283
mlSdkAsyncHttpResponseHandler .onStream (stream );
284
284
ArgumentCaptor <OpenSearchStatusException > captor = ArgumentCaptor .forClass (OpenSearchStatusException .class );
285
285
verify (actionListener , times (1 )).onFailure (captor .capture ());
286
- assert captor . getValue (). getMessage (). equals ( "Error from remote service: Model current status is: FAILED" );
286
+ assertEquals ( "Error from remote service: Model current status is: FAILED" , captor . getValue (). getMessage () );
287
287
assert captor .getValue ().status ().getStatus () == 500 ;
288
288
}
289
289
@@ -302,7 +302,7 @@ public void test_onComplete_empty_response_body() {
302
302
mlSdkAsyncHttpResponseHandler .onStream (stream );
303
303
ArgumentCaptor <OpenSearchStatusException > captor = ArgumentCaptor .forClass (OpenSearchStatusException .class );
304
304
verify (actionListener , times (1 )).onFailure (captor .capture ());
305
- assert captor . getValue (). getMessage (). equals ( "Remote service returned empty response body" );
305
+ assertEquals ( "Remote service returned empty response body" , captor . getValue (). getMessage () );
306
306
}
307
307
308
308
@ Test
@@ -380,14 +380,12 @@ public void test_onComplete_throttle_exception_onFailure() {
380
380
381
381
ArgumentCaptor <OpenSearchStatusException > captor = ArgumentCaptor .forClass (RemoteConnectorThrottlingException .class );
382
382
verify (actionListener , times (1 )).onFailure (captor .capture ());
383
- assert captor
384
- .getValue ()
385
- .getMessage ()
386
- .equals (
387
- "Error from remote service: The request was denied due to remote server throttling. "
388
- + "To change the retry policy and behavior, please update the connector client_config."
389
- );
390
383
assert captor .getValue ().status ().getStatus () == HttpStatusCode .BAD_REQUEST ;
384
+ assertEquals (
385
+ "Error from remote service: The request was denied due to remote server throttling. "
386
+ + "To change the retry policy and behavior, please update the connector client_config." ,
387
+ captor .getValue ().getMessage ()
388
+ );
391
389
}
392
390
393
391
@ Test
@@ -416,8 +414,39 @@ public void test_onComplete_processOutputFail_onFailure() {
416
414
};
417
415
mlSdkAsyncHttpResponseHandler .onStream (stream );
418
416
419
- ArgumentCaptor <MLException > captor = ArgumentCaptor .forClass (MLException .class );
417
+ ArgumentCaptor <IllegalArgumentException > captor = ArgumentCaptor .forClass (IllegalArgumentException .class );
420
418
verify (actionListener , times (1 )).onFailure (captor .capture ());
421
- assert captor .getValue ().getMessage ().equals ("Fail to execute PREDICT in aws connector" );
419
+ assertEquals ("no PREDICT action found" , captor .getValue ().getMessage ());
420
+ }
421
+
422
+ /**
423
+ * Asserts that IllegalArgumentException is propagated where post-processing function fails
424
+ * on response
425
+ */
426
+ @ Test
427
+ public void onComplete_InvalidEmbeddingBedRockPostProcessingOccurs_IllegalArgumentExceptionThrown () {
428
+ String invalidEmbeddingResponse = "{ \" embedding\" : [[1]] }" ;
429
+
430
+ mlSdkAsyncHttpResponseHandler .onHeaders (sdkHttpResponse );
431
+ Publisher <ByteBuffer > stream = s -> {
432
+ try {
433
+ s .onSubscribe (mock (Subscription .class ));
434
+ s .onNext (ByteBuffer .wrap (invalidEmbeddingResponse .getBytes ()));
435
+ s .onComplete ();
436
+ } catch (Throwable e ) {
437
+ s .onError (e );
438
+ }
439
+ };
440
+ mlSdkAsyncHttpResponseHandler .onStream (stream );
441
+
442
+ ArgumentCaptor <IllegalArgumentException > exceptionCaptor = ArgumentCaptor .forClass (IllegalArgumentException .class );
443
+ verify (actionListener , times (1 )).onFailure (exceptionCaptor .capture ());
444
+
445
+ // Error message
446
+ assertEquals (
447
+ "BedrockEmbeddingPostProcessFunction exception message should match" ,
448
+ "The embedding should be a non-empty List containing Float values." ,
449
+ exceptionCaptor .getValue ().getMessage ()
450
+ );
422
451
}
423
452
}
0 commit comments