diff --git a/pkg/service/controller.go b/pkg/service/controller.go index 3e6496fe..cb25e817 100644 --- a/pkg/service/controller.go +++ b/pkg/service/controller.go @@ -1,6 +1,7 @@ package service import ( + "fmt" "strconv" "time" @@ -66,7 +67,10 @@ func (c *ControllerService) validateCreateVolumeRequest(req *csi.CreateVolumeReq return false, status.Error(codes.InvalidArgument, "volume capabilities missing in request") } - isBlock, isRWX := getAccessMode(caps) + isBlock, isRWX, err := getAccessMode(caps) + if err != nil { + return false, err + } if isRWX && !isBlock { return false, status.Error(codes.InvalidArgument, "non-block volume with RWX access mode is not supported") @@ -91,10 +95,7 @@ func (c *ControllerService) validateCreateVolumeRequest(req *csi.CreateVolumeReq return isRWX, nil } -func getAccessMode(caps []*csi.VolumeCapability) (bool, bool) { - isBlock := false - isRWX := false - +func getAccessMode(caps []*csi.VolumeCapability) (isBlock, isRWX bool, err error) { for _, capability := range caps { if capability != nil { if capability.GetBlock() != nil { @@ -102,14 +103,38 @@ func getAccessMode(caps []*csi.VolumeCapability) (bool, bool) { } if am := capability.GetAccessMode(); am != nil { - if am.Mode == csi.VolumeCapability_AccessMode_MULTI_NODE_MULTI_WRITER { - isRWX = true - } + isRWX, err = hasRWXCapabiltyAccessMode(am) } } } - return isBlock, isRWX + return isBlock, isRWX, err +} + +// hasRWXCapabiltyAccessMode will return whether a volume is RWX. It may also +// return an error if the received access mode is unknown. +// +// Parameters: +// - cap: Volume Capability AccessMode. +// +// Returns: +// - bool: True if the capability represents an RWX volume. +// - error: Unsupported capability. +func hasRWXCapabiltyAccessMode(cap *csi.VolumeCapability_AccessMode) (bool, error) { + switch cap.GetMode() { + case + csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER, + csi.VolumeCapability_AccessMode_SINGLE_NODE_READER_ONLY, + csi.VolumeCapability_AccessMode_SINGLE_NODE_SINGLE_WRITER, + csi.VolumeCapability_AccessMode_SINGLE_NODE_MULTI_WRITER: + return false, nil + case + csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY, + csi.VolumeCapability_AccessMode_MULTI_NODE_SINGLE_WRITER, + csi.VolumeCapability_AccessMode_MULTI_NODE_MULTI_WRITER: + return true, nil + } + return false, fmt.Errorf("unknown volume capability") } // CreateVolume Create a new DataVolume. @@ -326,7 +351,31 @@ func (c *ControllerService) ControllerPublishVolume( if err := c.validateControllerPublishVolumeRequest(req); err != nil { return nil, err } + dvName := req.GetVolumeId() + + // Get VM name from node ID which is a namespace/name + _, vmName, err := cache.SplitMetaNamespaceKey(req.NodeId) + if err != nil { + klog.Error("failed getting VM Name for node ID " + req.NodeId) + return nil, err + } + + // Check if the volume is RWO, and if it is, check if its in a different Virtual Machine Instance. + isRWX, err := hasRWXCapabiltyAccessMode(req.GetVolumeCapability().GetAccessMode()) + if err != nil { + return nil, fmt.Errorf("error checking access mode: %w", err) + } + if !isRWX { + alreadyAttached, err := c.IsVolumeAttachedToOtherVMI(ctx, dvName, c.infraClusterNamespace, vmName) + if err != nil { + return nil, status.Errorf(codes.FailedPrecondition, "failed to check if volume is already attached: %s", err) + } + if alreadyAttached { + return nil, status.Errorf(codes.FailedPrecondition, "volume is attached to another VM") + } + } + if _, err := c.virtClient.GetDataVolume(ctx, c.infraClusterNamespace, dvName); errors.IsNotFound(err) { return nil, status.Errorf(codes.NotFound, "volume %s not found", req.GetVolumeId()) } else if err != nil { @@ -335,12 +384,6 @@ func (c *ControllerService) ControllerPublishVolume( klog.V(3).Infof("Attaching DataVolume %s to Node ID %s", dvName, req.NodeId) - // Get VM name from node ID which is a namespace/name - _, vmName, err := cache.SplitMetaNamespaceKey(req.NodeId) - if err != nil { - klog.Error("failed getting VM Name for node ID " + req.NodeId) - return nil, err - } _, err = c.virtClient.GetWorkloadManagingVirtualMachine(ctx, c.infraClusterNamespace, vmName) if err != nil { if !errors.IsNotFound(err) { @@ -850,3 +893,55 @@ func (c *ControllerService) ControllerGetCapabilities(context.Context, *csi.Cont func (c *ControllerService) ControllerGetVolume(_ context.Context, _ *csi.ControllerGetVolumeRequest) (*csi.ControllerGetVolumeResponse, error) { return nil, status.Error(codes.Unimplemented, "") } + +// IsVolumeAttachedToOtherVMI checks if a PVC is actively +// used by any VirtualMachineInstance other than the current one. +// +// NOTE: This function uses vmi.Status.VolumeStatus as the source of truth for +// what is currently attached. It directly compares the volume name in the status +// with the target PVC name. +// +// Parameters: +// - ctx: The context for cancellation. +// - dvName: The name of the PersistentVolumeClaim to check for. +// - infraNamespace: The namespace of the PersistentVolumeClaim. +// - currentVMIName: The name of the VMI for the current ControllerPublishVolume +// request. We want to ignore this VMI in our check. +// +// Returns: +// - bool: True if the volume is attached to another VMI. +// - error: An error if listing VMIs fails. +func (c *ControllerService) IsVolumeAttachedToOtherVMI( + ctx context.Context, + dvName string, + infraNamespace string, + currentVMIName string, +) (bool, error) { + vmis, err := c.virtClient.ListVirtualMachines(ctx, infraNamespace) + if err != nil { + return false, fmt.Errorf("failed to list Virtual Machine Instances in namespace %s: %w", infraNamespace, err) + } + + for _, vmi := range vmis { + // Skip the VMI that the volume is intended for. + if vmi.Name == currentVMIName { + continue + } + + // The source of truth is the VMI's status. We iterate through the volumes + // that are reported as active in the status. + for _, volumeStatus := range vmi.Status.VolumeStatus { + // If the name in the status matches our PVC name, it means the volume + // is actively attached to this other VMI. + if volumeStatus.Name == dvName { + klog.Infof( + "CONFLICT: PVC %s/%s is in use by VMI %s/%s", + infraNamespace, dvName, vmi.Namespace, vmi.Name, + ) + return true, nil + } + } + } + + return false, nil +} diff --git a/pkg/service/controller_test.go b/pkg/service/controller_test.go index 0de57598..378aa7db 100644 --- a/pkg/service/controller_test.go +++ b/pkg/service/controller_test.go @@ -381,6 +381,112 @@ var _ = Describe("PublishUnPublish", func() { Expect(err).ToNot(HaveOccurred()) Expect(capturingClient.hotunplugForVMIOccured).To(BeTrue()) }) + + It("should not publish an RWO volume that is not yet released by another VMI", func() { + // Create the DataVolume we will use. + dv, err := client.CreateDataVolume(context.TODO(), controller.infraClusterNamespace, &cdiv1.DataVolume{ + ObjectMeta: metav1.ObjectMeta{ + Name: testVolumeName, + Labels: testInfraLabels, + }, + Spec: cdiv1.DataVolumeSpec{ + Storage: &cdiv1.StorageSpec{ + StorageClassName: &testInfraStorageClassName, + Resources: corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceStorage: resource.MustParse("3Gi"), + }, + }, + }, + }, + }) + Expect(err).ToNot(HaveOccurred()) + // Attach the volume to VM 1. + client.datavolumes = make(map[string]*cdiv1.DataVolume) + client.datavolumes[getKey(testInfraNamespace, testVolumeName)] = dv + _, err = controller.ControllerPublishVolume(context.TODO(), getPublishVolumeRequest()) + Expect(err).ToNot(HaveOccurred()) + + // Attempt to attach the volume to VM 2. + client.ListVirtualMachineWithStatus = true + _, err = controller.ControllerPublishVolume(context.TODO(), genPublishVolumeRequest( + testVolumeName, + getKey(testInfraNamespace, testVMName2), + &csi.VolumeCapability{ + AccessMode: &csi.VolumeCapability_AccessMode{ + Mode: csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER, + }, + }, + )) + Expect(err).To(HaveOccurred()) + }) + + It("should publish an RWX volume that is not yet released by another VMI", func() { + // Create the DataVolume we will use. + dv, err := client.CreateDataVolume(context.TODO(), controller.infraClusterNamespace, &cdiv1.DataVolume{ + ObjectMeta: metav1.ObjectMeta{ + Name: testVolumeName, + Labels: testInfraLabels, + }, + Spec: cdiv1.DataVolumeSpec{ + Storage: &cdiv1.StorageSpec{ + StorageClassName: &testInfraStorageClassName, + Resources: corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceStorage: resource.MustParse("3Gi"), + }, + }, + }, + }, + }) + Expect(err).ToNot(HaveOccurred()) + // Attach the volume to VM 1. + client.datavolumes = make(map[string]*cdiv1.DataVolume) + client.datavolumes[getKey(testInfraNamespace, testVolumeName)] = dv + _, err = controller.ControllerPublishVolume(context.TODO(), getPublishVolumeRequest()) + Expect(err).ToNot(HaveOccurred()) + + // Attempt to attach the volume to VM 2. + client.ListVirtualMachineWithStatus = true + _, err = controller.ControllerPublishVolume(context.TODO(), genPublishVolumeRequest( + testVolumeName, + getKey(testInfraNamespace, testVMName2), + &csi.VolumeCapability{ + AccessMode: &csi.VolumeCapability_AccessMode{ + Mode: csi.VolumeCapability_AccessMode_MULTI_NODE_MULTI_WRITER, + }, + }, + )) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should not publish a volume with unknown access mode", func() { + dv, err := client.CreateDataVolume(context.TODO(), controller.infraClusterNamespace, &cdiv1.DataVolume{ + ObjectMeta: metav1.ObjectMeta{ + Name: testVolumeName, + Labels: testInfraLabels, + }, + Spec: cdiv1.DataVolumeSpec{ + Storage: &cdiv1.StorageSpec{ + StorageClassName: &testInfraStorageClassName, + Resources: corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceStorage: resource.MustParse("3Gi"), + }, + }, + }, + }, + }) + Expect(err).ToNot(HaveOccurred()) + client.datavolumes = make(map[string]*cdiv1.DataVolume) + client.datavolumes[getKey(testInfraNamespace, testVolumeName)] = dv + _, err = controller.ControllerPublishVolume(context.TODO(), genPublishVolumeRequest( + testVolumeName, + getKey(testInfraNamespace, testVMName), + &csi.VolumeCapability{}, + )) + Expect(err).To(HaveOccurred()) + }) }) var _ = Describe("Snapshots", func() { @@ -732,6 +838,7 @@ var ( testInfraNamespace = "tenant-cluster-2" testNodeID = getKey(testInfraNamespace, testVMName) testVMName = "test-vm" + testVMName2 = "test-vm2" testDataVolumeUID = "2d0111d5-494f-4731-8f67-122b27d3c366" testBusType *kubevirtv1.DiskBus = nil // nil==do not pass bus type testInfraLabels = map[string]string{"infra-label-name": "infra-label-value"} @@ -796,18 +903,26 @@ func getDeleteVolumeRequest() *csi.DeleteVolumeRequest { return &csi.DeleteVolumeRequest{VolumeId: testVolumeName} } -func getPublishVolumeRequest() *csi.ControllerPublishVolumeRequest { +func genPublishVolumeRequest(volumeName, nodeID string, capabilty *csi.VolumeCapability) *csi.ControllerPublishVolumeRequest { return &csi.ControllerPublishVolumeRequest{ - VolumeId: testVolumeName, - NodeId: testNodeID, + VolumeId: volumeName, + NodeId: nodeID, VolumeContext: map[string]string{ busParameter: string(getBusType()), serialParameter: testDataVolumeUID, }, - VolumeCapability: &csi.VolumeCapability{}, + VolumeCapability: capabilty, } } +func getPublishVolumeRequest() *csi.ControllerPublishVolumeRequest { + return genPublishVolumeRequest(testVolumeName, testNodeID, &csi.VolumeCapability{ + AccessMode: &csi.VolumeCapability_AccessMode{ + Mode: csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER, + }, + }) +} + func getUnpublishVolumeRequest() *csi.ControllerUnpublishVolumeRequest { return &csi.ControllerUnpublishVolumeRequest{ VolumeId: testVolumeName, @@ -816,23 +931,24 @@ func getUnpublishVolumeRequest() *csi.ControllerUnpublishVolumeRequest { } type ControllerClientMock struct { - FailListVirtualMachines bool - FailDeleteDataVolume bool - FailCreateDataVolume bool - FailGetDataVolume bool - FailAddVolumeToVM bool - FailRemoveVolumeFromVM bool - FailGetSnapshot bool - FailCreateSnapshot bool - FailDeleteSnapshot bool - FailListSnapshots bool - ShouldReturnVMNotFound bool - ExpansionOccured bool - ExpansionVerified bool - virtualMachineStatus kubevirtv1.VirtualMachineInstanceStatus - vmVolumes []kubevirtv1.Volume - snapshots map[string]*snapshotv1.VolumeSnapshot - datavolumes map[string]*cdiv1.DataVolume + FailListVirtualMachines bool + ListVirtualMachineWithStatus bool + FailDeleteDataVolume bool + FailCreateDataVolume bool + FailGetDataVolume bool + FailAddVolumeToVM bool + FailRemoveVolumeFromVM bool + FailGetSnapshot bool + FailCreateSnapshot bool + FailDeleteSnapshot bool + FailListSnapshots bool + ShouldReturnVMNotFound bool + ExpansionOccured bool + ExpansionVerified bool + virtualMachineStatus kubevirtv1.VirtualMachineInstanceStatus + vmVolumes []kubevirtv1.Volume + snapshots map[string]*snapshotv1.VolumeSnapshot + datavolumes map[string]*cdiv1.DataVolume } func (c *ControllerClientMock) Ping(ctx context.Context) error { @@ -852,6 +968,24 @@ func (c *ControllerClientMock) ListVirtualMachines(_ context.Context, namespace return nil, errors.New("ListVirtualMachines failed") } + if c.ListVirtualMachineWithStatus { + return []kubevirtv1.VirtualMachineInstance{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: testVMName, + Namespace: namespace, + }, + Status: kubevirtv1.VirtualMachineInstanceStatus{ + VolumeStatus: []kubevirtv1.VolumeStatus{ + { + Name: testVolumeName, + }, + }, + }, + }, + }, nil + } + return []kubevirtv1.VirtualMachineInstance{ { ObjectMeta: metav1.ObjectMeta{ @@ -966,7 +1100,6 @@ func (c *ControllerClientMock) AddVolumeToVM(_ context.Context, namespace string } // Test input - Expect(testVMName).To(Equal(vmName)) Expect(testVolumeName).To(Equal(addVolumeOptions.Name)) Expect(testVolumeName).To(Equal(addVolumeOptions.VolumeSource.DataVolume.Name)) Expect(getBusType()).To(Equal(addVolumeOptions.Disk.DiskDevice.Disk.Bus))