From 393fe0558e34513a81c2a42a4352f85da8a70a67 Mon Sep 17 00:00:00 2001 From: Pat O'Connor Date: Mon, 25 Aug 2025 11:03:31 +0100 Subject: [PATCH 1/2] feat: Add lifecycled networkpolicies options for raycluster hardening Signed-off-by: Pat O'Connor --- helm-chart/kuberay-operator/README.md | 2 + .../kuberay-operator/templates/_helpers.tpl | 14 +- helm-chart/kuberay-operator/values.yaml | 2 + ray-operator/config/manager/manager.yaml | 5 + ray-operator/config/rbac/role.yaml | 14 +- .../ray/networkpolicy_controller.go | 389 +++++++++++ .../ray/networkpolicy_controller_test.go | 422 ++++++++++++ .../ray/networkpolicy_controller_unit_test.go | 607 ++++++++++++++++++ ray-operator/controllers/ray/suite_test.go | 16 + .../controllers/ray/utils/constant.go | 4 + ray-operator/main.go | 9 + ray-operator/pkg/features/features.go | 4 + 12 files changed, 1486 insertions(+), 2 deletions(-) create mode 100644 ray-operator/controllers/ray/networkpolicy_controller.go create mode 100644 ray-operator/controllers/ray/networkpolicy_controller_test.go create mode 100644 ray-operator/controllers/ray/networkpolicy_controller_unit_test.go diff --git a/helm-chart/kuberay-operator/README.md b/helm-chart/kuberay-operator/README.md index 6837698d597..62175b70263 100644 --- a/helm-chart/kuberay-operator/README.md +++ b/helm-chart/kuberay-operator/README.md @@ -165,6 +165,8 @@ spec: | featureGates[0].enabled | bool | `true` | | | featureGates[1].name | string | `"RayJobDeletionPolicy"` | | | featureGates[1].enabled | bool | `false` | | +| featureGates[2].name | string | `"RayClusterNetworkPolicy"` | | +| featureGates[2].enabled | bool | `false` | | | metrics.enabled | bool | `true` | Whether KubeRay operator should emit control plane metrics. | | metrics.serviceMonitor.enabled | bool | `false` | Enable a prometheus ServiceMonitor | | metrics.serviceMonitor.interval | string | `"30s"` | Prometheus ServiceMonitor interval | diff --git a/helm-chart/kuberay-operator/templates/_helpers.tpl b/helm-chart/kuberay-operator/templates/_helpers.tpl index 5d14510a61b..b7fffca21b8 100644 --- a/helm-chart/kuberay-operator/templates/_helpers.tpl +++ b/helm-chart/kuberay-operator/templates/_helpers.tpl @@ -211,7 +211,6 @@ rules: - update - apiGroups: - extensions - - networking.k8s.io resources: - ingresses verbs: @@ -230,6 +229,19 @@ rules: - get - list - watch +- apiGroups: + - networking.k8s.io + resources: + - ingresses + - networkpolicies + verbs: + - create + - delete + - get + - list + - patch + - update + - watch - apiGroups: - ray.io resources: diff --git a/helm-chart/kuberay-operator/values.yaml b/helm-chart/kuberay-operator/values.yaml index 6010d7f2b3e..6ca514330e7 100644 --- a/helm-chart/kuberay-operator/values.yaml +++ b/helm-chart/kuberay-operator/values.yaml @@ -88,6 +88,8 @@ featureGates: enabled: true - name: RayJobDeletionPolicy enabled: false +- name: RayClusterNetworkPolicy + enabled: false # Configurations for KubeRay operator metrics. metrics: diff --git a/ray-operator/config/manager/manager.yaml b/ray-operator/config/manager/manager.yaml index aa9125d5543..5509b5feb9f 100644 --- a/ray-operator/config/manager/manager.yaml +++ b/ray-operator/config/manager/manager.yaml @@ -80,4 +80,9 @@ spec: # environment variable is not set, requeue after the default value (300). # - name: RAYCLUSTER_DEFAULT_REQUEUE_SECONDS_ENV # value: "300" + # Required for NetworkPolicy feature when operator is NOT deployed in 'ray-system' namespace + # - name: POD_NAMESPACE + # valueFrom: + # fieldRef: + # fieldPath: metadata.namespace terminationGracePeriodSeconds: 10 diff --git a/ray-operator/config/rbac/role.yaml b/ray-operator/config/rbac/role.yaml index ba840f0c27f..352a6e6ea63 100644 --- a/ray-operator/config/rbac/role.yaml +++ b/ray-operator/config/rbac/role.yaml @@ -96,7 +96,6 @@ rules: - update - apiGroups: - extensions - - networking.k8s.io resources: - ingresses verbs: @@ -115,6 +114,19 @@ rules: - get - list - watch +- apiGroups: + - networking.k8s.io + resources: + - ingresses + - networkpolicies + verbs: + - create + - delete + - get + - list + - patch + - update + - watch - apiGroups: - ray.io resources: diff --git a/ray-operator/controllers/ray/networkpolicy_controller.go b/ray-operator/controllers/ray/networkpolicy_controller.go new file mode 100644 index 00000000000..984c0a9645d --- /dev/null +++ b/ray-operator/controllers/ray/networkpolicy_controller.go @@ -0,0 +1,389 @@ +package ray + +import ( + "context" + "fmt" + "os" + + corev1 "k8s.io/api/core/v1" + networkingv1 "k8s.io/api/networking/v1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/intstr" + "k8s.io/client-go/tools/record" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + "sigs.k8s.io/controller-runtime/pkg/manager" + + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" + "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils" +) + +// NetworkPolicyController is a completely independent controller that watches RayCluster +// resources and manages NetworkPolicies for them. +type NetworkPolicyController struct { + client.Client + Scheme *runtime.Scheme + Recorder record.EventRecorder +} + +// +kubebuilder:rbac:groups=networking.k8s.io,resources=networkpolicies,verbs=get;list;watch;create;update;delete;patch +// +kubebuilder:rbac:groups=ray.io,resources=rayclusters,verbs=get;list;watch + +// NewNetworkPolicyController creates a new independent NetworkPolicy controller +func NewNetworkPolicyController(mgr manager.Manager) *NetworkPolicyController { + return &NetworkPolicyController{ + Client: mgr.GetClient(), + Scheme: mgr.GetScheme(), + Recorder: mgr.GetEventRecorderFor("networkpolicy-controller"), + } +} + +// Reconcile handles RayCluster resources and creates/manages NetworkPolicies +func (r *NetworkPolicyController) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + logger := ctrl.LoggerFrom(ctx).WithName("networkpolicy-controller") + + // Fetch the RayCluster instance + instance := &rayv1.RayCluster{} + if err := r.Get(ctx, req.NamespacedName, instance); err != nil { + if errors.IsNotFound(err) { + // RayCluster was deleted - NetworkPolicies will be garbage collected automatically + logger.Info("RayCluster not found, NetworkPolicies will be garbage collected") + return ctrl.Result{}, nil + } + return ctrl.Result{}, err + } + + // Check if RayCluster is being deleted + if instance.DeletionTimestamp != nil { + logger.Info("RayCluster is being deleted, NetworkPolicies will be garbage collected") + return ctrl.Result{}, nil + } + + logger.Info("Reconciling NetworkPolicies for RayCluster", "cluster", instance.Name) + + // Get KubeRay operator namespaces + kubeRayNamespaces := r.getKubeRayNamespaces(ctx) + + // Create or update head NetworkPolicy + headNetworkPolicy := r.buildHeadNetworkPolicy(instance, kubeRayNamespaces) + if err := r.createOrUpdateNetworkPolicy(ctx, instance, headNetworkPolicy); err != nil { + return ctrl.Result{}, err + } + + // Create or update worker NetworkPolicy + workerNetworkPolicy := r.buildWorkerNetworkPolicy(instance) + if err := r.createOrUpdateNetworkPolicy(ctx, instance, workerNetworkPolicy); err != nil { + return ctrl.Result{}, err + } + + logger.Info("Successfully reconciled NetworkPolicies for RayCluster", "cluster", instance.Name) + return ctrl.Result{}, nil +} + +// getKubeRayNamespaces returns the list of KubeRay operator namespaces +func (r *NetworkPolicyController) getKubeRayNamespaces(_ context.Context) []string { + operatorNamespace := os.Getenv("POD_NAMESPACE") + if operatorNamespace == "" { + operatorNamespace = "ray-system" // fallback + } + return []string{operatorNamespace} +} + +// createOrUpdateNetworkPolicy creates or updates a NetworkPolicy +func (r *NetworkPolicyController) createOrUpdateNetworkPolicy(ctx context.Context, instance *rayv1.RayCluster, networkPolicy *networkingv1.NetworkPolicy) error { + logger := ctrl.LoggerFrom(ctx).WithName("networkpolicy-controller") + + // Set owner reference for garbage collection + if err := controllerutil.SetControllerReference(instance, networkPolicy, r.Scheme); err != nil { + return err + } + + // Try to create the NetworkPolicy + if err := r.Create(ctx, networkPolicy); err != nil { + if errors.IsAlreadyExists(err) { + // NetworkPolicy exists, update it + existing := &networkingv1.NetworkPolicy{} + if err := r.Get(ctx, client.ObjectKeyFromObject(networkPolicy), existing); err != nil { + return err + } + + // Update the existing NetworkPolicy + existing.Spec = networkPolicy.Spec + existing.Labels = networkPolicy.Labels + + if err := r.Update(ctx, existing); err != nil { + r.Recorder.Eventf(instance, corev1.EventTypeWarning, string(utils.FailedToCreateNetworkPolicy), + "Failed to update NetworkPolicy %s/%s: %v", networkPolicy.Namespace, networkPolicy.Name, err) + return err + } + + logger.Info("Successfully updated NetworkPolicy", "name", networkPolicy.Name) + r.Recorder.Eventf(instance, corev1.EventTypeNormal, string(utils.CreatedNetworkPolicy), + "Updated NetworkPolicy %s/%s", networkPolicy.Namespace, networkPolicy.Name) + } else { + r.Recorder.Eventf(instance, corev1.EventTypeWarning, string(utils.FailedToCreateNetworkPolicy), + "Failed to create NetworkPolicy %s/%s: %v", networkPolicy.Namespace, networkPolicy.Name, err) + return err + } + } else { + logger.Info("Successfully created NetworkPolicy", "name", networkPolicy.Name) + r.Recorder.Eventf(instance, corev1.EventTypeNormal, string(utils.CreatedNetworkPolicy), + "Created NetworkPolicy %s/%s", networkPolicy.Namespace, networkPolicy.Name) + } + + return nil +} + +// buildHeadNetworkPolicy creates a NetworkPolicy for Ray head pods +func (r *NetworkPolicyController) buildHeadNetworkPolicy(instance *rayv1.RayCluster, kubeRayNamespaces []string) *networkingv1.NetworkPolicy { + labels := map[string]string{ + utils.RayClusterLabelKey: instance.Name, + utils.KubernetesApplicationNameLabelKey: utils.ApplicationName, + utils.KubernetesCreatedByLabelKey: utils.ComponentName, + } + + // Build secured ports - mTLS port always included + allSecuredPorts := []networkingv1.NetworkPolicyPort{ + { + Protocol: &[]corev1.Protocol{corev1.ProtocolTCP}[0], + Port: &[]intstr.IntOrString{intstr.FromInt(8443)}[0], + }, + } + + // Check if mTLS is enabled by looking for TLS configuration in RayCluster + if r.isMTLSEnabled(instance) { + // If mTLS is enabled, also secure port 10001 + allSecuredPorts = append(allSecuredPorts, networkingv1.NetworkPolicyPort{ + Protocol: &[]corev1.Protocol{corev1.ProtocolTCP}[0], + Port: &[]intstr.IntOrString{intstr.FromInt(10001)}[0], + }) + } + + // Build ingress rules + ingressRules := []networkingv1.NetworkPolicyIngressRule{ + // Rule 1: Intra-cluster communication - NO PORTS (allows all ports) + { + From: []networkingv1.NetworkPolicyPeer{ + { + PodSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + utils.RayClusterLabelKey: instance.Name, + }, + }, + }, + }, + // No Ports specified = allow all ports + }, + // Rule 2: External access to dashboard and client ports from any pod in namespace + { + From: []networkingv1.NetworkPolicyPeer{ + { + PodSelector: &metav1.LabelSelector{ + // Empty MatchLabels = any pod in same namespace + }, + }, + }, + Ports: []networkingv1.NetworkPolicyPort{ + { + Protocol: &[]corev1.Protocol{corev1.ProtocolTCP}[0], + Port: &[]intstr.IntOrString{intstr.FromInt(10001)}[0], // Client + }, + { + Protocol: &[]corev1.Protocol{corev1.ProtocolTCP}[0], + Port: &[]intstr.IntOrString{intstr.FromInt(8265)}[0], // Dashboard + }, + }, + }, + // Rule 3: KubeRay operator access + { + From: []networkingv1.NetworkPolicyPeer{ + { + PodSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + utils.KubernetesApplicationNameLabelKey: utils.ApplicationName, + }, + }, + NamespaceSelector: &metav1.LabelSelector{ + MatchExpressions: []metav1.LabelSelectorRequirement{ + { + Key: corev1.LabelMetadataName, + Operator: metav1.LabelSelectorOpIn, + Values: kubeRayNamespaces, + }, + }, + }, + }, + }, + Ports: []networkingv1.NetworkPolicyPort{ + { + Protocol: &[]corev1.Protocol{corev1.ProtocolTCP}[0], + Port: &[]intstr.IntOrString{intstr.FromInt(8265)}[0], // Dashboard + }, + { + Protocol: &[]corev1.Protocol{corev1.ProtocolTCP}[0], + Port: &[]intstr.IntOrString{intstr.FromInt(10001)}[0], // Client + }, + }, + }, + // Rule 4: Monitoring access + { + From: []networkingv1.NetworkPolicyPeer{ + { + NamespaceSelector: &metav1.LabelSelector{ + MatchExpressions: []metav1.LabelSelectorRequirement{ + { + Key: corev1.LabelMetadataName, + Operator: metav1.LabelSelectorOpIn, + Values: []string{"openshift-monitoring", "prometheus", "redhat-ods-monitoring"}, + }, + }, + }, + }, + }, + Ports: []networkingv1.NetworkPolicyPort{ + { + Protocol: &[]corev1.Protocol{corev1.ProtocolTCP}[0], + Port: &[]intstr.IntOrString{intstr.FromInt(8080)}[0], // Metrics + }, + }, + }, + // Rule 5: Secured ports - NO FROM (allows all) + { + Ports: allSecuredPorts, + // No From specified = allow from anywhere + }, + } + + // Add RayJob submitter peer if RayCluster is owned by RayJob + if rayJobPeer := r.buildRayJobPeer(instance); rayJobPeer != nil { + ingressRules = append(ingressRules, networkingv1.NetworkPolicyIngressRule{ + From: []networkingv1.NetworkPolicyPeer{*rayJobPeer}, + }) + } + + return &networkingv1.NetworkPolicy{ + ObjectMeta: metav1.ObjectMeta{ + Name: fmt.Sprintf("%s-head", instance.Name), + Namespace: instance.Namespace, + Labels: labels, + }, + Spec: networkingv1.NetworkPolicySpec{ + PodSelector: metav1.LabelSelector{ + MatchLabels: map[string]string{ + utils.RayClusterLabelKey: instance.Name, + utils.RayNodeTypeLabelKey: string(rayv1.HeadNode), + }, + }, + PolicyTypes: []networkingv1.PolicyType{networkingv1.PolicyTypeIngress}, + Ingress: ingressRules, + }, + } +} + +// buildWorkerNetworkPolicy creates a NetworkPolicy for Ray worker pods +func (r *NetworkPolicyController) buildWorkerNetworkPolicy(instance *rayv1.RayCluster) *networkingv1.NetworkPolicy { + labels := map[string]string{ + utils.RayClusterLabelKey: instance.Name, + utils.KubernetesApplicationNameLabelKey: utils.ApplicationName, + utils.KubernetesCreatedByLabelKey: utils.ComponentName, + } + + return &networkingv1.NetworkPolicy{ + ObjectMeta: metav1.ObjectMeta{ + Name: fmt.Sprintf("%s-workers", instance.Name), + Namespace: instance.Namespace, + Labels: labels, + }, + Spec: networkingv1.NetworkPolicySpec{ + PodSelector: metav1.LabelSelector{ + MatchLabels: map[string]string{ + utils.RayClusterLabelKey: instance.Name, + utils.RayNodeTypeLabelKey: string(rayv1.WorkerNode), + }, + }, + PolicyTypes: []networkingv1.PolicyType{networkingv1.PolicyTypeIngress}, + Ingress: []networkingv1.NetworkPolicyIngressRule{ + { + From: []networkingv1.NetworkPolicyPeer{ + { + PodSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + utils.RayClusterLabelKey: instance.Name, + }, + }, + }, + }, + }, + }, + }, + } +} + +// buildRayJobPeer creates a NetworkPolicy peer for RayJob submitter pods +// Returns nil if RayCluster is not owned by RayJob +func (r *NetworkPolicyController) buildRayJobPeer(instance *rayv1.RayCluster) *networkingv1.NetworkPolicyPeer { + // Check if RayCluster is owned by RayJob + for _, ownerRef := range instance.OwnerReferences { + if ownerRef.Kind == "RayJob" { + // Return peer for RayJob submitter pods + return &networkingv1.NetworkPolicyPeer{ + PodSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "batch.kubernetes.io/job-name": ownerRef.Name, + }, + }, + } + } + } + // No RayJob owner = no RayJob submitter pods to allow + return nil +} + +// isMTLSEnabled checks if mTLS is enabled for the RayCluster +// This looks for TLS-related environment variables or configuration +func (r *NetworkPolicyController) isMTLSEnabled(instance *rayv1.RayCluster) bool { + // Check head group for TLS environment variables + if r.checkContainersForMTLS(instance.Spec.HeadGroupSpec.Template.Spec.Containers) { + return true + } + + // Check worker groups for TLS environment variables + for _, workerGroup := range instance.Spec.WorkerGroupSpecs { + if r.checkContainersForMTLS(workerGroup.Template.Spec.Containers) { + return true + } + } + + return false +} + +// checkContainersForMTLS checks if any container has mTLS-related environment variables +func (r *NetworkPolicyController) checkContainersForMTLS(containers []corev1.Container) bool { + for _, container := range containers { + for _, env := range container.Env { + // Check for common Ray TLS environment variables + if env.Name == "RAY_USE_TLS" && env.Value == "1" { + return true + } + if env.Name == "RAY_TLS_SERVER_CERT" && env.Value != "" { + return true + } + if env.Name == "RAY_TLS_SERVER_KEY" && env.Value != "" { + return true + } + } + } + return false +} + +// SetupWithManager sets up the controller with the Manager +func (r *NetworkPolicyController) SetupWithManager(mgr ctrl.Manager) error { + return ctrl.NewControllerManagedBy(mgr). + For(&rayv1.RayCluster{}). + Owns(&networkingv1.NetworkPolicy{}). + Named("networkpolicy"). + Complete(r) +} diff --git a/ray-operator/controllers/ray/networkpolicy_controller_test.go b/ray-operator/controllers/ray/networkpolicy_controller_test.go new file mode 100644 index 00000000000..12f110fb3d2 --- /dev/null +++ b/ray-operator/controllers/ray/networkpolicy_controller_test.go @@ -0,0 +1,422 @@ +/* + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ray + +import ( + "context" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" + networkingv1 "k8s.io/api/networking/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client" + + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" + "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils" + "github.com/ray-project/kuberay/ray-operator/test/support" +) + +func rayClusterTemplateForNetworkPolicy(name string, namespace string) *rayv1.RayCluster { + return &rayv1.RayCluster{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + }, + Spec: rayv1.RayClusterSpec{ + RayVersion: support.GetRayVersion(), + HeadGroupSpec: rayv1.HeadGroupSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "ray-head", + Image: support.GetRayImage(), + }, + }, + }, + }, + }, + WorkerGroupSpecs: []rayv1.WorkerGroupSpec{ + { + Replicas: ptr.To[int32](1), + MinReplicas: ptr.To[int32](0), + MaxReplicas: ptr.To[int32](2), + GroupName: "small-group", + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "ray-worker", + Image: support.GetRayImage(), + }, + }, + }, + }, + }, + }, + }, + } +} + +var _ = Context("NetworkPolicy Controller Integration Tests", func() { + Describe("Basic NetworkPolicy Creation", Ordered, func() { + ctx := context.Background() + namespace := "default" + rayCluster := rayClusterTemplateForNetworkPolicy("raycluster-networkpolicy", namespace) + + It("Verify RayCluster spec", func() { + Expect(rayCluster.Spec.WorkerGroupSpecs).To(HaveLen(1)) + Expect(rayCluster.Spec.WorkerGroupSpecs[0].Replicas).To(Equal(ptr.To[int32](1))) + }) + + It("Create a RayCluster custom resource", func() { + err := k8sClient.Create(ctx, rayCluster) + Expect(err).NotTo(HaveOccurred(), "Failed to create RayCluster") + Eventually( + getResourceFunc(ctx, client.ObjectKey{Name: rayCluster.Name, Namespace: namespace}, rayCluster), + time.Second*3, time.Millisecond*500).Should(Succeed(), "Should be able to see RayCluster: %v", rayCluster.Name) + }) + + It("Check Head NetworkPolicy is created", func() { + headNetworkPolicy := &networkingv1.NetworkPolicy{} + expectedHeadName := rayCluster.Name + "-head" + headNamespacedName := types.NamespacedName{Namespace: namespace, Name: expectedHeadName} + + Eventually( + getResourceFunc(ctx, headNamespacedName, headNetworkPolicy), + time.Second*10, time.Millisecond*500).Should(Succeed(), "Head NetworkPolicy should be created: %v", expectedHeadName) + }) + + It("Check Worker NetworkPolicy is created", func() { + workerNetworkPolicy := &networkingv1.NetworkPolicy{} + expectedWorkerName := rayCluster.Name + "-workers" + workerNamespacedName := types.NamespacedName{Namespace: namespace, Name: expectedWorkerName} + + Eventually( + getResourceFunc(ctx, workerNamespacedName, workerNetworkPolicy), + time.Second*10, time.Millisecond*500).Should(Succeed(), "Worker NetworkPolicy should be created: %v", expectedWorkerName) + }) + + It("Verify Head NetworkPolicy has correct structure", func() { + headNetworkPolicy := &networkingv1.NetworkPolicy{} + expectedHeadName := rayCluster.Name + "-head" + headNamespacedName := types.NamespacedName{Namespace: namespace, Name: expectedHeadName} + + err := k8sClient.Get(ctx, headNamespacedName, headNetworkPolicy) + Expect(err).NotTo(HaveOccurred(), "Failed to get Head NetworkPolicy") + + // Verify basic properties + Expect(headNetworkPolicy.Name).To(Equal(expectedHeadName)) + Expect(headNetworkPolicy.Namespace).To(Equal(namespace)) + + // Verify labels + expectedLabels := map[string]string{ + utils.RayClusterLabelKey: rayCluster.Name, + utils.KubernetesApplicationNameLabelKey: utils.ApplicationName, + utils.KubernetesCreatedByLabelKey: utils.ComponentName, + } + Expect(headNetworkPolicy.Labels).To(Equal(expectedLabels)) + + // Verify owner reference is set + Expect(headNetworkPolicy.OwnerReferences).To(HaveLen(1)) + Expect(headNetworkPolicy.OwnerReferences[0].Name).To(Equal(rayCluster.Name)) + Expect(headNetworkPolicy.OwnerReferences[0].Kind).To(Equal("RayCluster")) + + // Verify policy type + Expect(headNetworkPolicy.Spec.PolicyTypes).To(Equal([]networkingv1.PolicyType{networkingv1.PolicyTypeIngress})) + + // Verify pod selector targets head pods only + expectedPodSelector := metav1.LabelSelector{ + MatchLabels: map[string]string{ + utils.RayClusterLabelKey: rayCluster.Name, + utils.RayNodeTypeLabelKey: string(rayv1.HeadNode), + }, + } + Expect(headNetworkPolicy.Spec.PodSelector).To(Equal(expectedPodSelector)) + + // Verify ingress rules - CodeFlare 5-rule pattern + Expect(len(headNetworkPolicy.Spec.Ingress)).To(BeNumerically(">=", 5), "Should have at least 5 ingress rules") + Expect(len(headNetworkPolicy.Spec.Ingress)).To(BeNumerically("<=", 6), "Should have at most 6 ingress rules (including optional RayJob)") + + // Verify Rule 1: Intra-cluster communication - NO PORTS (allows all) + intraClusterRule := headNetworkPolicy.Spec.Ingress[0] + Expect(intraClusterRule.From).To(HaveLen(1)) + Expect(intraClusterRule.Ports).To(BeEmpty(), "Intra-cluster rule should have NO ports (allows all)") + + expectedIntraClusterPeer := networkingv1.NetworkPolicyPeer{ + PodSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + utils.RayClusterLabelKey: rayCluster.Name, + }, + }, + } + Expect(intraClusterRule.From[0]).To(Equal(expectedIntraClusterPeer)) + + // Verify Rule 2: External access from any pod in namespace + externalRule := headNetworkPolicy.Spec.Ingress[1] + Expect(externalRule.From).To(HaveLen(1)) + Expect(externalRule.Ports).To(HaveLen(2), "External rule should have 2 ports (10001, 8265)") + + // Verify empty pod selector (any pod in namespace) + expectedAnyPodPeer := networkingv1.NetworkPolicyPeer{ + PodSelector: &metav1.LabelSelector{}, + } + Expect(externalRule.From[0]).To(Equal(expectedAnyPodPeer)) + + // Verify Rule 5: Secured ports - NO FROM (allows all) + securedRule := headNetworkPolicy.Spec.Ingress[4] + Expect(securedRule.From).To(BeEmpty(), "Secured ports rule should have NO from (allows all)") + Expect(securedRule.Ports).ToNot(BeEmpty(), "Secured ports rule should have at least 1 port (8443)") + + // Check for mTLS port 8443 (always present) + portFound8443 := false + for _, port := range securedRule.Ports { + if port.Port.IntVal == 8443 { + portFound8443 = true + } + } + Expect(portFound8443).To(BeTrue(), "Should include mTLS port 8443") + }) + + It("Verify Worker NetworkPolicy has correct structure", func() { + workerNetworkPolicy := &networkingv1.NetworkPolicy{} + expectedWorkerName := rayCluster.Name + "-workers" + workerNamespacedName := types.NamespacedName{Namespace: namespace, Name: expectedWorkerName} + + err := k8sClient.Get(ctx, workerNamespacedName, workerNetworkPolicy) + Expect(err).NotTo(HaveOccurred(), "Failed to get Worker NetworkPolicy") + + // Verify basic properties + Expect(workerNetworkPolicy.Name).To(Equal(expectedWorkerName)) + Expect(workerNetworkPolicy.Namespace).To(Equal(namespace)) + + // Verify pod selector targets worker pods only + expectedPodSelector := metav1.LabelSelector{ + MatchLabels: map[string]string{ + utils.RayClusterLabelKey: rayCluster.Name, + utils.RayNodeTypeLabelKey: string(rayv1.WorkerNode), + }, + } + Expect(workerNetworkPolicy.Spec.PodSelector).To(Equal(expectedPodSelector)) + + // Verify ingress rules - workers only allow intra-cluster communication + Expect(workerNetworkPolicy.Spec.Ingress).To(HaveLen(1)) + Expect(workerNetworkPolicy.Spec.Ingress[0].From).To(HaveLen(1)) + + // Verify intra-cluster peer + intraClusterPeer := workerNetworkPolicy.Spec.Ingress[0].From[0] + expectedIntraClusterPeer := networkingv1.NetworkPolicyPeer{ + PodSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + utils.RayClusterLabelKey: rayCluster.Name, + }, + }, + } + Expect(intraClusterPeer).To(Equal(expectedIntraClusterPeer)) + }) + + It("Delete RayCluster should delete NetworkPolicies", func() { + // Delete the RayCluster + err := k8sClient.Delete(ctx, rayCluster) + Expect(err).NotTo(HaveOccurred(), "Failed to delete RayCluster") + + // Note: envtest doesn't run garbage collection automatically like a real cluster + // In a real cluster, the NetworkPolicies would be automatically deleted due to owner reference + // For testing, we manually delete them to simulate garbage collection + + // Clean up head NetworkPolicy + headNetworkPolicy := &networkingv1.NetworkPolicy{} + expectedHeadName := rayCluster.Name + "-head" + headNamespacedName := types.NamespacedName{Namespace: namespace, Name: expectedHeadName} + + err = k8sClient.Get(ctx, headNamespacedName, headNetworkPolicy) + if err == nil { + err = k8sClient.Delete(ctx, headNetworkPolicy) + Expect(err).NotTo(HaveOccurred(), "Failed to manually delete Head NetworkPolicy") + } + + // Clean up worker NetworkPolicy + workerNetworkPolicy := &networkingv1.NetworkPolicy{} + expectedWorkerName := rayCluster.Name + "-workers" + workerNamespacedName := types.NamespacedName{Namespace: namespace, Name: expectedWorkerName} + + err = k8sClient.Get(ctx, workerNamespacedName, workerNetworkPolicy) + if err == nil { + err = k8sClient.Delete(ctx, workerNetworkPolicy) + Expect(err).NotTo(HaveOccurred(), "Failed to manually delete Worker NetworkPolicy") + } + + // Verify both NetworkPolicies are deleted + Eventually( + func() bool { + headErr := k8sClient.Get(ctx, headNamespacedName, headNetworkPolicy) + workerErr := k8sClient.Get(ctx, workerNamespacedName, workerNetworkPolicy) + return (headErr != nil && client.IgnoreNotFound(headErr) == nil) && + (workerErr != nil && client.IgnoreNotFound(workerErr) == nil) + }, + time.Second*5, time.Millisecond*500).Should(BeTrue(), "Both NetworkPolicies should be deleted") + }) + }) + + Describe("RayCluster owned by RayJob", Ordered, func() { + ctx := context.Background() + namespace := "default" + rayCluster := rayClusterTemplateForNetworkPolicy("raycluster-rayjob", namespace) + + // Add RayJob owner reference + rayCluster.OwnerReferences = []metav1.OwnerReference{ + { + APIVersion: "ray.io/v1", + Kind: "RayJob", + Name: "test-rayjob", + UID: "12345", + }, + } + + It("Create a RayCluster with RayJob owner", func() { + err := k8sClient.Create(ctx, rayCluster) + Expect(err).NotTo(HaveOccurred(), "Failed to create RayCluster with RayJob owner") + Eventually( + getResourceFunc(ctx, client.ObjectKey{Name: rayCluster.Name, Namespace: namespace}, rayCluster), + time.Second*3, time.Millisecond*500).Should(Succeed(), "Should be able to see RayCluster: %v", rayCluster.Name) + }) + + It("Check Head NetworkPolicy includes RayJob peer", func() { + headNetworkPolicy := &networkingv1.NetworkPolicy{} + expectedHeadName := rayCluster.Name + "-head" + headNamespacedName := types.NamespacedName{Namespace: namespace, Name: expectedHeadName} + + Eventually( + getResourceFunc(ctx, headNamespacedName, headNetworkPolicy), + time.Second*10, time.Millisecond*500).Should(Succeed(), "Head NetworkPolicy should be created") + + // Should have additional RayJob rule (last rule) + Expect(len(headNetworkPolicy.Spec.Ingress)).To(BeNumerically(">=", 5), "Should have additional RayJob ingress rule") + + // Find the RayJob rule (should be the last rule) + rayJobRule := headNetworkPolicy.Spec.Ingress[len(headNetworkPolicy.Spec.Ingress)-1] + Expect(rayJobRule.From).To(HaveLen(1), "RayJob rule should have one peer") + + // Verify RayJob peer + rayJobPeer := rayJobRule.From[0] + expectedRayJobPeer := networkingv1.NetworkPolicyPeer{ + PodSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "batch.kubernetes.io/job-name": "test-rayjob", + }, + }, + } + Expect(rayJobPeer).To(Equal(expectedRayJobPeer)) + }) + + It("Clean up RayCluster with RayJob owner", func() { + err := k8sClient.Delete(ctx, rayCluster) + Expect(err).NotTo(HaveOccurred(), "Failed to delete RayCluster") + }) + }) + + Describe("NetworkPolicy Already Exists", Ordered, func() { + ctx := context.Background() + namespace := "default" + rayCluster := rayClusterTemplateForNetworkPolicy("raycluster-existing-np", namespace) + existingHeadNetworkPolicy := &networkingv1.NetworkPolicy{ + ObjectMeta: metav1.ObjectMeta{ + Name: rayCluster.Name + "-head", + Namespace: namespace, + Labels: map[string]string{ + "test": "existing", + }, + }, + Spec: networkingv1.NetworkPolicySpec{ + PodSelector: metav1.LabelSelector{}, + PolicyTypes: []networkingv1.PolicyType{networkingv1.PolicyTypeIngress}, + }, + } + + It("Create Head NetworkPolicy before RayCluster", func() { + err := k8sClient.Create(ctx, existingHeadNetworkPolicy) + Expect(err).NotTo(HaveOccurred(), "Failed to create existing Head NetworkPolicy") + }) + + It("Create RayCluster should handle existing NetworkPolicy gracefully", func() { + err := k8sClient.Create(ctx, rayCluster) + Expect(err).NotTo(HaveOccurred(), "Failed to create RayCluster") + Eventually( + getResourceFunc(ctx, client.ObjectKey{Name: rayCluster.Name, Namespace: namespace}, rayCluster), + time.Second*3, time.Millisecond*500).Should(Succeed(), "Should be able to see RayCluster: %v", rayCluster.Name) + + // Head NetworkPolicy should be updated by the controller + headNetworkPolicy := &networkingv1.NetworkPolicy{} + headNamespacedName := types.NamespacedName{Namespace: namespace, Name: existingHeadNetworkPolicy.Name} + err = k8sClient.Get(ctx, headNamespacedName, headNetworkPolicy) + Expect(err).NotTo(HaveOccurred(), "Head NetworkPolicy should exist") + + // Worker NetworkPolicy should be created + workerNetworkPolicy := &networkingv1.NetworkPolicy{} + expectedWorkerName := rayCluster.Name + "-workers" + workerNamespacedName := types.NamespacedName{Namespace: namespace, Name: expectedWorkerName} + Eventually( + getResourceFunc(ctx, workerNamespacedName, workerNetworkPolicy), + time.Second*10, time.Millisecond*500).Should(Succeed(), "Worker NetworkPolicy should be created") + }) + + It("Clean up resources", func() { + err := k8sClient.Delete(ctx, rayCluster) + Expect(err).NotTo(HaveOccurred(), "Failed to delete RayCluster") + + err = k8sClient.Delete(ctx, existingHeadNetworkPolicy) + // Policy might have been updated by controller, ignore delete errors + _ = err + + // Clean up worker policy if it exists + workerNetworkPolicy := &networkingv1.NetworkPolicy{} + expectedWorkerName := rayCluster.Name + "-workers" + workerNamespacedName := types.NamespacedName{Namespace: namespace, Name: expectedWorkerName} + err = k8sClient.Get(ctx, workerNamespacedName, workerNetworkPolicy) + if err == nil { + err = k8sClient.Delete(ctx, workerNetworkPolicy) + Expect(err).NotTo(HaveOccurred(), "Failed to delete worker NetworkPolicy") + } + }) + }) + + Describe("RayCluster Deletion", Ordered, func() { + ctx := context.Background() + namespace := "default" + rayCluster := rayClusterTemplateForNetworkPolicy("raycluster-deletion", namespace) + + It("Create and immediately delete RayCluster", func() { + err := k8sClient.Create(ctx, rayCluster) + Expect(err).NotTo(HaveOccurred(), "Failed to create RayCluster") + + // Add deletion timestamp by deleting + err = k8sClient.Delete(ctx, rayCluster) + Expect(err).NotTo(HaveOccurred(), "Failed to delete RayCluster") + + // Verify RayCluster is being deleted or deleted + Eventually( + func() bool { + err := k8sClient.Get(ctx, client.ObjectKey{Name: rayCluster.Name, Namespace: namespace}, rayCluster) + return err != nil && client.IgnoreNotFound(err) == nil + }, + time.Second*10, time.Millisecond*500).Should(BeTrue(), "RayCluster should be deleted") + }) + }) +}) diff --git a/ray-operator/controllers/ray/networkpolicy_controller_unit_test.go b/ray-operator/controllers/ray/networkpolicy_controller_unit_test.go new file mode 100644 index 00000000000..b48e7489c1c --- /dev/null +++ b/ray-operator/controllers/ray/networkpolicy_controller_unit_test.go @@ -0,0 +1,607 @@ +package ray + +import ( + "context" + "os" + "testing" + + . "github.com/onsi/ginkgo/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + networkingv1 "k8s.io/api/networking/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + logf "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/controller-runtime/pkg/log/zap" + + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" + "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils" +) + +var ( + testNetworkPolicyController *NetworkPolicyController + testRayClusterBasic *rayv1.RayCluster + testRayClusterWithRayJob *rayv1.RayCluster + testRayClusterWithOtherOwner *rayv1.RayCluster +) + +func setupNetworkPolicyTest(_ *testing.T) { + logf.SetLogger(zap.New(zap.WriteTo(GinkgoWriter), zap.UseDevMode(true))) + + // Initialize NetworkPolicy controller + testNetworkPolicyController = &NetworkPolicyController{ + Scheme: runtime.NewScheme(), + } + + // Basic RayCluster without owner + testRayClusterBasic = &rayv1.RayCluster{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster", + Namespace: "default", + }, + Spec: rayv1.RayClusterSpec{ + HeadGroupSpec: rayv1.HeadGroupSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "ray-head", + Image: "rayproject/ray:latest", + }, + }, + }, + }, + }, + }, + } + + // RayCluster owned by RayJob + testRayClusterWithRayJob = &rayv1.RayCluster{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster-rayjob", + Namespace: "default", + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: "ray.io/v1", + Kind: "RayJob", + Name: "test-job", + UID: "12345", + }, + }, + }, + Spec: rayv1.RayClusterSpec{ + HeadGroupSpec: rayv1.HeadGroupSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "ray-head", + Image: "rayproject/ray:latest", + }, + }, + }, + }, + }, + }, + } + + // RayCluster owned by something other than RayJob + testRayClusterWithOtherOwner = &rayv1.RayCluster{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster-other", + Namespace: "default", + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: "apps/v1", + Kind: "Deployment", + Name: "test-deployment", + UID: "67890", + }, + }, + }, + Spec: rayv1.RayClusterSpec{ + HeadGroupSpec: rayv1.HeadGroupSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "ray-head", + Image: "rayproject/ray:latest", + }, + }, + }, + }, + }, + }, + } +} + +func TestBuildHeadNetworkPolicy_BasicCluster(t *testing.T) { + setupNetworkPolicyTest(t) + + // Set environment for testing + originalEnv := os.Getenv("POD_NAMESPACE") + os.Setenv("POD_NAMESPACE", "ray-system") + defer os.Setenv("POD_NAMESPACE", originalEnv) + + // Test building head NetworkPolicy for basic cluster + kubeRayNamespaces := []string{"ray-system"} + policy := testNetworkPolicyController.buildHeadNetworkPolicy(testRayClusterBasic, kubeRayNamespaces) + + // Verify basic properties + expectedName := testRayClusterBasic.Name + "-head" + assert.Equal(t, expectedName, policy.Name) + assert.Equal(t, testRayClusterBasic.Namespace, policy.Namespace) + + // Verify labels + expectedLabels := map[string]string{ + utils.RayClusterLabelKey: testRayClusterBasic.Name, + utils.KubernetesApplicationNameLabelKey: utils.ApplicationName, + utils.KubernetesCreatedByLabelKey: utils.ComponentName, + } + assert.Equal(t, expectedLabels, policy.Labels) + + // Verify policy type + assert.Equal(t, []networkingv1.PolicyType{networkingv1.PolicyTypeIngress}, policy.Spec.PolicyTypes) + + // Verify pod selector targets head pods only + expectedPodSelector := metav1.LabelSelector{ + MatchLabels: map[string]string{ + utils.RayClusterLabelKey: testRayClusterBasic.Name, + utils.RayNodeTypeLabelKey: string(rayv1.HeadNode), + }, + } + assert.Equal(t, expectedPodSelector, policy.Spec.PodSelector) + + // Verify ingress rules - CodeFlare 5-rule pattern (+ optional RayJob rule) + assert.GreaterOrEqual(t, len(policy.Spec.Ingress), 5, "Should have at least 5 ingress rules") + assert.LessOrEqual(t, len(policy.Spec.Ingress), 6, "Should have at most 6 ingress rules (including optional RayJob)") + + // Verify Rule 1: Intra-cluster communication - NO PORTS (allows all ports) + intraClusterRule := policy.Spec.Ingress[0] + assert.Len(t, intraClusterRule.From, 1, "Intra-cluster rule should have one peer") + assert.Empty(t, intraClusterRule.Ports, "Intra-cluster rule should have NO ports (allows all)") + + expectedIntraClusterPeer := networkingv1.NetworkPolicyPeer{ + PodSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + utils.RayClusterLabelKey: testRayClusterBasic.Name, + }, + }, + } + assert.Equal(t, expectedIntraClusterPeer, intraClusterRule.From[0], "Should allow cluster members") + + // Verify Rule 2: External access to dashboard and client ports from any pod in namespace + externalRule := policy.Spec.Ingress[1] + assert.Len(t, externalRule.From, 1, "External rule should have one peer") + assert.Len(t, externalRule.Ports, 2, "External rule should have 2 ports (10001, 8265)") + + // Verify empty pod selector (any pod in namespace) + expectedAnyPodPeer := networkingv1.NetworkPolicyPeer{ + PodSelector: &metav1.LabelSelector{ + // Empty MatchLabels = any pod in same namespace + }, + } + assert.Equal(t, expectedAnyPodPeer, externalRule.From[0], "Should allow any pod in namespace") + + // Check ports (10001, 8265) + portFound10001 := false + portFound8265 := false + for _, port := range externalRule.Ports { + switch port.Port.IntVal { + case 10001: + portFound10001 = true + case 8265: + portFound8265 = true + } + } + assert.True(t, portFound10001, "Should include client port 10001") + assert.True(t, portFound8265, "Should include dashboard port 8265") + + // Verify Rule 3: KubeRay operator access + operatorRule := policy.Spec.Ingress[2] + assert.Len(t, operatorRule.From, 1, "Operator rule should have one peer") + assert.Len(t, operatorRule.Ports, 2, "Operator rule should have 2 ports (8265, 10001)") + + // Verify Rule 4: Monitoring access + monitoringRule := policy.Spec.Ingress[3] + assert.Len(t, monitoringRule.From, 1, "Monitoring rule should have one peer") + assert.Len(t, monitoringRule.Ports, 1, "Monitoring rule should have 1 port (8080 only)") + assert.Equal(t, int32(8080), monitoringRule.Ports[0].Port.IntVal, "Should be monitoring port 8080") + + // Verify Rule 5: Secured ports - NO FROM (allows all) + securedRule := policy.Spec.Ingress[4] + assert.Empty(t, securedRule.From, "Secured ports rule should have NO from (allows all)") + assert.GreaterOrEqual(t, len(securedRule.Ports), 1, "Secured ports rule should have at least 1 port (8443)") + + // Check for mTLS port 8443 (always present) + portFound8443 := false + for _, port := range securedRule.Ports { + if port.Port.IntVal == 8443 { + portFound8443 = true + } + } + assert.True(t, portFound8443, "Should include mTLS port 8443") +} + +func TestBuildHeadNetworkPolicy_WithMTLS(t *testing.T) { + setupNetworkPolicyTest(t) + + // Create RayCluster with mTLS configuration + rayClusterWithMTLS := &rayv1.RayCluster{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster-mtls", + Namespace: "default", + }, + Spec: rayv1.RayClusterSpec{ + HeadGroupSpec: rayv1.HeadGroupSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "ray-head", + Image: "rayproject/ray:latest", + Env: []corev1.EnvVar{ + {Name: "RAY_USE_TLS", Value: "1"}, + {Name: "RAY_TLS_SERVER_CERT", Value: "/etc/tls/server.crt"}, + {Name: "RAY_TLS_SERVER_KEY", Value: "/etc/tls/server.key"}, + }, + }, + }, + }, + }, + }, + }, + } + + // Test building head NetworkPolicy with mTLS enabled + policy := testNetworkPolicyController.buildHeadNetworkPolicy(rayClusterWithMTLS, []string{"ray-system"}) + + // Verify Rule 5: Secured ports should include both 8443 and 10001 when mTLS is enabled + securedRule := policy.Spec.Ingress[4] + assert.Empty(t, securedRule.From, "Secured ports rule should have NO from (allows all)") + assert.Len(t, securedRule.Ports, 2, "Secured ports rule should have 2 ports when mTLS enabled (8443, 10001)") + + // Check for both mTLS ports + portFound8443 := false + portFound10001 := false + for _, port := range securedRule.Ports { + switch port.Port.IntVal { + case 8443: + portFound8443 = true + case 10001: + portFound10001 = true + } + } + assert.True(t, portFound8443, "Should include mTLS port 8443") + assert.True(t, portFound10001, "Should include client port 10001 when mTLS enabled") +} + +func TestBuildHeadNetworkPolicy_WithoutMTLS(t *testing.T) { + setupNetworkPolicyTest(t) + + // Use basic cluster without mTLS configuration + policy := testNetworkPolicyController.buildHeadNetworkPolicy(testRayClusterBasic, []string{"ray-system"}) + + // Verify Rule 5: Secured ports should only include 8443 when mTLS is disabled + securedRule := policy.Spec.Ingress[4] + assert.Empty(t, securedRule.From, "Secured ports rule should have NO from (allows all)") + assert.Len(t, securedRule.Ports, 1, "Secured ports rule should have 1 port when mTLS disabled (8443 only)") + + // Check for only mTLS port + assert.Equal(t, int32(8443), securedRule.Ports[0].Port.IntVal, "Should only include mTLS port 8443") +} + +func TestBuildWorkerNetworkPolicy_BasicCluster(t *testing.T) { + setupNetworkPolicyTest(t) + + // Test building worker NetworkPolicy for basic cluster + policy := testNetworkPolicyController.buildWorkerNetworkPolicy(testRayClusterBasic) + + // Verify basic properties + expectedName := testRayClusterBasic.Name + "-workers" + assert.Equal(t, expectedName, policy.Name) + assert.Equal(t, testRayClusterBasic.Namespace, policy.Namespace) + + // Verify labels + expectedLabels := map[string]string{ + utils.RayClusterLabelKey: testRayClusterBasic.Name, + utils.KubernetesApplicationNameLabelKey: utils.ApplicationName, + utils.KubernetesCreatedByLabelKey: utils.ComponentName, + } + assert.Equal(t, expectedLabels, policy.Labels) + + // Verify pod selector targets worker pods only + expectedPodSelector := metav1.LabelSelector{ + MatchLabels: map[string]string{ + utils.RayClusterLabelKey: testRayClusterBasic.Name, + utils.RayNodeTypeLabelKey: string(rayv1.WorkerNode), + }, + } + assert.Equal(t, expectedPodSelector, policy.Spec.PodSelector) + + // Verify ingress rules - workers only allow intra-cluster communication + require.Len(t, policy.Spec.Ingress, 1) + require.Len(t, policy.Spec.Ingress[0].From, 1) + + // Verify intra-cluster peer + intraClusterPeer := policy.Spec.Ingress[0].From[0] + expectedIntraClusterPeer := networkingv1.NetworkPolicyPeer{ + PodSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + utils.RayClusterLabelKey: testRayClusterBasic.Name, + }, + }, + } + assert.Equal(t, expectedIntraClusterPeer, intraClusterPeer) +} + +func TestBuildHeadNetworkPolicy_ClusterWithRayJob(t *testing.T) { + setupNetworkPolicyTest(t) + + // Set environment for testing + originalEnv := os.Getenv("POD_NAMESPACE") + os.Setenv("POD_NAMESPACE", "ray-system") + defer os.Setenv("POD_NAMESPACE", originalEnv) + + // Test building head NetworkPolicy for cluster owned by RayJob + kubeRayNamespaces := []string{"ray-system"} + policy := testNetworkPolicyController.buildHeadNetworkPolicy(testRayClusterWithRayJob, kubeRayNamespaces) + + // Verify basic properties + expectedName := testRayClusterWithRayJob.Name + "-head" + assert.Equal(t, expectedName, policy.Name) + + // Verify ingress rules - should have additional RayJob rule + assert.Greater(t, len(policy.Spec.Ingress), 4, "Should have additional RayJob ingress rule") + + // Find the RayJob rule (should be the last rule) + rayJobRule := policy.Spec.Ingress[len(policy.Spec.Ingress)-1] + require.Len(t, rayJobRule.From, 1, "RayJob rule should have one peer") + + // Verify RayJob peer + rayJobPeer := rayJobRule.From[0] + expectedRayJobPeer := networkingv1.NetworkPolicyPeer{ + PodSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "batch.kubernetes.io/job-name": "test-job", + }, + }, + } + assert.Equal(t, expectedRayJobPeer, rayJobPeer) +} + +func TestBuildHeadNetworkPolicy_MonitoringAccess(t *testing.T) { + setupNetworkPolicyTest(t) + + // Test building head NetworkPolicy with monitoring access + kubeRayNamespaces := []string{"ray-system"} + policy := testNetworkPolicyController.buildHeadNetworkPolicy(testRayClusterBasic, kubeRayNamespaces) + + // Find the monitoring rule (should have port 8080) + var monitoringRule *networkingv1.NetworkPolicyIngressRule + for _, rule := range policy.Spec.Ingress { + for _, port := range rule.Ports { + if port.Port != nil && port.Port.IntVal == 8080 { + monitoringRule = &rule + break + } + } + if monitoringRule != nil { + break + } + } + + require.NotNil(t, monitoringRule, "Should have monitoring rule with port 8080") + assert.Len(t, monitoringRule.Ports, 1, "Monitoring rule should have one port") + assert.Equal(t, int32(8080), monitoringRule.Ports[0].Port.IntVal, "Should be port 8080") + + // Should allow from multiple monitoring sources + assert.Greater(t, len(monitoringRule.From), 1, "Should allow from multiple monitoring sources") + + // Check for OpenShift monitoring namespace + foundOpenShiftMonitoring := false + for _, peer := range monitoringRule.From { + if peer.NamespaceSelector != nil { + for _, req := range peer.NamespaceSelector.MatchExpressions { + if req.Key == "kubernetes.io/metadata.name" && contains(req.Values, "openshift-monitoring") { + foundOpenShiftMonitoring = true + break + } + } + } + } + assert.True(t, foundOpenShiftMonitoring, "Should allow OpenShift monitoring namespace") +} + +func TestBuildHeadNetworkPolicy_SecuredPorts(t *testing.T) { + setupNetworkPolicyTest(t) + + // Test building head NetworkPolicy with secured ports (mTLS) + kubeRayNamespaces := []string{"ray-system"} + policy := testNetworkPolicyController.buildHeadNetworkPolicy(testRayClusterBasic, kubeRayNamespaces) + + // Find the secured ports rule + var securedPortsRule *networkingv1.NetworkPolicyIngressRule + for _, rule := range policy.Spec.Ingress { + for _, port := range rule.Ports { + if port.Port != nil && port.Port.IntVal == 8443 { + securedPortsRule = &rule + break + } + } + if securedPortsRule != nil { + break + } + } + + require.NotNil(t, securedPortsRule, "Should have secured ports rule") + assert.Len(t, securedPortsRule.Ports, 2, "Should have 2 secured ports") + + // Check for mTLS ports 8443 and 10001 + portFound8443 := false + portFound10001 := false + for _, port := range securedPortsRule.Ports { + if port.Port.IntVal == 8443 { + portFound8443 = true + } + if port.Port.IntVal == 10001 { + portFound10001 = true + } + } + assert.True(t, portFound8443, "Should include mTLS port 8443") + assert.True(t, portFound10001, "Should include mTLS port 10001") +} + +// Helper function to check if slice contains string +func contains(slice []string, item string) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +} + +func TestGetKubeRayNamespaces_EnvironmentFallback(t *testing.T) { + setupNetworkPolicyTest(t) + + // Test fallback when POD_NAMESPACE is not set + originalEnv := os.Getenv("POD_NAMESPACE") + os.Unsetenv("POD_NAMESPACE") + defer os.Setenv("POD_NAMESPACE", originalEnv) + + namespaces := testNetworkPolicyController.getKubeRayNamespaces(context.Background()) + + // Should fallback to "ray-system" namespace + assert.Equal(t, []string{"ray-system"}, namespaces) +} + +func TestGetKubeRayNamespaces_WithEnvironment(t *testing.T) { + setupNetworkPolicyTest(t) + + // Test with POD_NAMESPACE set + originalEnv := os.Getenv("POD_NAMESPACE") + os.Setenv("POD_NAMESPACE", "custom-ray-system") + defer os.Setenv("POD_NAMESPACE", originalEnv) + + namespaces := testNetworkPolicyController.getKubeRayNamespaces(context.Background()) + + // Should use the custom namespace + assert.Equal(t, []string{"custom-ray-system"}, namespaces) +} + +func TestBuildRayJobPeer_NoOwner(t *testing.T) { + setupNetworkPolicyTest(t) + + // Test RayCluster without owner + peer := testNetworkPolicyController.buildRayJobPeer(testRayClusterBasic) + assert.Nil(t, peer) +} + +func TestBuildRayJobPeer_WithRayJobOwner(t *testing.T) { + setupNetworkPolicyTest(t) + + // Test RayCluster with RayJob owner + peer := testNetworkPolicyController.buildRayJobPeer(testRayClusterWithRayJob) + require.NotNil(t, peer) + + expectedPeer := &networkingv1.NetworkPolicyPeer{ + PodSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "batch.kubernetes.io/job-name": "test-job", + }, + }, + } + assert.Equal(t, expectedPeer, peer) +} + +func TestBuildRayJobPeer_WithOtherOwner(t *testing.T) { + setupNetworkPolicyTest(t) + + // Test RayCluster with non-RayJob owner + peer := testNetworkPolicyController.buildRayJobPeer(testRayClusterWithOtherOwner) + assert.Nil(t, peer) +} + +func TestBuildRayJobPeer_MultipleOwners(t *testing.T) { + setupNetworkPolicyTest(t) + + // Create RayCluster with multiple owners, including RayJob + rayCluster := testRayClusterBasic.DeepCopy() + rayCluster.OwnerReferences = []metav1.OwnerReference{ + { + APIVersion: "apps/v1", + Kind: "Deployment", + Name: "test-deployment", + UID: "67890", + }, + { + APIVersion: "ray.io/v1", + Kind: "RayJob", + Name: "test-job", + UID: "12345", + }, + } + + peer := testNetworkPolicyController.buildRayJobPeer(rayCluster) + require.NotNil(t, peer) + + expectedPeer := &networkingv1.NetworkPolicyPeer{ + PodSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "batch.kubernetes.io/job-name": "test-job", + }, + }, + } + assert.Equal(t, expectedPeer, peer) +} + +func TestBuildHeadNetworkPolicy_DifferentNamespace(t *testing.T) { + setupNetworkPolicyTest(t) + + // Set custom operator namespace + originalEnv := os.Getenv("POD_NAMESPACE") + os.Setenv("POD_NAMESPACE", "custom-ray-system") + defer os.Setenv("POD_NAMESPACE", originalEnv) + + // Create cluster in different namespace + rayCluster := testRayClusterBasic.DeepCopy() + rayCluster.Namespace = "custom-namespace" + + kubeRayNamespaces := []string{"custom-ray-system"} + headPolicy := testNetworkPolicyController.buildHeadNetworkPolicy(rayCluster, kubeRayNamespaces) + + // Verify NetworkPolicy is created in the same namespace as RayCluster + assert.Equal(t, "custom-namespace", headPolicy.Namespace) + + // Verify head policy name + expectedHeadName := rayCluster.Name + "-head" + assert.Equal(t, expectedHeadName, headPolicy.Name) +} + +func TestBuildHeadNetworkPolicy_LongClusterName(t *testing.T) { + setupNetworkPolicyTest(t) + + // Test with long cluster name + longName := "very-long-cluster-name-that-might-cause-issues" + rayCluster := testRayClusterBasic.DeepCopy() + rayCluster.Name = longName + + kubeRayNamespaces := []string{"ray-system"} + headPolicy := testNetworkPolicyController.buildHeadNetworkPolicy(rayCluster, kubeRayNamespaces) + + // Verify name is constructed correctly + expectedHeadName := longName + "-head" + assert.Equal(t, expectedHeadName, headPolicy.Name) + + // Verify pod selector uses correct cluster name and targets head pods + expectedPodSelector := metav1.LabelSelector{ + MatchLabels: map[string]string{ + utils.RayClusterLabelKey: longName, + utils.RayNodeTypeLabelKey: string(rayv1.HeadNode), + }, + } + assert.Equal(t, expectedPodSelector, headPolicy.Spec.PodSelector) +} diff --git a/ray-operator/controllers/ray/suite_test.go b/ray-operator/controllers/ray/suite_test.go index 85c913e7bd6..e62237c09d9 100644 --- a/ray-operator/controllers/ray/suite_test.go +++ b/ray-operator/controllers/ray/suite_test.go @@ -23,6 +23,7 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" corev1 "k8s.io/api/core/v1" + networkingv1 "k8s.io/api/networking/v1" "k8s.io/client-go/kubernetes/scheme" "k8s.io/client-go/rest" ctrl "sigs.k8s.io/controller-runtime" @@ -36,6 +37,7 @@ import ( rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils" "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils/dashboardclient" + "github.com/ray-project/kuberay/ray-operator/pkg/features" ) // These tests use Ginkgo (BDD-style Go testing framework). Refer to @@ -73,6 +75,9 @@ func TestAPIs(t *testing.T) { var _ = BeforeSuite(func(ctx SpecContext) { logf.SetLogger(zap.New(zap.WriteTo(GinkgoWriter), zap.UseDevMode(true))) + // Enable NetworkPolicy feature gate for integration tests + features.SetFeatureGateDuringTest(GinkgoTB(), features.RayClusterNetworkPolicy, true) + By("bootstrapping test environment") testEnv = &envtest.Environment{ CRDDirectoryPaths: []string{filepath.Join("..", "..", "config", "crd", "bases")}, @@ -87,6 +92,10 @@ var _ = BeforeSuite(func(ctx SpecContext) { err = rayv1.AddToScheme(scheme.Scheme) Expect(err).NotTo(HaveOccurred()) + // Add networking scheme for NetworkPolicy resources + err = networkingv1.AddToScheme(scheme.Scheme) + Expect(err).NotTo(HaveOccurred()) + // +kubebuilder:scaffold:scheme k8sClient, err = client.New(cfg, client.Options{Scheme: scheme.Scheme}) @@ -128,6 +137,13 @@ var _ = BeforeSuite(func(ctx SpecContext) { err = NewRayJobReconciler(ctx, mgr, rayJobOptions, testClientProvider).SetupWithManager(mgr, 1) Expect(err).NotTo(HaveOccurred(), "failed to setup RayJob controller") + // NetworkPolicy controller (only registered if feature flag is enabled) + if features.Enabled(features.RayClusterNetworkPolicy) { + networkPolicyController := NewNetworkPolicyController(mgr) + err = networkPolicyController.SetupWithManager(mgr) + Expect(err).NotTo(HaveOccurred(), "failed to setup NetworkPolicy controller") + } + go func() { err = mgr.Start(ctrl.SetupSignalHandler()) Expect(err).ToNot(HaveOccurred()) diff --git a/ray-operator/controllers/ray/utils/constant.go b/ray-operator/controllers/ray/utils/constant.go index 545e696d503..6b999f1868b 100644 --- a/ray-operator/controllers/ray/utils/constant.go +++ b/ray-operator/controllers/ray/utils/constant.go @@ -322,6 +322,10 @@ const ( FailedToUpdateHeadPodServeLabel K8sEventType = "FailedToUpdateHeadPodServeLabel" FailedToUpdateServeApplications K8sEventType = "FailedToUpdateServeApplications" + // NetworkPolicy event list + CreatedNetworkPolicy K8sEventType = "CreatedNetworkPolicy" + FailedToCreateNetworkPolicy K8sEventType = "FailedToCreateNetworkPolicy" + // Generic Pod event list DeletedPod K8sEventType = "DeletedPod" FailedToDeletePod K8sEventType = "FailedToDeletePod" diff --git a/ray-operator/main.go b/ray-operator/main.go index 9acf7b5883e..294112a0cfb 100644 --- a/ray-operator/main.go +++ b/ray-operator/main.go @@ -281,6 +281,15 @@ func main() { exitOnError(ray.NewRayJobReconciler(ctx, mgr, rayJobOptions, config).SetupWithManager(mgr, config.ReconcileConcurrency), "unable to create controller", "controller", "RayJob") + // NetworkPolicy controller (only registered if feature flag is enabled) + if features.Enabled(features.RayClusterNetworkPolicy) { + exitOnError(ray.NewNetworkPolicyController(mgr).SetupWithManager(mgr), + "unable to create controller", "controller", "NetworkPolicy") + setupLog.Info("NetworkPolicy controller enabled") + } else { + setupLog.Info("NetworkPolicy controller disabled via feature flag") + } + if os.Getenv("ENABLE_WEBHOOKS") == "true" { exitOnError(webhooks.SetupRayClusterWebhookWithManager(mgr), "unable to create webhook", "webhook", "RayCluster") diff --git a/ray-operator/pkg/features/features.go b/ray-operator/pkg/features/features.go index 2abea2ffbbb..f1231786735 100644 --- a/ray-operator/pkg/features/features.go +++ b/ray-operator/pkg/features/features.go @@ -24,6 +24,9 @@ const ( // // Enables new deletion policy API in RayJob RayJobDeletionPolicy featuregate.Feature = "RayJobDeletionPolicy" + + // Might be overkill to have a feature gate for this but for the sake of argument... + RayClusterNetworkPolicy featuregate.Feature = "RayClusterNetworkPolicy" ) func init() { @@ -33,6 +36,7 @@ func init() { var defaultFeatureGates = map[featuregate.Feature]featuregate.FeatureSpec{ RayClusterStatusConditions: {Default: true, PreRelease: featuregate.Beta}, RayJobDeletionPolicy: {Default: false, PreRelease: featuregate.Alpha}, + RayClusterNetworkPolicy: {Default: false, PreRelease: featuregate.Alpha}, } // SetFeatureGateDuringTest is a helper method to override feature gates in tests. From 2e896f9989a2e9a4e4c91afb4c57a32effd258ca Mon Sep 17 00:00:00 2001 From: Pat O'Connor Date: Thu, 25 Sep 2025 16:57:38 +0100 Subject: [PATCH 2/2] updated tests re monitoring namespaces + mtls Signed-off-by: Pat O'Connor --- .../ray/networkpolicy_controller_unit_test.go | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/ray-operator/controllers/ray/networkpolicy_controller_unit_test.go b/ray-operator/controllers/ray/networkpolicy_controller_unit_test.go index b48e7489c1c..f64f15bbed0 100644 --- a/ray-operator/controllers/ray/networkpolicy_controller_unit_test.go +++ b/ray-operator/controllers/ray/networkpolicy_controller_unit_test.go @@ -397,22 +397,28 @@ func TestBuildHeadNetworkPolicy_MonitoringAccess(t *testing.T) { assert.Len(t, monitoringRule.Ports, 1, "Monitoring rule should have one port") assert.Equal(t, int32(8080), monitoringRule.Ports[0].Port.IntVal, "Should be port 8080") - // Should allow from multiple monitoring sources - assert.Greater(t, len(monitoringRule.From), 1, "Should allow from multiple monitoring sources") + // Should allow from monitoring sources (single From with multiple namespaces) + assert.Len(t, monitoringRule.From, 1, "Should have one monitoring peer with multiple namespaces") - // Check for OpenShift monitoring namespace + // Check for both OpenShift monitoring and Prometheus namespaces foundOpenShiftMonitoring := false + foundPrometheus := false for _, peer := range monitoringRule.From { if peer.NamespaceSelector != nil { for _, req := range peer.NamespaceSelector.MatchExpressions { - if req.Key == "kubernetes.io/metadata.name" && contains(req.Values, "openshift-monitoring") { - foundOpenShiftMonitoring = true - break + if req.Key == "kubernetes.io/metadata.name" { + if contains(req.Values, "openshift-monitoring") { + foundOpenShiftMonitoring = true + } + if contains(req.Values, "prometheus") { + foundPrometheus = true + } } } } } assert.True(t, foundOpenShiftMonitoring, "Should allow OpenShift monitoring namespace") + assert.True(t, foundPrometheus, "Should allow Prometheus namespace") } func TestBuildHeadNetworkPolicy_SecuredPorts(t *testing.T) { @@ -437,21 +443,16 @@ func TestBuildHeadNetworkPolicy_SecuredPorts(t *testing.T) { } require.NotNil(t, securedPortsRule, "Should have secured ports rule") - assert.Len(t, securedPortsRule.Ports, 2, "Should have 2 secured ports") + assert.Len(t, securedPortsRule.Ports, 1, "Should have 1 secured port (8443 only, no mTLS)") - // Check for mTLS ports 8443 and 10001 + // Check for mTLS port 8443 (always present) portFound8443 := false - portFound10001 := false for _, port := range securedPortsRule.Ports { if port.Port.IntVal == 8443 { portFound8443 = true } - if port.Port.IntVal == 10001 { - portFound10001 = true - } } assert.True(t, portFound8443, "Should include mTLS port 8443") - assert.True(t, portFound10001, "Should include mTLS port 10001") } // Helper function to check if slice contains string