@@ -581,4 +581,324 @@ public void testUpdateModelDeployStatusAndTriggerOnNodesAction_whenMLTaskManager
581
581
verify (mlTaskManager ).updateMLTask (anyString (), any (), anyMap (), anyLong (), anyBoolean ());
582
582
}
583
583
584
+ public void testDeployRemoteModel_success () {
585
+ MLModel mlModel = mock (MLModel .class );
586
+ when (mlModel .getModelId ()).thenReturn ("test-model-id" );
587
+ when (mlModel .getTenantId ()).thenReturn ("test-tenant" );
588
+ when (mlModel .getModelContentHash ()).thenReturn ("test-hash" );
589
+ when (mlModel .getIsHidden ()).thenReturn (false );
590
+
591
+ MLTask mlTask = mock (MLTask .class );
592
+ when (mlTask .getTaskId ()).thenReturn ("test-task-id" );
593
+
594
+ DiscoveryNode node = mock (DiscoveryNode .class );
595
+ when (node .getId ()).thenReturn ("node1" );
596
+ List <DiscoveryNode > nodes = List .of (node );
597
+
598
+ doAnswer (invocation -> {
599
+ ActionListener <UpdateResponse > listener = invocation .getArgument (3 );
600
+ listener .onResponse (mock (UpdateResponse .class ));
601
+ return null ;
602
+ }).when (mlModelManager ).updateModel (anyString (), anyString (), anyMap (), any ());
603
+
604
+ doAnswer (invocation -> {
605
+ ActionListener <MLDeployModelNodesResponse > listener = invocation .getArgument (2 );
606
+ listener .onResponse (mock (MLDeployModelNodesResponse .class ));
607
+ return null ;
608
+ }).when (client ).execute (any (), any (), any ());
609
+
610
+ when (mlTaskManager .contains (anyString ())).thenReturn (true );
611
+
612
+ ActionListener <MLDeployModelResponse > listener = mock (ActionListener .class );
613
+ transportDeployModelAction .deployRemoteModel (mlModel , mlTask , "local-node" , nodes , true , listener );
614
+
615
+ verify (listener ).onResponse (any (MLDeployModelResponse .class ));
616
+ }
617
+
618
+ public void testDeployRemoteModel_failure () {
619
+ MLModel mlModel = mock (MLModel .class );
620
+ when (mlModel .getModelId ()).thenReturn ("test-model-id" );
621
+ when (mlModel .getTenantId ()).thenReturn ("test-tenant" );
622
+ when (mlModel .getModelContentHash ()).thenReturn ("test-hash" );
623
+ when (mlModel .getIsHidden ()).thenReturn (false );
624
+
625
+ MLTask mlTask = mock (MLTask .class );
626
+ when (mlTask .getTaskId ()).thenReturn ("test-task-id" );
627
+
628
+ DiscoveryNode node = mock (DiscoveryNode .class );
629
+ when (node .getId ()).thenReturn ("node1" );
630
+ List <DiscoveryNode > nodes = List .of (node );
631
+
632
+ doAnswer (invocation -> {
633
+ ActionListener <UpdateResponse > listener = invocation .getArgument (3 );
634
+ listener .onFailure (new RuntimeException ("Update failed" ));
635
+ return null ;
636
+ }).when (mlModelManager ).updateModel (anyString (), anyString (), anyMap (), any ());
637
+
638
+ ActionListener <MLDeployModelResponse > listener = mock (ActionListener .class );
639
+ transportDeployModelAction .deployRemoteModel (mlModel , mlTask , "local-node" , nodes , true , listener );
640
+
641
+ verify (listener ).onFailure (any (RuntimeException .class ));
642
+ }
643
+
644
+ public void testDoExecute_deployToAllNodes_false () {
645
+ MLModel mlModel = mock (MLModel .class );
646
+ when (mlModel .getAlgorithm ()).thenReturn (FunctionName .ANOMALY_LOCALIZATION );
647
+ when (mlModel .getModelGroupId ()).thenReturn ("test-group-id" );
648
+ when (mlModel .getIsHidden ()).thenReturn (false );
649
+
650
+ // Use the existing mlDeployModelRequest but override specific nodes
651
+ when (mlDeployModelRequest .getModelNodeIds ()).thenReturn (new String [] { "node1" , "node2" });
652
+
653
+ doAnswer (invocation -> {
654
+ ActionListener <MLModel > listener = invocation .getArgument (4 );
655
+ listener .onResponse (mlModel );
656
+ return null ;
657
+ }).when (mlModelManager ).getModel (anyString (), any (), isNull (), any (String [].class ), any ());
658
+
659
+ // Set up multiple eligible nodes
660
+ DiscoveryNode node1 = mock (DiscoveryNode .class );
661
+ DiscoveryNode node2 = mock (DiscoveryNode .class );
662
+ when (node1 .getId ()).thenReturn ("node1" );
663
+ when (node2 .getId ()).thenReturn ("node2" );
664
+ DiscoveryNode [] nodes = { node1 , node2 };
665
+ when (nodeFilter .getEligibleNodes (any ())).thenReturn (nodes );
666
+ when (mlModelManager .getWorkerNodes (anyString (), any ())).thenReturn (null );
667
+
668
+ IndexResponse indexResponse = mock (IndexResponse .class );
669
+ when (indexResponse .getId ()).thenReturn ("task-id" );
670
+ doAnswer (invocation -> {
671
+ ActionListener <IndexResponse > listener = invocation .getArgument (1 );
672
+ listener .onResponse (indexResponse );
673
+ return null ;
674
+ }).when (mlTaskManager ).createMLTask (any (MLTask .class ), any ());
675
+
676
+ ActionListener <MLDeployModelResponse > listener = mock (ActionListener .class );
677
+ transportDeployModelAction .doExecute (null , mlDeployModelRequest , listener );
678
+
679
+ verify (listener ).onResponse (any (MLDeployModelResponse .class ));
680
+ }
681
+
682
+ public void testDoExecute_workerNodesConflict () {
683
+ MLModel mlModel = mock (MLModel .class );
684
+ when (mlModel .getAlgorithm ()).thenReturn (FunctionName .ANOMALY_LOCALIZATION );
685
+ when (mlModel .getModelGroupId ()).thenReturn ("test-group-id" );
686
+ when (mlModel .getIsHidden ()).thenReturn (false );
687
+
688
+ // Use the existing mlDeployModelRequest but override specific nodes
689
+ when (mlDeployModelRequest .getModelNodeIds ()).thenReturn (new String [] { "node1" });
690
+
691
+ doAnswer (invocation -> {
692
+ ActionListener <MLModel > listener = invocation .getArgument (4 );
693
+ listener .onResponse (mlModel );
694
+ return null ;
695
+ }).when (mlModelManager ).getModel (anyString (), any (), isNull (), any (String [].class ), any ());
696
+
697
+ // Set up eligible nodes
698
+ DiscoveryNode node1 = mock (DiscoveryNode .class );
699
+ DiscoveryNode node2 = mock (DiscoveryNode .class );
700
+ when (node1 .getId ()).thenReturn ("node1" );
701
+ when (node2 .getId ()).thenReturn ("node2" );
702
+ DiscoveryNode [] nodes = { node1 , node2 };
703
+ when (nodeFilter .getEligibleNodes (any ())).thenReturn (nodes );
704
+
705
+ // Set up worker nodes conflict - model is already deployed on node2 but target is node1
706
+ when (mlModelManager .getWorkerNodes (anyString (), any ())).thenReturn (new String [] { "node2" });
707
+
708
+ ActionListener <MLDeployModelResponse > listener = mock (ActionListener .class );
709
+ transportDeployModelAction .doExecute (null , mlDeployModelRequest , listener );
710
+
711
+ verify (listener ).onFailure (any (IllegalArgumentException .class ));
712
+ }
713
+
714
+ public void testDoExecute_noEligibleNodes () {
715
+ MLModel mlModel = mock (MLModel .class );
716
+ when (mlModel .getAlgorithm ()).thenReturn (FunctionName .ANOMALY_LOCALIZATION );
717
+ when (mlModel .getModelGroupId ()).thenReturn ("test-group-id" );
718
+ when (mlModel .getIsHidden ()).thenReturn (false );
719
+
720
+ // Use the existing mlDeployModelRequest but override to request non-existent node
721
+ when (mlDeployModelRequest .getModelNodeIds ()).thenReturn (new String [] { "non-existent-node" });
722
+
723
+ doAnswer (invocation -> {
724
+ ActionListener <MLModel > listener = invocation .getArgument (4 );
725
+ listener .onResponse (mlModel );
726
+ return null ;
727
+ }).when (mlModelManager ).getModel (anyString (), any (), isNull (), any (String [].class ), any ());
728
+
729
+ // Set up eligible nodes that don't match the requested nodes
730
+ DiscoveryNode existingNode = mock (DiscoveryNode .class );
731
+ when (existingNode .getId ()).thenReturn ("existing-node" );
732
+ DiscoveryNode [] nodes = { existingNode };
733
+ when (nodeFilter .getEligibleNodes (any ())).thenReturn (nodes );
734
+
735
+ ActionListener <MLDeployModelResponse > listener = mock (ActionListener .class );
736
+ transportDeployModelAction .doExecute (null , mlDeployModelRequest , listener );
737
+
738
+ verify (listener ).onFailure (any (IllegalArgumentException .class ));
739
+ }
740
+
741
+ public void testDoExecute_deployToAllNodes_true () {
742
+ MLModel mlModel = mock (MLModel .class );
743
+ when (mlModel .getAlgorithm ()).thenReturn (FunctionName .ANOMALY_LOCALIZATION );
744
+ when (mlModel .getModelGroupId ()).thenReturn ("test-group-id" );
745
+ when (mlModel .getIsHidden ()).thenReturn (false );
746
+
747
+ // Use null or empty array to trigger deployToAllNodes = true
748
+ when (mlDeployModelRequest .getModelNodeIds ()).thenReturn (null );
749
+
750
+ doAnswer (invocation -> {
751
+ ActionListener <MLModel > listener = invocation .getArgument (4 );
752
+ listener .onResponse (mlModel );
753
+ return null ;
754
+ }).when (mlModelManager ).getModel (anyString (), any (), isNull (), any (String [].class ), any ());
755
+
756
+ // Set up eligible nodes
757
+ DiscoveryNode node1 = mock (DiscoveryNode .class );
758
+ DiscoveryNode node2 = mock (DiscoveryNode .class );
759
+ when (node1 .getId ()).thenReturn ("node1" );
760
+ when (node2 .getId ()).thenReturn ("node2" );
761
+ DiscoveryNode [] nodes = { node1 , node2 };
762
+ when (nodeFilter .getEligibleNodes (any ())).thenReturn (nodes );
763
+
764
+ IndexResponse indexResponse = mock (IndexResponse .class );
765
+ when (indexResponse .getId ()).thenReturn ("task-id" );
766
+ doAnswer (invocation -> {
767
+ ActionListener <IndexResponse > listener = invocation .getArgument (1 );
768
+ listener .onResponse (indexResponse );
769
+ return null ;
770
+ }).when (mlTaskManager ).createMLTask (any (MLTask .class ), any ());
771
+
772
+ ActionListener <MLDeployModelResponse > listener = mock (ActionListener .class );
773
+ transportDeployModelAction .doExecute (null , mlDeployModelRequest , listener );
774
+
775
+ verify (listener ).onResponse (any (MLDeployModelResponse .class ));
776
+ }
777
+
778
+ public void testDoExecute_accessControlFailure () {
779
+ MLModel mlModel = mock (MLModel .class );
780
+ when (mlModel .getAlgorithm ()).thenReturn (FunctionName .ANOMALY_LOCALIZATION );
781
+ when (mlModel .getModelGroupId ()).thenReturn ("test-group-id" );
782
+ when (mlModel .getIsHidden ()).thenReturn (false );
783
+
784
+ doAnswer (invocation -> {
785
+ ActionListener <MLModel > listener = invocation .getArgument (4 );
786
+ listener .onResponse (mlModel );
787
+ return null ;
788
+ }).when (mlModelManager ).getModel (anyString (), any (), isNull (), any (String [].class ), any ());
789
+
790
+ // Mock access control to return false (no access)
791
+ doAnswer (invocation -> {
792
+ ActionListener <Boolean > listener = invocation .getArgument (3 );
793
+ listener .onResponse (false );
794
+ return null ;
795
+ }).when (modelAccessControlHelper ).validateModelGroupAccess (any (), anyString (), any (), any ());
796
+
797
+ ActionListener <MLDeployModelResponse > listener = mock (ActionListener .class );
798
+ transportDeployModelAction .doExecute (null , mlDeployModelRequest , listener );
799
+
800
+ verify (listener ).onFailure (any (OpenSearchStatusException .class ));
801
+ }
802
+
803
+ public void testDoExecute_hiddenModelNonSuperAdmin () {
804
+ MLModel mlModel = mock (MLModel .class );
805
+ when (mlModel .getAlgorithm ()).thenReturn (FunctionName .ANOMALY_LOCALIZATION );
806
+ when (mlModel .getModelGroupId ()).thenReturn ("test-group-id" );
807
+ when (mlModel .getIsHidden ()).thenReturn (true );
808
+
809
+ doAnswer (invocation -> {
810
+ ActionListener <MLModel > listener = invocation .getArgument (4 );
811
+ listener .onResponse (mlModel );
812
+ return null ;
813
+ }).when (mlModelManager ).getModel (anyString (), any (), isNull (), any (String [].class ), any ());
814
+
815
+ // Mock the isSuperAdminUserWrapper to return false (not super admin)
816
+ TransportDeployModelAction spyAction = spy (transportDeployModelAction );
817
+ doReturn (false ).when (spyAction ).isSuperAdminUserWrapper (any (), any ());
818
+
819
+ ActionListener <MLDeployModelResponse > listener = mock (ActionListener .class );
820
+ spyAction .doExecute (null , mlDeployModelRequest , listener );
821
+
822
+ verify (listener ).onFailure (any (OpenSearchStatusException .class ));
823
+ }
824
+
825
+ public void testDoExecute_taskManagerNotContains () {
826
+ MLModel mlModel = mock (MLModel .class );
827
+ when (mlModel .getAlgorithm ()).thenReturn (FunctionName .ANOMALY_LOCALIZATION );
828
+ when (mlModel .getModelGroupId ()).thenReturn ("test-group-id" );
829
+ when (mlModel .getIsHidden ()).thenReturn (false );
830
+
831
+ doAnswer (invocation -> {
832
+ ActionListener <MLModel > listener = invocation .getArgument (4 );
833
+ listener .onResponse (mlModel );
834
+ return null ;
835
+ }).when (mlModelManager ).getModel (anyString (), any (), isNull (), any (String [].class ), any ());
836
+
837
+ IndexResponse indexResponse = mock (IndexResponse .class );
838
+ when (indexResponse .getId ()).thenReturn ("task-id" );
839
+ doAnswer (invocation -> {
840
+ ActionListener <IndexResponse > listener = invocation .getArgument (1 );
841
+ listener .onResponse (indexResponse );
842
+ return null ;
843
+ }).when (mlTaskManager ).createMLTask (any (MLTask .class ), any ());
844
+
845
+ // Mock mlTaskManager.contains to return false
846
+ when (mlTaskManager .contains (anyString ())).thenReturn (false );
847
+
848
+ ActionListener <MLDeployModelResponse > listener = mock (ActionListener .class );
849
+ transportDeployModelAction .doExecute (null , mlDeployModelRequest , listener );
850
+
851
+ verify (listener ).onResponse (any (MLDeployModelResponse .class ));
852
+ }
853
+
854
+ public void testDoExecute_customDeploymentNotAllowed () {
855
+ // Override the settings to disable custom deployment plan
856
+ Settings restrictiveSettings = Settings .builder ().put (ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN .getKey (), false ).build ();
857
+ ClusterSettings restrictiveClusterSettings = new ClusterSettings (
858
+ restrictiveSettings ,
859
+ new HashSet <>(Arrays .asList (ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN ))
860
+ );
861
+ when (clusterService .getClusterSettings ()).thenReturn (restrictiveClusterSettings );
862
+ when (clusterService .getSettings ()).thenReturn (restrictiveSettings );
863
+
864
+ // Create a new instance with restrictive settings
865
+ TransportDeployModelAction restrictiveAction = new TransportDeployModelAction (
866
+ transportService ,
867
+ actionFilters ,
868
+ modelHelper ,
869
+ mlTaskManager ,
870
+ clusterService ,
871
+ threadPool ,
872
+ client ,
873
+ sdkClient ,
874
+ namedXContentRegistry ,
875
+ nodeFilter ,
876
+ mlTaskDispatcher ,
877
+ mlModelManager ,
878
+ mlStats ,
879
+ restrictiveSettings ,
880
+ modelAccessControlHelper ,
881
+ mlFeatureEnabledSetting
882
+ );
883
+
884
+ MLModel mlModel = mock (MLModel .class );
885
+ when (mlModel .getAlgorithm ()).thenReturn (FunctionName .ANOMALY_LOCALIZATION );
886
+ when (mlModel .getModelGroupId ()).thenReturn ("test-group-id" );
887
+ when (mlModel .getIsHidden ()).thenReturn (false );
888
+
889
+ // Set specific nodes (not deploy to all)
890
+ when (mlDeployModelRequest .getModelNodeIds ()).thenReturn (new String [] { "node1" });
891
+
892
+ doAnswer (invocation -> {
893
+ ActionListener <MLModel > listener = invocation .getArgument (4 );
894
+ listener .onResponse (mlModel );
895
+ return null ;
896
+ }).when (mlModelManager ).getModel (anyString (), any (), isNull (), any (String [].class ), any ());
897
+
898
+ ActionListener <MLDeployModelResponse > listener = mock (ActionListener .class );
899
+ restrictiveAction .doExecute (null , mlDeployModelRequest , listener );
900
+
901
+ verify (listener ).onFailure (any (IllegalArgumentException .class ));
902
+ }
903
+
584
904
}
0 commit comments