Skip to content

Commit d8c6208

Browse files
authored
adding more unit tests (opensearch-project#4126)
Signed-off-by: Dhrubo Saha <[email protected]>
1 parent 6059afb commit d8c6208

File tree

3 files changed

+460
-2
lines changed

3 files changed

+460
-2
lines changed

plugin/build.gradle

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,8 +334,6 @@ jacocoTestReport {
334334

335335
List<String> jacocoExclusions = [
336336
// TODO: add more unit test to meet the minimal test coverage.
337-
'org.opensearch.ml.profile.MLPredictRequestStats',
338-
'org.opensearch.ml.action.deploy.TransportDeployModelAction',
339337
'org.opensearch.ml.action.deploy.TransportDeployModelOnNodeAction',
340338
'org.opensearch.ml.action.undeploy.TransportUndeployModelsAction',
341339
'org.opensearch.ml.action.prediction.TransportPredictionTaskAction',

plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java

Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,4 +581,324 @@ public void testUpdateModelDeployStatusAndTriggerOnNodesAction_whenMLTaskManager
581581
verify(mlTaskManager).updateMLTask(anyString(), any(), anyMap(), anyLong(), anyBoolean());
582582
}
583583

584+
public void testDeployRemoteModel_success() {
585+
MLModel mlModel = mock(MLModel.class);
586+
when(mlModel.getModelId()).thenReturn("test-model-id");
587+
when(mlModel.getTenantId()).thenReturn("test-tenant");
588+
when(mlModel.getModelContentHash()).thenReturn("test-hash");
589+
when(mlModel.getIsHidden()).thenReturn(false);
590+
591+
MLTask mlTask = mock(MLTask.class);
592+
when(mlTask.getTaskId()).thenReturn("test-task-id");
593+
594+
DiscoveryNode node = mock(DiscoveryNode.class);
595+
when(node.getId()).thenReturn("node1");
596+
List<DiscoveryNode> nodes = List.of(node);
597+
598+
doAnswer(invocation -> {
599+
ActionListener<UpdateResponse> listener = invocation.getArgument(3);
600+
listener.onResponse(mock(UpdateResponse.class));
601+
return null;
602+
}).when(mlModelManager).updateModel(anyString(), anyString(), anyMap(), any());
603+
604+
doAnswer(invocation -> {
605+
ActionListener<MLDeployModelNodesResponse> listener = invocation.getArgument(2);
606+
listener.onResponse(mock(MLDeployModelNodesResponse.class));
607+
return null;
608+
}).when(client).execute(any(), any(), any());
609+
610+
when(mlTaskManager.contains(anyString())).thenReturn(true);
611+
612+
ActionListener<MLDeployModelResponse> listener = mock(ActionListener.class);
613+
transportDeployModelAction.deployRemoteModel(mlModel, mlTask, "local-node", nodes, true, listener);
614+
615+
verify(listener).onResponse(any(MLDeployModelResponse.class));
616+
}
617+
618+
public void testDeployRemoteModel_failure() {
619+
MLModel mlModel = mock(MLModel.class);
620+
when(mlModel.getModelId()).thenReturn("test-model-id");
621+
when(mlModel.getTenantId()).thenReturn("test-tenant");
622+
when(mlModel.getModelContentHash()).thenReturn("test-hash");
623+
when(mlModel.getIsHidden()).thenReturn(false);
624+
625+
MLTask mlTask = mock(MLTask.class);
626+
when(mlTask.getTaskId()).thenReturn("test-task-id");
627+
628+
DiscoveryNode node = mock(DiscoveryNode.class);
629+
when(node.getId()).thenReturn("node1");
630+
List<DiscoveryNode> nodes = List.of(node);
631+
632+
doAnswer(invocation -> {
633+
ActionListener<UpdateResponse> listener = invocation.getArgument(3);
634+
listener.onFailure(new RuntimeException("Update failed"));
635+
return null;
636+
}).when(mlModelManager).updateModel(anyString(), anyString(), anyMap(), any());
637+
638+
ActionListener<MLDeployModelResponse> listener = mock(ActionListener.class);
639+
transportDeployModelAction.deployRemoteModel(mlModel, mlTask, "local-node", nodes, true, listener);
640+
641+
verify(listener).onFailure(any(RuntimeException.class));
642+
}
643+
644+
public void testDoExecute_deployToAllNodes_false() {
645+
MLModel mlModel = mock(MLModel.class);
646+
when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION);
647+
when(mlModel.getModelGroupId()).thenReturn("test-group-id");
648+
when(mlModel.getIsHidden()).thenReturn(false);
649+
650+
// Use the existing mlDeployModelRequest but override specific nodes
651+
when(mlDeployModelRequest.getModelNodeIds()).thenReturn(new String[] { "node1", "node2" });
652+
653+
doAnswer(invocation -> {
654+
ActionListener<MLModel> listener = invocation.getArgument(4);
655+
listener.onResponse(mlModel);
656+
return null;
657+
}).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), any());
658+
659+
// Set up multiple eligible nodes
660+
DiscoveryNode node1 = mock(DiscoveryNode.class);
661+
DiscoveryNode node2 = mock(DiscoveryNode.class);
662+
when(node1.getId()).thenReturn("node1");
663+
when(node2.getId()).thenReturn("node2");
664+
DiscoveryNode[] nodes = { node1, node2 };
665+
when(nodeFilter.getEligibleNodes(any())).thenReturn(nodes);
666+
when(mlModelManager.getWorkerNodes(anyString(), any())).thenReturn(null);
667+
668+
IndexResponse indexResponse = mock(IndexResponse.class);
669+
when(indexResponse.getId()).thenReturn("task-id");
670+
doAnswer(invocation -> {
671+
ActionListener<IndexResponse> listener = invocation.getArgument(1);
672+
listener.onResponse(indexResponse);
673+
return null;
674+
}).when(mlTaskManager).createMLTask(any(MLTask.class), any());
675+
676+
ActionListener<MLDeployModelResponse> listener = mock(ActionListener.class);
677+
transportDeployModelAction.doExecute(null, mlDeployModelRequest, listener);
678+
679+
verify(listener).onResponse(any(MLDeployModelResponse.class));
680+
}
681+
682+
public void testDoExecute_workerNodesConflict() {
683+
MLModel mlModel = mock(MLModel.class);
684+
when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION);
685+
when(mlModel.getModelGroupId()).thenReturn("test-group-id");
686+
when(mlModel.getIsHidden()).thenReturn(false);
687+
688+
// Use the existing mlDeployModelRequest but override specific nodes
689+
when(mlDeployModelRequest.getModelNodeIds()).thenReturn(new String[] { "node1" });
690+
691+
doAnswer(invocation -> {
692+
ActionListener<MLModel> listener = invocation.getArgument(4);
693+
listener.onResponse(mlModel);
694+
return null;
695+
}).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), any());
696+
697+
// Set up eligible nodes
698+
DiscoveryNode node1 = mock(DiscoveryNode.class);
699+
DiscoveryNode node2 = mock(DiscoveryNode.class);
700+
when(node1.getId()).thenReturn("node1");
701+
when(node2.getId()).thenReturn("node2");
702+
DiscoveryNode[] nodes = { node1, node2 };
703+
when(nodeFilter.getEligibleNodes(any())).thenReturn(nodes);
704+
705+
// Set up worker nodes conflict - model is already deployed on node2 but target is node1
706+
when(mlModelManager.getWorkerNodes(anyString(), any())).thenReturn(new String[] { "node2" });
707+
708+
ActionListener<MLDeployModelResponse> listener = mock(ActionListener.class);
709+
transportDeployModelAction.doExecute(null, mlDeployModelRequest, listener);
710+
711+
verify(listener).onFailure(any(IllegalArgumentException.class));
712+
}
713+
714+
public void testDoExecute_noEligibleNodes() {
715+
MLModel mlModel = mock(MLModel.class);
716+
when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION);
717+
when(mlModel.getModelGroupId()).thenReturn("test-group-id");
718+
when(mlModel.getIsHidden()).thenReturn(false);
719+
720+
// Use the existing mlDeployModelRequest but override to request non-existent node
721+
when(mlDeployModelRequest.getModelNodeIds()).thenReturn(new String[] { "non-existent-node" });
722+
723+
doAnswer(invocation -> {
724+
ActionListener<MLModel> listener = invocation.getArgument(4);
725+
listener.onResponse(mlModel);
726+
return null;
727+
}).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), any());
728+
729+
// Set up eligible nodes that don't match the requested nodes
730+
DiscoveryNode existingNode = mock(DiscoveryNode.class);
731+
when(existingNode.getId()).thenReturn("existing-node");
732+
DiscoveryNode[] nodes = { existingNode };
733+
when(nodeFilter.getEligibleNodes(any())).thenReturn(nodes);
734+
735+
ActionListener<MLDeployModelResponse> listener = mock(ActionListener.class);
736+
transportDeployModelAction.doExecute(null, mlDeployModelRequest, listener);
737+
738+
verify(listener).onFailure(any(IllegalArgumentException.class));
739+
}
740+
741+
public void testDoExecute_deployToAllNodes_true() {
742+
MLModel mlModel = mock(MLModel.class);
743+
when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION);
744+
when(mlModel.getModelGroupId()).thenReturn("test-group-id");
745+
when(mlModel.getIsHidden()).thenReturn(false);
746+
747+
// Use null or empty array to trigger deployToAllNodes = true
748+
when(mlDeployModelRequest.getModelNodeIds()).thenReturn(null);
749+
750+
doAnswer(invocation -> {
751+
ActionListener<MLModel> listener = invocation.getArgument(4);
752+
listener.onResponse(mlModel);
753+
return null;
754+
}).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), any());
755+
756+
// Set up eligible nodes
757+
DiscoveryNode node1 = mock(DiscoveryNode.class);
758+
DiscoveryNode node2 = mock(DiscoveryNode.class);
759+
when(node1.getId()).thenReturn("node1");
760+
when(node2.getId()).thenReturn("node2");
761+
DiscoveryNode[] nodes = { node1, node2 };
762+
when(nodeFilter.getEligibleNodes(any())).thenReturn(nodes);
763+
764+
IndexResponse indexResponse = mock(IndexResponse.class);
765+
when(indexResponse.getId()).thenReturn("task-id");
766+
doAnswer(invocation -> {
767+
ActionListener<IndexResponse> listener = invocation.getArgument(1);
768+
listener.onResponse(indexResponse);
769+
return null;
770+
}).when(mlTaskManager).createMLTask(any(MLTask.class), any());
771+
772+
ActionListener<MLDeployModelResponse> listener = mock(ActionListener.class);
773+
transportDeployModelAction.doExecute(null, mlDeployModelRequest, listener);
774+
775+
verify(listener).onResponse(any(MLDeployModelResponse.class));
776+
}
777+
778+
public void testDoExecute_accessControlFailure() {
779+
MLModel mlModel = mock(MLModel.class);
780+
when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION);
781+
when(mlModel.getModelGroupId()).thenReturn("test-group-id");
782+
when(mlModel.getIsHidden()).thenReturn(false);
783+
784+
doAnswer(invocation -> {
785+
ActionListener<MLModel> listener = invocation.getArgument(4);
786+
listener.onResponse(mlModel);
787+
return null;
788+
}).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), any());
789+
790+
// Mock access control to return false (no access)
791+
doAnswer(invocation -> {
792+
ActionListener<Boolean> listener = invocation.getArgument(3);
793+
listener.onResponse(false);
794+
return null;
795+
}).when(modelAccessControlHelper).validateModelGroupAccess(any(), anyString(), any(), any());
796+
797+
ActionListener<MLDeployModelResponse> listener = mock(ActionListener.class);
798+
transportDeployModelAction.doExecute(null, mlDeployModelRequest, listener);
799+
800+
verify(listener).onFailure(any(OpenSearchStatusException.class));
801+
}
802+
803+
public void testDoExecute_hiddenModelNonSuperAdmin() {
804+
MLModel mlModel = mock(MLModel.class);
805+
when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION);
806+
when(mlModel.getModelGroupId()).thenReturn("test-group-id");
807+
when(mlModel.getIsHidden()).thenReturn(true);
808+
809+
doAnswer(invocation -> {
810+
ActionListener<MLModel> listener = invocation.getArgument(4);
811+
listener.onResponse(mlModel);
812+
return null;
813+
}).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), any());
814+
815+
// Mock the isSuperAdminUserWrapper to return false (not super admin)
816+
TransportDeployModelAction spyAction = spy(transportDeployModelAction);
817+
doReturn(false).when(spyAction).isSuperAdminUserWrapper(any(), any());
818+
819+
ActionListener<MLDeployModelResponse> listener = mock(ActionListener.class);
820+
spyAction.doExecute(null, mlDeployModelRequest, listener);
821+
822+
verify(listener).onFailure(any(OpenSearchStatusException.class));
823+
}
824+
825+
public void testDoExecute_taskManagerNotContains() {
826+
MLModel mlModel = mock(MLModel.class);
827+
when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION);
828+
when(mlModel.getModelGroupId()).thenReturn("test-group-id");
829+
when(mlModel.getIsHidden()).thenReturn(false);
830+
831+
doAnswer(invocation -> {
832+
ActionListener<MLModel> listener = invocation.getArgument(4);
833+
listener.onResponse(mlModel);
834+
return null;
835+
}).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), any());
836+
837+
IndexResponse indexResponse = mock(IndexResponse.class);
838+
when(indexResponse.getId()).thenReturn("task-id");
839+
doAnswer(invocation -> {
840+
ActionListener<IndexResponse> listener = invocation.getArgument(1);
841+
listener.onResponse(indexResponse);
842+
return null;
843+
}).when(mlTaskManager).createMLTask(any(MLTask.class), any());
844+
845+
// Mock mlTaskManager.contains to return false
846+
when(mlTaskManager.contains(anyString())).thenReturn(false);
847+
848+
ActionListener<MLDeployModelResponse> listener = mock(ActionListener.class);
849+
transportDeployModelAction.doExecute(null, mlDeployModelRequest, listener);
850+
851+
verify(listener).onResponse(any(MLDeployModelResponse.class));
852+
}
853+
854+
public void testDoExecute_customDeploymentNotAllowed() {
855+
// Override the settings to disable custom deployment plan
856+
Settings restrictiveSettings = Settings.builder().put(ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN.getKey(), false).build();
857+
ClusterSettings restrictiveClusterSettings = new ClusterSettings(
858+
restrictiveSettings,
859+
new HashSet<>(Arrays.asList(ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN))
860+
);
861+
when(clusterService.getClusterSettings()).thenReturn(restrictiveClusterSettings);
862+
when(clusterService.getSettings()).thenReturn(restrictiveSettings);
863+
864+
// Create a new instance with restrictive settings
865+
TransportDeployModelAction restrictiveAction = new TransportDeployModelAction(
866+
transportService,
867+
actionFilters,
868+
modelHelper,
869+
mlTaskManager,
870+
clusterService,
871+
threadPool,
872+
client,
873+
sdkClient,
874+
namedXContentRegistry,
875+
nodeFilter,
876+
mlTaskDispatcher,
877+
mlModelManager,
878+
mlStats,
879+
restrictiveSettings,
880+
modelAccessControlHelper,
881+
mlFeatureEnabledSetting
882+
);
883+
884+
MLModel mlModel = mock(MLModel.class);
885+
when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION);
886+
when(mlModel.getModelGroupId()).thenReturn("test-group-id");
887+
when(mlModel.getIsHidden()).thenReturn(false);
888+
889+
// Set specific nodes (not deploy to all)
890+
when(mlDeployModelRequest.getModelNodeIds()).thenReturn(new String[] { "node1" });
891+
892+
doAnswer(invocation -> {
893+
ActionListener<MLModel> listener = invocation.getArgument(4);
894+
listener.onResponse(mlModel);
895+
return null;
896+
}).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), any());
897+
898+
ActionListener<MLDeployModelResponse> listener = mock(ActionListener.class);
899+
restrictiveAction.doExecute(null, mlDeployModelRequest, listener);
900+
901+
verify(listener).onFailure(any(IllegalArgumentException.class));
902+
}
903+
584904
}

0 commit comments

Comments
 (0)