diff --git a/driver/controller.go b/driver/controller.go index 35997eada..e96c69f5b 100644 --- a/driver/controller.go +++ b/driver/controller.go @@ -23,6 +23,7 @@ import ( "net/http" "strconv" "strings" + "sync" "time" "github.com/container-storage-interface/spec/lib/go/csi" @@ -78,9 +79,30 @@ var ( } ) +type Controller struct { + // publishInfoVolumeName is used to pass the volume name from + // `ControllerPublishVolume` to `NodeStageVolume or `NodePublishVolume` + publishInfoVolumeName string + region string + doTag string + defaultVolumesPageSize uint + + storage godo.StorageService + storageActions godo.StorageActionsService + droplets godo.DropletsService + snapshots godo.SnapshotsService + account godo.AccountService + tags godo.TagsService + + healthChecker *HealthChecker + log *logrus.Entry + + readyMu sync.Mutex +} + // CreateVolume creates a new volume from the given request. The function is // idempotent. -func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) { +func (d *Controller) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) { if req.Name == "" { return nil, status.Error(codes.InvalidArgument, "CreateVolume Name must be provided") } @@ -230,7 +252,7 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) } // DeleteVolume deletes the given volume. The function is idempotent. -func (d *Driver) DeleteVolume(ctx context.Context, req *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) { +func (d *Controller) DeleteVolume(ctx context.Context, req *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) { if req.VolumeId == "" { return nil, status.Error(codes.InvalidArgument, "DeleteVolume Volume ID must be provided") } @@ -259,7 +281,7 @@ func (d *Driver) DeleteVolume(ctx context.Context, req *csi.DeleteVolumeRequest) } // ControllerPublishVolume attaches the given volume to the node -func (d *Driver) ControllerPublishVolume(ctx context.Context, req *csi.ControllerPublishVolumeRequest) (*csi.ControllerPublishVolumeResponse, error) { +func (d *Controller) ControllerPublishVolume(ctx context.Context, req *csi.ControllerPublishVolumeRequest) (*csi.ControllerPublishVolumeResponse, error) { if req.VolumeId == "" { return nil, status.Error(codes.InvalidArgument, "ControllerPublishVolume Volume ID must be provided") } @@ -389,7 +411,7 @@ func (d *Driver) ControllerPublishVolume(ctx context.Context, req *csi.Controlle } // ControllerUnpublishVolume deattaches the given volume from the node -func (d *Driver) ControllerUnpublishVolume(ctx context.Context, req *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) { +func (d *Controller) ControllerUnpublishVolume(ctx context.Context, req *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) { if req.VolumeId == "" { return nil, status.Error(codes.InvalidArgument, "ControllerUnpublishVolume Volume ID must be provided") } @@ -475,7 +497,7 @@ func (d *Driver) ControllerUnpublishVolume(ctx context.Context, req *csi.Control // ValidateVolumeCapabilities checks whether the volume capabilities requested // are supported. -func (d *Driver) ValidateVolumeCapabilities(ctx context.Context, req *csi.ValidateVolumeCapabilitiesRequest) (*csi.ValidateVolumeCapabilitiesResponse, error) { +func (d *Controller) ValidateVolumeCapabilities(ctx context.Context, req *csi.ValidateVolumeCapabilitiesRequest) (*csi.ValidateVolumeCapabilitiesResponse, error) { if req.VolumeId == "" { return nil, status.Error(codes.InvalidArgument, "ValidateVolumeCapabilities Volume ID must be provided") } @@ -517,7 +539,7 @@ func (d *Driver) ValidateVolumeCapabilities(ctx context.Context, req *csi.Valida } // ListVolumes returns a list of all requested volumes -func (d *Driver) ListVolumes(ctx context.Context, req *csi.ListVolumesRequest) (*csi.ListVolumesResponse, error) { +func (d *Controller) ListVolumes(ctx context.Context, req *csi.ListVolumesRequest) (*csi.ListVolumesResponse, error) { maxEntries := req.MaxEntries if maxEntries == 0 && d.defaultVolumesPageSize > 0 { maxEntries = int32(d.defaultVolumesPageSize) @@ -596,7 +618,7 @@ func (d *Driver) ListVolumes(ctx context.Context, req *csi.ListVolumesRequest) ( } // GetCapacity returns the capacity of the storage pool -func (d *Driver) GetCapacity(ctx context.Context, req *csi.GetCapacityRequest) (*csi.GetCapacityResponse, error) { +func (d *Controller) GetCapacity(ctx context.Context, req *csi.GetCapacityRequest) (*csi.GetCapacityResponse, error) { // TODO(arslan): check if we can provide this information somehow d.log.WithFields(logrus.Fields{ "params": req.Parameters, @@ -606,7 +628,7 @@ func (d *Driver) GetCapacity(ctx context.Context, req *csi.GetCapacityRequest) ( } // ControllerGetCapabilities returns the capabilities of the controller service. -func (d *Driver) ControllerGetCapabilities(ctx context.Context, req *csi.ControllerGetCapabilitiesRequest) (*csi.ControllerGetCapabilitiesResponse, error) { +func (d *Controller) ControllerGetCapabilities(ctx context.Context, req *csi.ControllerGetCapabilitiesRequest) (*csi.ControllerGetCapabilitiesResponse, error) { newCap := func(cap csi.ControllerServiceCapability_RPC_Type) *csi.ControllerServiceCapability { return &csi.ControllerServiceCapability{ Type: &csi.ControllerServiceCapability_Rpc{ @@ -643,7 +665,7 @@ func (d *Driver) ControllerGetCapabilities(ctx context.Context, req *csi.Control // CreateSnapshot will be called by the CO to create a new snapshot from a // source volume on behalf of a user. -func (d *Driver) CreateSnapshot(ctx context.Context, req *csi.CreateSnapshotRequest) (*csi.CreateSnapshotResponse, error) { +func (d *Controller) CreateSnapshot(ctx context.Context, req *csi.CreateSnapshotRequest) (*csi.CreateSnapshotResponse, error) { if req.GetName() == "" { return nil, status.Error(codes.InvalidArgument, "CreateSnapshot Name must be provided") } @@ -739,7 +761,7 @@ func (d *Driver) CreateSnapshot(ctx context.Context, req *csi.CreateSnapshotRequ } // DeleteSnapshot will be called by the CO to delete a snapshot. -func (d *Driver) DeleteSnapshot(ctx context.Context, req *csi.DeleteSnapshotRequest) (*csi.DeleteSnapshotResponse, error) { +func (d *Controller) DeleteSnapshot(ctx context.Context, req *csi.DeleteSnapshotRequest) (*csi.DeleteSnapshotResponse, error) { log := d.log.WithFields(logrus.Fields{ "req_snapshot_id": req.GetSnapshotId(), "method": "delete_snapshot", @@ -772,7 +794,7 @@ func (d *Driver) DeleteSnapshot(ctx context.Context, req *csi.DeleteSnapshotRequ // system within the given parameters regardless of how they were created. // ListSnapshots shold not list a snapshot that is being created but has not // been cut successfully yet. -func (d *Driver) ListSnapshots(ctx context.Context, req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) { +func (d *Controller) ListSnapshots(ctx context.Context, req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) { listResp := &csi.ListSnapshotsResponse{} log := d.log.WithFields(logrus.Fields{ "snapshot_id": req.SnapshotId, @@ -862,7 +884,7 @@ func (d *Driver) ListSnapshots(ctx context.Context, req *csi.ListSnapshotsReques } // ControllerExpandVolume is called from the resizer to increase the volume size. -func (d *Driver) ControllerExpandVolume(ctx context.Context, req *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) { +func (d *Controller) ControllerExpandVolume(ctx context.Context, req *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) { volID := req.GetVolumeId() if len(volID) == 0 { @@ -928,7 +950,7 @@ func (d *Driver) ControllerExpandVolume(ctx context.Context, req *csi.Controller // The call is used for the CSI health check feature // (https://github.com/kubernetes/enhancements/pull/1077) which we do not // support yet. -func (d *Driver) ControllerGetVolume(ctx context.Context, req *csi.ControllerGetVolumeRequest) (*csi.ControllerGetVolumeResponse, error) { +func (d *Controller) ControllerGetVolume(ctx context.Context, req *csi.ControllerGetVolumeRequest) (*csi.ControllerGetVolumeResponse, error) { return nil, status.Error(codes.Unimplemented, "") } @@ -936,7 +958,7 @@ func (d *Driver) ControllerGetVolume(ctx context.Context, req *csi.ControllerGet // range. If the capacity range is not satisfied it returns the default volume // size. If the capacity range is above supported sizes, it returns an // error. If the capacity range is below supported size, it returns the minimum supported size -func (d *Driver) extractStorage(capRange *csi.CapacityRange) (int64, error) { +func (d *Controller) extractStorage(capRange *csi.CapacityRange) (int64, error) { if capRange == nil { return defaultVolumeSizeInBytes, nil } @@ -1016,7 +1038,7 @@ func formatBytes(inputBytes int64) string { } // waitAction waits until the given action for the volume has completed. -func (d *Driver) waitAction(ctx context.Context, log *logrus.Entry, volumeID string, actionID int) error { +func (d *Controller) waitAction(ctx context.Context, log *logrus.Entry, volumeID string, actionID int) error { err := wait.PollUntil(1*time.Second, func() (done bool, err error) { action, _, err := d.storageActions.Get(ctx, volumeID, actionID) if err != nil { @@ -1057,7 +1079,7 @@ type limitDetails struct { } // checkLimit checks whether the user hit their account volume limit. -func (d *Driver) checkLimit(ctx context.Context) (*limitDetails, error) { +func (d *Controller) checkLimit(ctx context.Context) (*limitDetails, error) { // only one provisioner runs, we can make sure to prevent burst creation d.readyMu.Lock() defer d.readyMu.Unlock() @@ -1144,7 +1166,7 @@ func validateCapabilities(caps []*csi.VolumeCapability) []string { return violations.List() } -func (d *Driver) tagVolume(parentCtx context.Context, vol *godo.Volume) error { +func (d *Controller) tagVolume(parentCtx context.Context, vol *godo.Volume) error { for _, tag := range vol.Tags { if tag == d.doTag { return nil diff --git a/driver/driver.go b/driver/driver.go index 904683fc9..6b9ed7c34 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -57,37 +57,22 @@ var ( // csi.NodeServer // type Driver struct { - name string - // publishInfoVolumeName is used to pass the volume name from - // `ControllerPublishVolume` to `NodeStageVolume or `NodePublishVolume` - publishInfoVolumeName string - - endpoint string - debugAddr string - hostID func() string - region string - doTag string - isController bool - defaultVolumesPageSize uint + name string + endpoint string + debugAddr string + isController bool srv *grpc.Server httpSrv *http.Server log *logrus.Entry - mounter Mounter - - storage godo.StorageService - storageActions godo.StorageActionsService - droplets godo.DropletsService - snapshots godo.SnapshotsService - account godo.AccountService - tags godo.TagsService - - healthChecker *HealthChecker // ready defines whether the driver is ready to function. This value will // be used by the `Identity` service via the `Probe()` method. readyMu sync.Mutex // protects ready ready bool + + csi.NodeServer + csi.ControllerServer } // NewDriverParams defines the parameters that can be passed to NewDriver. @@ -131,52 +116,75 @@ func NewDriver(p NewDriverParams) (*Driver, error) { } hostID := strconv.Itoa(hostIDInt) - var opts []godo.ClientOpt - opts = append(opts, godo.SetBaseURL(p.URL)) - - if version == "" { - version = "dev" - } - opts = append(opts, godo.SetUserAgent("csi-digitalocean/"+version)) - - doClient, err := godo.New(oauthClient, opts...) - if err != nil { - return nil, fmt.Errorf("couldn't initialize DigitalOcean client: %s", err) - } - - healthChecker := NewHealthChecker(&doHealthChecker{account: doClient.Account}) - log := logrus.New().WithFields(logrus.Fields{ "region": region, "host_id": hostID, "version": version, }) - return &Driver{ - name: driverName, - publishInfoVolumeName: driverName + "/volume-name", - - doTag: p.DOTag, - endpoint: p.Endpoint, - debugAddr: p.DebugAddr, - defaultVolumesPageSize: p.DefaultVolumesPageSize, - - hostID: func() string { return hostID }, - region: region, - mounter: newMounter(log), - log: log, - // we're assuming only the controller has a non-empty token. - isController: p.Token != "", - - storage: doClient.Storage, - storageActions: doClient.StorageActions, - droplets: doClient.Droplets, - snapshots: doClient.Snapshots, - account: doClient.Account, - tags: doClient.Tags, - - healthChecker: healthChecker, - }, nil + var driver *Driver + // we're assuming only the controller has a non-empty token. + if p.Token != "" { + var opts []godo.ClientOpt + opts = append(opts, godo.SetBaseURL(p.URL)) + + if version == "" { + version = "dev" + } + opts = append(opts, godo.SetUserAgent("csi-digitalocean/"+version)) + + doClient, err := godo.New(oauthClient, opts...) + if err != nil { + return nil, fmt.Errorf("couldn't initialize DigitalOcean client: %s", err) + } + + healthChecker := NewHealthChecker(&doHealthChecker{account: doClient.Account}) + + controller := &Controller{ + publishInfoVolumeName: driverName + "/volume-name", + region: region, + doTag: p.DOTag, + defaultVolumesPageSize: p.DefaultVolumesPageSize, + + storage: doClient.Storage, + storageActions: doClient.StorageActions, + droplets: doClient.Droplets, + snapshots: doClient.Snapshots, + account: doClient.Account, + tags: doClient.Tags, + + healthChecker: healthChecker, + log: log, + } + + driver = &Driver{ + name: driverName, + endpoint: p.Endpoint, + debugAddr: p.DebugAddr, + isController: p.Token != "", + + ControllerServer: controller, + } + } else { + node := &Node{ + publishInfoVolumeName: driverName + "/volume-name", + region: region, + hostID: func() string { return hostID }, + log: log, + mounter: newMounter(log), + } + + driver = &Driver{ + name: driverName, + endpoint: p.Endpoint, + debugAddr: p.DebugAddr, + isController: p.Token != "", + + NodeServer: node, + } + } + + return driver, nil } // Run starts the CSI plugin by communication over the given endpoint @@ -217,11 +225,15 @@ func (d *Driver) Run(ctx context.Context) error { return resp, err } + d.srv = grpc.NewServer(grpc.UnaryInterceptor(errHandler)) + csi.RegisterIdentityServer(d.srv, d) + // warn the user, it'll not propagate to the user but at least we see if // something is wrong in the logs. Only check if the driver is running with // a token (i.e: controller) if d.isController { - details, err := d.checkLimit(context.Background()) + controller := d.ControllerServer.(*Controller) + details, err := controller.checkLimit(context.Background()) if err != nil { return fmt.Errorf("failed to check volumes limits on startup: %s", err) } @@ -235,7 +247,7 @@ func (d *Driver) Run(ctx context.Context) error { if d.debugAddr != "" { mux := http.NewServeMux() mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { - err := d.healthChecker.Check(r.Context()) + err := controller.healthChecker.Check(r.Context()) if err != nil { d.log.WithError(err).Error("executing health check") http.Error(w, err.Error(), http.StatusInternalServerError) @@ -248,13 +260,11 @@ func (d *Driver) Run(ctx context.Context) error { Handler: mux, } } + csi.RegisterControllerServer(d.srv, controller) + } else { + csi.RegisterNodeServer(d.srv, d.NodeServer.(*Node)) } - d.srv = grpc.NewServer(grpc.UnaryInterceptor(errHandler)) - csi.RegisterIdentityServer(d.srv, d) - csi.RegisterControllerServer(d.srv, d) - csi.RegisterNodeServer(d.srv, d) - d.ready = true // we're now ready to go! d.log.WithFields(logrus.Fields{ "grpc_addr": grpcAddr, diff --git a/driver/node.go b/driver/node.go index 758b8a2e0..640977e5a 100644 --- a/driver/node.go +++ b/driver/node.go @@ -61,11 +61,22 @@ var ( } ) +type Node struct { + // publishInfoVolumeName is used to pass the volume name from + // `ControllerPublishVolume` to `NodeStageVolume or `NodePublishVolume` + publishInfoVolumeName string + region string + hostID func() string + + log *logrus.Entry + mounter Mounter +} + // NodeStageVolume mounts the volume to a staging path on the node. This is // called by the CO before NodePublishVolume and is used to temporary mount the // volume to a staging path. Once mounted, NodePublishVolume will make sure to // mount it to the appropriate path -func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRequest) (*csi.NodeStageVolumeResponse, error) { +func (d *Node) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRequest) (*csi.NodeStageVolumeResponse, error) { if req.VolumeId == "" { return nil, status.Error(codes.InvalidArgument, "NodeStageVolume Volume ID must be provided") } @@ -165,7 +176,7 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe } // NodeUnstageVolume unstages the volume from the staging path -func (d *Driver) NodeUnstageVolume(ctx context.Context, req *csi.NodeUnstageVolumeRequest) (*csi.NodeUnstageVolumeResponse, error) { +func (d *Node) NodeUnstageVolume(ctx context.Context, req *csi.NodeUnstageVolumeRequest) (*csi.NodeUnstageVolumeResponse, error) { if req.VolumeId == "" { return nil, status.Error(codes.InvalidArgument, "NodeUnstageVolume Volume ID must be provided") } @@ -201,7 +212,7 @@ func (d *Driver) NodeUnstageVolume(ctx context.Context, req *csi.NodeUnstageVolu } // NodePublishVolume mounts the volume mounted to the staging path to the target path -func (d *Driver) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolumeRequest) (*csi.NodePublishVolumeResponse, error) { +func (d *Node) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolumeRequest) (*csi.NodePublishVolumeResponse, error) { if req.VolumeId == "" { return nil, status.Error(codes.InvalidArgument, "NodePublishVolume Volume ID must be provided") } @@ -250,7 +261,7 @@ func (d *Driver) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolu } // NodeUnpublishVolume unmounts the volume from the target path -func (d *Driver) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpublishVolumeRequest) (*csi.NodeUnpublishVolumeResponse, error) { +func (d *Node) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpublishVolumeRequest) (*csi.NodeUnpublishVolumeResponse, error) { if req.VolumeId == "" { return nil, status.Error(codes.InvalidArgument, "NodeUnpublishVolume Volume ID must be provided") } @@ -276,7 +287,7 @@ func (d *Driver) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpublish } // NodeGetCapabilities returns the supported capabilities of the node server -func (d *Driver) NodeGetCapabilities(ctx context.Context, req *csi.NodeGetCapabilitiesRequest) (*csi.NodeGetCapabilitiesResponse, error) { +func (d *Node) NodeGetCapabilities(ctx context.Context, req *csi.NodeGetCapabilitiesRequest) (*csi.NodeGetCapabilitiesResponse, error) { nscaps := []*csi.NodeServiceCapability{ &csi.NodeServiceCapability{ Type: &csi.NodeServiceCapability_Rpc{ @@ -314,7 +325,7 @@ func (d *Driver) NodeGetCapabilities(ctx context.Context, req *csi.NodeGetCapabi // should eventually return the droplet ID if possible. This is used so the CO // knows where to place the workload. The result of this function will be used // by the CO in ControllerPublishVolume. -func (d *Driver) NodeGetInfo(ctx context.Context, req *csi.NodeGetInfoRequest) (*csi.NodeGetInfoResponse, error) { +func (d *Node) NodeGetInfo(ctx context.Context, req *csi.NodeGetInfoRequest) (*csi.NodeGetInfoResponse, error) { d.log.WithField("method", "node_get_info").Info("node get info called") return &csi.NodeGetInfoResponse{ NodeId: d.hostID(), @@ -331,7 +342,7 @@ func (d *Driver) NodeGetInfo(ctx context.Context, req *csi.NodeGetInfoRequest) ( // NodeGetVolumeStats returns the volume capacity statistics available for the // the given volume. -func (d *Driver) NodeGetVolumeStats(ctx context.Context, req *csi.NodeGetVolumeStatsRequest) (*csi.NodeGetVolumeStatsResponse, error) { +func (d *Node) NodeGetVolumeStats(ctx context.Context, req *csi.NodeGetVolumeStatsRequest) (*csi.NodeGetVolumeStatsResponse, error) { if req.VolumeId == "" { return nil, status.Error(codes.InvalidArgument, "NodeGetVolumeStats Volume ID must be provided") } @@ -412,7 +423,7 @@ func (d *Driver) NodeGetVolumeStats(ctx context.Context, req *csi.NodeGetVolumeS }, nil } -func (d *Driver) NodeExpandVolume(ctx context.Context, req *csi.NodeExpandVolumeRequest) (*csi.NodeExpandVolumeResponse, error) { +func (d *Node) NodeExpandVolume(ctx context.Context, req *csi.NodeExpandVolumeRequest) (*csi.NodeExpandVolumeResponse, error) { volumeID := req.GetVolumeId() if len(volumeID) == 0 { return nil, status.Error(codes.InvalidArgument, "NodeExpandVolume volume ID not provided") @@ -470,7 +481,7 @@ func (d *Driver) NodeExpandVolume(ctx context.Context, req *csi.NodeExpandVolume return &csi.NodeExpandVolumeResponse{}, nil } -func (d *Driver) nodePublishVolumeForFileSystem(req *csi.NodePublishVolumeRequest, mountOptions []string, log *logrus.Entry) error { +func (d *Node) nodePublishVolumeForFileSystem(req *csi.NodePublishVolumeRequest, mountOptions []string, log *logrus.Entry) error { source := req.StagingTargetPath target := req.TargetPath @@ -508,7 +519,7 @@ func (d *Driver) nodePublishVolumeForFileSystem(req *csi.NodePublishVolumeReques return nil } -func (d *Driver) nodePublishVolumeForBlock(req *csi.NodePublishVolumeRequest, mountOptions []string, log *logrus.Entry) error { +func (d *Node) nodePublishVolumeForBlock(req *csi.NodePublishVolumeRequest, mountOptions []string, log *logrus.Entry) error { volumeName, ok := req.GetPublishContext()[d.publishInfoVolumeName] if !ok { return status.Error(codes.InvalidArgument, fmt.Sprintf("Could not find the volume name from the publish context %q", d.publishInfoVolumeName))