Skip to content

Commit 937f3fe

Browse files
authored
[2.19 backport] refactors undeploy models client with sdkClient bulk op (opensearch-project#4077)
* refactors undeploy models client with sdkClient bulk op Signed-off-by: Brian Flores <[email protected]> * add warn log to undeploy action Signed-off-by: Brian Flores <[email protected]> * apply spotles Signed-off-by: Brian Flores <[email protected]> --------- Signed-off-by: Brian Flores <[email protected]>
1 parent 8fca9cc commit 937f3fe

File tree

2 files changed

+95
-28
lines changed

2 files changed

+95
-28
lines changed

plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java

Lines changed: 64 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,21 @@
1010

1111
import java.time.Instant;
1212
import java.util.Arrays;
13+
import java.util.HashMap;
1314
import java.util.List;
1415
import java.util.Map;
1516
import java.util.stream.Collectors;
1617

1718
import org.opensearch.ExceptionsHelper;
1819
import org.opensearch.OpenSearchStatusException;
1920
import org.opensearch.action.ActionRequest;
20-
import org.opensearch.action.bulk.BulkRequest;
21+
import org.opensearch.action.bulk.BulkItemResponse;
2122
import org.opensearch.action.bulk.BulkResponse;
2223
import org.opensearch.action.search.SearchRequest;
2324
import org.opensearch.action.search.SearchResponse;
2425
import org.opensearch.action.support.ActionFilters;
2526
import org.opensearch.action.support.HandledTransportAction;
2627
import org.opensearch.action.support.WriteRequest;
27-
import org.opensearch.action.update.UpdateRequest;
2828
import org.opensearch.client.Client;
2929
import org.opensearch.cluster.service.ClusterService;
3030
import org.opensearch.common.inject.Inject;
@@ -56,8 +56,10 @@
5656
import org.opensearch.ml.task.MLTaskManager;
5757
import org.opensearch.ml.utils.RestActionUtils;
5858
import org.opensearch.ml.utils.TenantAwareHelper;
59+
import org.opensearch.remote.metadata.client.BulkDataObjectRequest;
5960
import org.opensearch.remote.metadata.client.SdkClient;
6061
import org.opensearch.remote.metadata.client.SearchDataObjectRequest;
62+
import org.opensearch.remote.metadata.client.UpdateDataObjectRequest;
6163
import org.opensearch.remote.metadata.common.SdkClientUtils;
6264
import org.opensearch.search.SearchHit;
6365
import org.opensearch.search.builder.SearchSourceBuilder;
@@ -66,7 +68,6 @@
6668
import org.opensearch.transport.TransportService;
6769

6870
import com.google.common.annotations.VisibleForTesting;
69-
import com.google.common.collect.ImmutableMap;
7071

7172
import lombok.extern.log4j.Log4j2;
7273

@@ -213,7 +214,13 @@ private void undeployModels(
213214
return modelCacheMissForModelIds;
214215
});
215216
if (response.getNodes().isEmpty() || modelNotFoundInNodesCache) {
216-
bulkSetModelIndexToUndeploy(modelIds, listener, response);
217+
log
218+
.warn(
219+
"Model undeployment fallback: No active nodes found for models {}."
220+
+ " Proceeding with manual index update to UNDEPLOY state.",
221+
Arrays.toString(modelIds)
222+
);
223+
bulkSetModelIndexToUndeploy(modelIds, tenantId, listener, response);
217224
return;
218225
}
219226
listener.onResponse(new MLUndeployModelsResponse(response));
@@ -222,34 +229,39 @@ private void undeployModels(
222229

223230
private void bulkSetModelIndexToUndeploy(
224231
String[] modelIds,
232+
String tenantId,
225233
ActionListener<MLUndeployModelsResponse> listener,
226-
MLUndeployModelNodesResponse response
234+
MLUndeployModelNodesResponse mlUndeployModelNodesResponse
227235
) {
228-
BulkRequest bulkUpdateRequest = new BulkRequest();
236+
BulkDataObjectRequest bulkRequest = BulkDataObjectRequest.builder().globalIndex(ML_MODEL_INDEX).build();
237+
229238
for (String modelId : modelIds) {
230-
UpdateRequest updateRequest = new UpdateRequest();
231239

232-
ImmutableMap.Builder<String, Object> builder = ImmutableMap.builder();
233-
builder.put(MLModel.MODEL_STATE_FIELD, MLModelState.UNDEPLOYED.name());
240+
Map<String, Object> updateDocument = new HashMap<>();
234241

235-
builder.put(MLModel.PLANNING_WORKER_NODES_FIELD, List.of());
236-
builder.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, 0);
242+
updateDocument.put(MLModel.MODEL_STATE_FIELD, MLModelState.UNDEPLOYED.name());
243+
updateDocument.put(MLModel.PLANNING_WORKER_NODES_FIELD, List.of());
244+
updateDocument.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, 0);
245+
updateDocument.put(MLModel.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli());
246+
updateDocument.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, 0);
237247

238-
builder.put(MLModel.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli());
239-
builder.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, 0);
240-
updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(builder.build());
241-
bulkUpdateRequest.add(updateRequest);
248+
UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest
249+
.builder()
250+
.id(modelId)
251+
.tenantId(tenantId)
252+
.dataObject(updateDocument)
253+
.build();
254+
bulkRequest.add(updateRequest).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
242255
}
243256

244-
bulkUpdateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
245257
log.info("No nodes running these models: {}", Arrays.toString(modelIds));
246258

247259
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
248260
ActionListener<MLUndeployModelsResponse> listenerWithContextRestoration = ActionListener
249261
.runBefore(listener, () -> threadContext.restore());
262+
250263
ActionListener<BulkResponse> bulkResponseListener = ActionListener.wrap(br -> {
251-
log.debug("Successfully set the following modelId(s) to UNDEPLOY in index: {}", Arrays.toString(modelIds));
252-
listenerWithContextRestoration.onResponse(new MLUndeployModelsResponse(response));
264+
listenerWithContextRestoration.onResponse(new MLUndeployModelsResponse(mlUndeployModelNodesResponse));
253265
}, e -> {
254266
String modelsNotFoundMessage = String
255267
.format("Failed to set the following modelId(s) to UNDEPLOY in index: %s", Arrays.toString(modelIds));
@@ -262,7 +274,40 @@ private void bulkSetModelIndexToUndeploy(
262274
listenerWithContextRestoration.onFailure(exception);
263275
});
264276

265-
client.bulk(bulkUpdateRequest, bulkResponseListener);
277+
sdkClient.bulkDataObjectAsync(bulkRequest).whenComplete((response, exception) -> {
278+
if (exception != null) {
279+
Exception cause = SdkClientUtils.unwrapAndConvertToException(exception, OpenSearchStatusException.class);
280+
bulkResponseListener.onFailure(cause);
281+
return;
282+
}
283+
284+
try {
285+
BulkResponse bulkResponse = BulkResponse.fromXContent(response.parser());
286+
log
287+
.info(
288+
"Executed {} bulk operations with {} failures, Took: {}",
289+
bulkResponse.getItems().length,
290+
bulkResponse.hasFailures()
291+
? Arrays.stream(bulkResponse.getItems()).filter(BulkItemResponse::isFailed).count()
292+
: 0,
293+
bulkResponse.getTook()
294+
);
295+
List<String> unemployedModelIds = Arrays
296+
.stream(bulkResponse.getItems())
297+
.filter(bulkItemResponse -> !bulkItemResponse.isFailed())
298+
.map(BulkItemResponse::getId)
299+
.collect(Collectors.toList());
300+
log
301+
.debug(
302+
"Successfully set the following modelId(s) to UNDEPLOY in index: {}",
303+
Arrays.toString(unemployedModelIds.toArray())
304+
);
305+
306+
bulkResponseListener.onResponse(bulkResponse);
307+
} catch (Exception e) {
308+
bulkResponseListener.onFailure(e);
309+
}
310+
});
266311
} catch (Exception e) {
267312
log.error("Unexpected error while setting the following modelId(s) to UNDEPLOY in index: {}", Arrays.toString(modelIds), e);
268313
listener.onFailure(e);

plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import java.io.IOException;
2525
import java.util.ArrayList;
26+
import java.util.Collections;
2627
import java.util.HashMap;
2728
import java.util.List;
2829
import java.util.Map;
@@ -32,12 +33,17 @@
3233
import org.junit.rules.ExpectedException;
3334
import org.mockito.ArgumentCaptor;
3435
import org.mockito.Mock;
36+
import org.mockito.Mockito;
3537
import org.mockito.MockitoAnnotations;
38+
import org.opensearch.action.DocWriteRequest;
39+
import org.opensearch.action.DocWriteResponse;
3640
import org.opensearch.action.FailedNodeException;
41+
import org.opensearch.action.bulk.BulkItemResponse;
3742
import org.opensearch.action.bulk.BulkRequest;
3843
import org.opensearch.action.bulk.BulkResponse;
3944
import org.opensearch.action.support.ActionFilters;
4045
import org.opensearch.action.update.UpdateRequest;
46+
import org.opensearch.action.update.UpdateResponse;
4147
import org.opensearch.client.Client;
4248
import org.opensearch.cluster.ClusterName;
4349
import org.opensearch.cluster.service.ClusterService;
@@ -46,6 +52,7 @@
4652
import org.opensearch.commons.ConfigConstants;
4753
import org.opensearch.commons.authuser.User;
4854
import org.opensearch.core.action.ActionListener;
55+
import org.opensearch.core.index.shard.ShardId;
4956
import org.opensearch.core.xcontent.NamedXContentRegistry;
5057
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
5158
import org.opensearch.ml.common.FunctionName;
@@ -62,6 +69,7 @@
6269
import org.opensearch.ml.task.MLTaskDispatcher;
6370
import org.opensearch.ml.task.MLTaskManager;
6471
import org.opensearch.remote.metadata.client.SdkClient;
72+
import org.opensearch.remote.metadata.client.impl.SdkClientFactory;
6573
import org.opensearch.tasks.Task;
6674
import org.opensearch.test.OpenSearchTestCase;
6775
import org.opensearch.threadpool.ThreadPool;
@@ -135,6 +143,8 @@ public class TransportUndeployModelsActionTests extends OpenSearchTestCase {
135143
public void setup() throws IOException {
136144
MockitoAnnotations.openMocks(this);
137145
Settings settings = Settings.builder().build();
146+
sdkClient = Mockito.spy(SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()));
147+
138148
transportUndeployModelsAction = spy(
139149
new TransportUndeployModelsAction(
140150
transportService,
@@ -217,11 +227,10 @@ public void testDoExecute_undeployModelIndex_WhenNoNodesServiceModel() {
217227

218228
ArgumentCaptor<BulkRequest> bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class);
219229

230+
BulkResponse bulkResponse = getSuccessBulkResponse();
220231
// mock the bulk response that can be captured for inspecting the contents of the write to index
221232
doAnswer(invocation -> {
222233
ActionListener<BulkResponse> listener = invocation.getArgument(1);
223-
BulkResponse bulkResponse = mock(BulkResponse.class);
224-
when(bulkResponse.hasFailures()).thenReturn(false);
225234
listener.onResponse(bulkResponse);
226235
return null;
227236
}).when(client).bulk(bulkRequestCaptor.capture(), any(ActionListener.class));
@@ -333,11 +342,10 @@ public void testHiddenModelSuccess() {
333342
return null;
334343
}).when(client).execute(any(), any(), isA(ActionListener.class));
335344

345+
BulkResponse bulkResponse = getSuccessBulkResponse();
336346
// Mock the client.bulk call
337347
doAnswer(invocation -> {
338348
ActionListener<BulkResponse> listener = invocation.getArgument(1);
339-
BulkResponse bulkResponse = mock(BulkResponse.class);
340-
when(bulkResponse.hasFailures()).thenReturn(false);
341349
listener.onResponse(bulkResponse);
342350
return null;
343351
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));
@@ -392,9 +400,10 @@ public void testDoExecute_bulkRequestFired_WhenModelNotFoundInAllNodes() {
392400
return null;
393401
}).when(client).execute(any(), any(), isA(ActionListener.class));
394402

403+
BulkResponse bulkResponse = getSuccessBulkResponse();
395404
doAnswer(invocation -> {
396405
ActionListener<BulkResponse> listener = invocation.getArgument(1);
397-
listener.onResponse(mock(BulkResponse.class));
406+
listener.onResponse(bulkResponse);
398407
return null;
399408
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));
400409

@@ -458,17 +467,18 @@ public void testDoExecute() {
458467
listener.onResponse(response);
459468
return null;
460469
}).when(client).execute(any(), any(), isA(ActionListener.class));
461-
// Mock the client.bulk call
470+
471+
BulkResponse bulkResponse = getSuccessBulkResponse();
462472
doAnswer(invocation -> {
463473
ActionListener<BulkResponse> listener = invocation.getArgument(1);
464-
BulkResponse bulkResponse = mock(BulkResponse.class);
465-
when(bulkResponse.hasFailures()).thenReturn(false);
466474
listener.onResponse(bulkResponse);
467475
return null;
468-
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));
476+
}).when(client).bulk(any(), any());
469477

470478
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null);
471479
transportUndeployModelsAction.doExecute(task, request, actionListener);
480+
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
481+
472482
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
473483
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
474484
}
@@ -534,4 +544,16 @@ public void testDoExecute_modelIds_moreThan1() {
534544
MLUndeployModelsRequest request = new MLUndeployModelsRequest(new String[] { "modelId1", "modelId2" }, nodeIds, null);
535545
transportUndeployModelsAction.doExecute(task, request, actionListener);
536546
}
547+
548+
private BulkResponse getSuccessBulkResponse() {
549+
return new BulkResponse(
550+
new BulkItemResponse[] {
551+
new BulkItemResponse(
552+
1,
553+
DocWriteRequest.OpType.UPDATE,
554+
new UpdateResponse(new ShardId(ML_MODEL_INDEX, "modelId123", 0), "id1", 1, 1, 1, DocWriteResponse.Result.UPDATED)
555+
) },
556+
100L
557+
);
558+
}
537559
}

0 commit comments

Comments
 (0)