diff --git a/cmd/initializers.go b/cmd/initializers.go index 7d794df..4f5ea4e 100644 --- a/cmd/initializers.go +++ b/cmd/initializers.go @@ -16,6 +16,7 @@ import ( "waverless/pkg/interfaces" "waverless/pkg/logger" "waverless/pkg/provider" + "waverless/pkg/provider/gmi" "waverless/pkg/provider/k8s" "waverless/pkg/provider/novita" "waverless/pkg/resource" @@ -204,9 +205,17 @@ func (app *Application) initServices() error { } } + // Get GMI deployment provider for status sync + var gmiDeployProvider *gmi.GMIDeploymentProvider + if app.config.GMI.Enabled { + if gmiProv, ok := app.deploymentProvider.(*gmi.GMIDeploymentProvider); ok { + gmiDeployProvider = gmiProv + } + } + // Setup Lifecycle Manager for unified watcher management // This replaces all the individual setup*Watcher methods - if err := app.setupLifecycleManager(k8sDeployProvider, novitaDeployProvider); err != nil { + if err := app.setupLifecycleManager(k8sDeployProvider, novitaDeployProvider, gmiDeployProvider); err != nil { logger.WarnCtx(app.ctx, "Failed to setup lifecycle manager: %v (non-critical, continuing)", err) } @@ -237,7 +246,7 @@ func (app *Application) initServices() error { // setupLifecycleManager sets up the unified lifecycle manager for all providers // This replaces the individual setup*Watcher methods with a centralized approach -func (app *Application) setupLifecycleManager(k8sProvider *k8s.K8sDeploymentProvider, novitaProvider *novita.NovitaDeploymentProvider) error { +func (app *Application) setupLifecycleManager(k8sProvider *k8s.K8sDeploymentProvider, novitaProvider *novita.NovitaDeploymentProvider, gmiProvider *gmi.GMIDeploymentProvider) error { // Create lifecycle manager app.lifecycleManager = lifecycle.NewManager( app.ctx, @@ -276,6 +285,14 @@ func (app *Application) setupLifecycleManager(k8sProvider *k8s.K8sDeploymentProv } } + // Register GMI provider if enabled + if gmiProvider != nil { + logger.InfoCtx(app.ctx, "Registering GMI provider with lifecycle manager...") + if err := app.lifecycleManager.RegisterGMIProvider(gmiProvider); err != nil { + logger.WarnCtx(app.ctx, "Failed to register GMI provider: %v", err) + } + } + logger.InfoCtx(app.ctx, "Lifecycle manager setup completed with providers: %v", app.lifecycleManager.GetRegisteredProviders()) return nil } @@ -288,8 +305,8 @@ func (app *Application) initHandlers() error { app.statisticsHandler = handler.NewStatisticsHandler(app.statisticsService, app.workerService) app.monitoringHandler = handler.NewMonitoringHandler(app.monitoringService) - // Initialize Endpoint Handler (for K8s or Novita) - if app.config.K8s.Enabled || app.config.Novita.Enabled { + // Initialize Endpoint Handler (for K8s, Novita or GMI) + if app.config.K8s.Enabled || app.config.Novita.Enabled || app.config.GMI.Enabled { if app.deploymentProvider == nil { logger.ErrorCtx(app.ctx, "Deployment provider is enabled but provider is nil") } else { @@ -300,6 +317,9 @@ func (app *Application) initHandlers() error { if app.config.Novita.Enabled { logger.InfoCtx(app.ctx, "Endpoint handler initialized for Novita") } + if app.config.GMI.Enabled { + logger.InfoCtx(app.ctx, "Endpoint handler initialized for GMI") + } } } @@ -347,8 +367,7 @@ func (app *Application) initAutoScaler() error { logger.InfoCtx(app.ctx, "Spec service injected into SpecManager - specs will be read from database first") } } else { - logger.WarnCtx(app.ctx, "AutoScaler requires K8s deployment provider, skipping initialization") - return nil + logger.WarnCtx(app.ctx, "K8s deployment provider not available, autoscaler will run without resource limits") } autoscalerConfig := &autoscaler.Config{ diff --git a/config/config.example.yaml b/config/config.example.yaml index 9bc864b..5811325 100644 --- a/config/config.example.yaml +++ b/config/config.example.yaml @@ -95,3 +95,12 @@ novita: base_url: "https://api.novita.ai" # Novita API base URL config_dir: "./config" # Configuration directory (contains specs.yaml and templates/) poll_interval: 10 # Poll interval for status updates (seconds, default: 10) + +# GMI Serverless Configuration +gmi: + enabled: false # Enable GMI serverless provider + api_key: "xxx" # Your GMI API key (Bearer token) + base_url: "http://34.105.52.179:31818" # GMI API base URL + callback_url: "http://waverless-svc" # External URL of waverless server for worker callbacks + config_dir: "./config" # Configuration directory (contains specs.yaml and templates/) + poll_interval: 10 # Poll interval for status updates (seconds, default: 10) \ No newline at end of file diff --git a/internal/service/lifecycle/manager.go b/internal/service/lifecycle/manager.go index 80552e5..dfaa087 100644 --- a/internal/service/lifecycle/manager.go +++ b/internal/service/lifecycle/manager.go @@ -11,6 +11,7 @@ import ( "waverless/pkg/interfaces" "waverless/pkg/logger" "waverless/pkg/provider" + "waverless/pkg/provider/gmi" "waverless/pkg/provider/k8s" "waverless/pkg/provider/novita" "waverless/pkg/status" @@ -151,6 +152,38 @@ func (m *Manager) RegisterNovitaProvider(provider *novita.NovitaDeploymentProvid return nil } +// RegisterGMIProvider registers GMI Provider and starts its watchers +func (m *Manager) RegisterGMIProvider(provider *gmi.GMIDeploymentProvider) error { + if provider == nil { + return fmt.Errorf("gmi provider is nil") + } + + m.mu.Lock() + defer m.mu.Unlock() + + name := "gmi" + if _, exists := m.providers[name]; exists { + logger.WarnCtx(m.ctx, "Provider %s already registered, skipping", name) + return nil + } + + lifecycle := provider.GetLifecycle() + if lifecycle == nil { + return fmt.Errorf("gmi provider lifecycle is nil") + } + + callbacks := m.createGMICallbacks() + + if err := lifecycle.RegisterWatchers(callbacks); err != nil { + return fmt.Errorf("failed to register watchers for gmi provider: %w", err) + } + + m.providers[name] = lifecycle + logger.InfoCtx(m.ctx, "GMI Provider registered successfully") + + return nil +} + // createK8sCallbacks creates K8s callback functions func (m *Manager) createK8sCallbacks(k8sProvider *k8s.K8sDeploymentProvider) *k8s.K8sLifecycleCallbacks { return &k8s.K8sLifecycleCallbacks{ @@ -262,6 +295,41 @@ func (m *Manager) createNovitaCallbacks() *novita.NovitaLifecycleCallbacks { } } +// createGMICallbacks creates GMI callback functions +func (m *Manager) createGMICallbacks() *gmi.GMILifecycleCallbacks { + return &gmi.GMILifecycleCallbacks{ + OnWorkerStatusChange: func(workerID, endpoint string, podInfo *interfaces.PodInfo) { + m.callbackHandler.HandleWorkerStatusChange(&provider.WorkerStatusEvent{ + WorkerID: workerID, + Endpoint: endpoint, + PodInfo: podInfo, + }) + }, + OnWorkerDelete: func(workerID, endpoint string) { + m.callbackHandler.HandleWorkerDelete(&provider.WorkerDeleteEvent{ + WorkerID: workerID, + Endpoint: endpoint, + }) + }, + OnWorkerFailure: func(workerID, endpoint string, failureInfo *interfaces.WorkerFailureInfo) { + m.callbackHandler.HandleWorkerFailure(&provider.WorkerFailureEvent{ + WorkerID: workerID, + Endpoint: endpoint, + FailureInfo: failureInfo, + }) + }, + OnEndpointStatusChange: func(endpoint, status string, desiredReplicas, readyReplicas, availableReplicas int) { + m.callbackHandler.HandleEndpointStatusChange(&provider.EndpointStatusEvent{ + Endpoint: endpoint, + Status: status, + DesiredReplicas: desiredReplicas, + ReadyReplicas: readyReplicas, + AvailableReplicas: availableReplicas, + }) + }, + } +} + // UnregisterProvider unregisters a Provider and stops its watchers func (m *Manager) UnregisterProvider(name string) error { m.mu.Lock() diff --git a/pkg/autoscaler/resource_calculator.go b/pkg/autoscaler/resource_calculator.go index 13292dd..0690068 100644 --- a/pkg/autoscaler/resource_calculator.go +++ b/pkg/autoscaler/resource_calculator.go @@ -42,7 +42,12 @@ func (c *ResourceCalculator) CalculateEndpointResource(ctx context.Context, endp specName = meta.SpecName } - // Get spec + // Get spec - if specManager is nil (e.g., GMI/Novita provider), skip resource calculation + if c.specManager == nil { + logger.DebugCtx(ctx, "specManager is nil, skipping resource calculation for endpoint %s (non-K8s provider)", endpoint.Name) + return &Resources{}, nil + } + spec, err := c.specManager.GetSpec(specName) if err != nil { return nil, fmt.Errorf("failed to get spec %s: %w", specName, err) diff --git a/pkg/config/config.go b/pkg/config/config.go index c805d84..d734a46 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -25,6 +25,7 @@ type Config struct { Notification NotificationConfig `yaml:"notification"` // Notification configuration Providers *ProvidersConfig `yaml:"providers,omitempty"` // Providers configuration (optional) Novita NovitaConfig `yaml:"novita"` // Novita serverless configuration + GMI GMIConfig `yaml:"gmi"` // GMI serverless configuration ImageValidation ImageValidationConfig `yaml:"imageValidation"` // Image validation configuration ResourceReleaser ResourceReleaserConfig `yaml:"resourceReleaser"` // Resource releaser configuration } @@ -217,6 +218,16 @@ type NovitaConfig struct { PollInterval int `yaml:"poll_interval"` // Poll interval for status updates (seconds, default: 10) } +// GMIConfig GMI (Generic Model Interface) serverless configuration +type GMIConfig struct { + Enabled bool `yaml:"enabled"` // Whether to enable GMI provider + APIKey string `yaml:"api_key"` // GMI API key (Bearer token) + BaseURL string `yaml:"base_url"` // API base URL + CallbackURL string `yaml:"callback_url"` // External URL of waverless server for worker callbacks (e.g., http://35.189.190.93:8087) + ConfigDir string `yaml:"config_dir"` // Configuration directory (specs.yaml and templates) + PollInterval int `yaml:"poll_interval"` // Poll interval for status updates (seconds, default: 10) +} + // Init initializes configuration func Init() error { configPath := os.Getenv("CONFIG_PATH") diff --git a/pkg/interfaces/deployment.go b/pkg/interfaces/deployment.go index 6435b05..b0f9dda 100644 --- a/pkg/interfaces/deployment.go +++ b/pkg/interfaces/deployment.go @@ -206,8 +206,8 @@ type UpdateDeploymentRequest struct { VolumeMounts *[]VolumeMount `json:"volumeMounts,omitempty"` // New volume mounts (optional, use pointer to distinguish empty from unset) ShmSize *string `json:"shmSize,omitempty"` // New shared memory size (optional, use pointer to distinguish empty from unset) EnablePtrace *bool `json:"enablePtrace,omitempty"` // Enable SYS_PTRACE capability (optional, use pointer to distinguish false from unset) - Env *map[string]string `json:"env,omitempty"` // New environment variables (optional, use pointer to distinguish empty from unset) - TaskTimeout *int `json:"taskTimeout,omitempty"` // New task timeout (optional) + Env *map[string]string `json:"env,omitempty"` // New environment variables (optional, use pointer to distinguish empty from unset) + TaskTimeout *int `json:"taskTimeout,omitempty"` // New task timeout (optional) } // UpdateEndpointConfigRequest update Endpoint configuration request (metadata + autoscaling configuration) diff --git a/pkg/provider/factory.go b/pkg/provider/factory.go index 9e06b9d..c90e17a 100644 --- a/pkg/provider/factory.go +++ b/pkg/provider/factory.go @@ -8,6 +8,7 @@ import ( "waverless/pkg/config" "waverless/pkg/interfaces" "waverless/pkg/provider/docker" + "waverless/pkg/provider/gmi" "waverless/pkg/provider/k8s" "waverless/pkg/provider/novita" ) @@ -80,6 +81,7 @@ func init() { RegisterDeploymentProvider("kubernetes", k8s.NewK8sDeploymentProvider) RegisterDeploymentProvider("docker", docker.NewDockerDeploymentProvider) RegisterDeploymentProvider("novita", novita.NewNovitaDeploymentProvider) + RegisterDeploymentProvider("gmi", gmi.NewGMIDeploymentProvider) } func (f *ProviderFactory) CreateDeploymentProvider(providerType string) (interfaces.DeploymentProvider, error) { diff --git a/pkg/provider/gmi/client.go b/pkg/provider/gmi/client.go new file mode 100644 index 0000000..2e70cbd --- /dev/null +++ b/pkg/provider/gmi/client.go @@ -0,0 +1,74 @@ +package gmi + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + + "waverless/pkg/logger" +) + +// doRequest executes an HTTP request with Bearer token auth +func (p *GMIDeploymentProvider) doRequest(ctx context.Context, method, url string, body interface{}) ([]byte, error) { + var reqBody io.Reader + if body != nil { + jsonData, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + logger.Debugf("GMI API request: %s %s body=%s", method, url, string(jsonData)) + reqBody = bytes.NewBuffer(jsonData) + } + + httpReq, err := http.NewRequestWithContext(ctx, method, url, reqBody) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %w", err) + } + + // gmiless uses "Authorization: Bearer " for /api/v1 routes + httpReq.Header.Set("Authorization", "Bearer "+p.token) + if body != nil { + httpReq.Header.Set("Content-Type", "application/json") + } + httpReq.Header.Set("Accept", "application/json") + + resp, err := p.client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("failed to call GMI API: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("GMI API returned error status %d: %s", resp.StatusCode, string(respBody)) + } + + return respBody, nil +} + +// getEndpointID looks up the gmiless endpoint ID by name (with cache) +func (p *GMIDeploymentProvider) getEndpointID(ctx context.Context, endpoint string) (string, error) { + // Check cache + if id, ok := p.endpointCache.Load(endpoint); ok { + return id.(string), nil + } + + // Not cached, query list API to populate cache + _, err := p.ListApps(ctx) + if err != nil { + return "", fmt.Errorf("failed to list endpoints: %w", err) + } + + if id, ok := p.endpointCache.Load(endpoint); ok { + return id.(string), nil + } + + return "", fmt.Errorf("endpoint %s not found in GMI", endpoint) +} diff --git a/pkg/provider/gmi/lifecycle.go b/pkg/provider/gmi/lifecycle.go new file mode 100644 index 0000000..d5e389d --- /dev/null +++ b/pkg/provider/gmi/lifecycle.go @@ -0,0 +1,149 @@ +package gmi + +import ( + "context" + "sync" + + "waverless/pkg/interfaces" + "waverless/pkg/logger" +) + +// GMILifecycleCallbacks defines GMI lifecycle callback functions +type GMILifecycleCallbacks struct { + OnWorkerStatusChange func(workerID, endpoint string, podInfo *interfaces.PodInfo) + OnWorkerDelete func(workerID, endpoint string) + OnWorkerFailure func(workerID, endpoint string, failureInfo *interfaces.WorkerFailureInfo) + OnEndpointStatusChange func(endpoint, status string, desiredReplicas, readyReplicas, availableReplicas int) +} + +// GMIProviderLifecycle manages GMI provider lifecycle +type GMIProviderLifecycle struct { + provider *GMIDeploymentProvider + ctx context.Context + cancel context.CancelFunc + callbacks *GMILifecycleCallbacks + mu sync.Mutex + started bool +} + +// NewGMIProviderLifecycle creates a new GMI provider lifecycle manager +func NewGMIProviderLifecycle(p *GMIDeploymentProvider) *GMIProviderLifecycle { + return &GMIProviderLifecycle{ + provider: p, + } +} + +// GetProviderName returns the provider name +func (l *GMIProviderLifecycle) GetProviderName() string { + return "gmi" +} + +// RegisterWatchers registers all GMI watchers +func (l *GMIProviderLifecycle) RegisterWatchers(callbacks *GMILifecycleCallbacks) error { + l.mu.Lock() + defer l.mu.Unlock() + + if l.started { + logger.InfoCtx(l.ctx, "GMI lifecycle watchers already started") + return nil + } + + l.callbacks = callbacks + l.ctx, l.cancel = context.WithCancel(context.Background()) + + // Register worker status change watcher + if err := l.registerWorkerStatusWatcher(); err != nil { + logger.WarnCtx(l.ctx, "Failed to register GMI worker status watcher: %v", err) + } + + // Register worker delete watcher + if err := l.registerWorkerDeleteWatcher(); err != nil { + logger.WarnCtx(l.ctx, "Failed to register GMI worker delete watcher: %v", err) + } + + // Register endpoint status change watcher (via WatchReplicas) + if err := l.registerEndpointStatusWatcher(); err != nil { + logger.WarnCtx(l.ctx, "Failed to register GMI endpoint status watcher: %v", err) + } + + l.started = true + logger.InfoCtx(l.ctx, "GMI lifecycle watchers registered successfully") + return nil +} + +// StopWatchers stops all watchers +func (l *GMIProviderLifecycle) StopWatchers() error { + l.mu.Lock() + defer l.mu.Unlock() + + if !l.started { + return nil + } + + if l.cancel != nil { + l.cancel() + } + + l.provider.StopWatcher() + + l.started = false + logger.InfoCtx(context.Background(), "GMI lifecycle watchers stopped") + return nil +} + +// registerWorkerStatusWatcher registers worker status change watcher +func (l *GMIProviderLifecycle) registerWorkerStatusWatcher() error { + if l.callbacks == nil || l.callbacks.OnWorkerStatusChange == nil { + return nil + } + + // Create failure detector + failureDetector := NewGMIWorkerStatusMonitor(nil) + + return l.provider.WatchPodStatusChange(l.ctx, func(workerID, endpoint string, info *interfaces.PodInfo) { + // 1. Trigger status change callback + l.callbacks.OnWorkerStatusChange(workerID, endpoint, info) + + // 2. Detect failure and trigger failure callback + if l.callbacks.OnWorkerFailure != nil { + if failureInfo := failureDetector.DetectFailure(info); failureInfo != nil { + l.callbacks.OnWorkerFailure(workerID, endpoint, failureInfo) + } + } + }) +} + +// registerWorkerDeleteWatcher registers worker delete watcher +func (l *GMIProviderLifecycle) registerWorkerDeleteWatcher() error { + if l.callbacks == nil || l.callbacks.OnWorkerDelete == nil { + return nil + } + + return l.provider.WatchPodDelete(l.ctx, func(workerID, endpoint string) { + l.callbacks.OnWorkerDelete(workerID, endpoint) + }) +} + +// registerEndpointStatusWatcher registers endpoint status change watcher +func (l *GMIProviderLifecycle) registerEndpointStatusWatcher() error { + if l.callbacks == nil || l.callbacks.OnEndpointStatusChange == nil { + return nil + } + + return l.provider.WatchReplicas(l.ctx, func(event interfaces.ReplicaEvent) { + status := "Pending" + if event.AvailableReplicas == event.DesiredReplicas && event.DesiredReplicas > 0 { + status = "Running" + } else if event.DesiredReplicas == 0 { + status = "Stopped" + } + + l.callbacks.OnEndpointStatusChange( + event.Name, + status, + event.DesiredReplicas, + event.ReadyReplicas, + event.AvailableReplicas, + ) + }) +} diff --git a/pkg/provider/gmi/mapper.go b/pkg/provider/gmi/mapper.go new file mode 100644 index 0000000..a288436 --- /dev/null +++ b/pkg/provider/gmi/mapper.go @@ -0,0 +1,182 @@ +package gmi + +import ( + "strings" + + "waverless/pkg/interfaces" +) + +// ======================================== +// Spec name → GPU type mapping +// ======================================== + +// specToGPUTypeMap maps waverless spec names to GPU type IDs used by gmiless +var specToGPUTypeMap = map[string]string{ + "h100-single-hbm3": "NVIDIA-H100-80GB-HBM3", + "h100-single": "NVIDIA-H100-80GB-HBM3", + "h100-pcie-single": "NVIDIA-H100-PCIe", + "a100-single": "NVIDIA-A100-80GB-PCIe", + "5090-single": "NVIDIA-GeForce-RTX-5090", + "4090-single": "NVIDIA-GeForce-RTX-4090", + "a6000-single": "NVIDIA-RTX-A6000", + "l40-single": "NVIDIA-L40", +} + +// specNameToGPUType converts a waverless spec name to a gmiless GPU type ID. +// If no mapping exists, returns the spec name as-is (assumes it's already a GPU type). +func specNameToGPUType(specName string) string { + if gpuType, ok := specToGPUTypeMap[specName]; ok { + return gpuType + } + return specName +} + +// ======================================== +// Response → AppInfo conversion +// ======================================== + +// convertToAppInfo converts gmiless endpoint response to waverless AppInfo +func convertToAppInfo(resp *gmiEndpointResponse) *interfaces.AppInfo { + status := resp.Status + if status == "" { + if resp.Replicas == 0 { + status = "Stopped" + } else { + status = "Pending" + } + } + + // Count ready workers + var readyReplicas, availableReplicas int32 + for _, w := range resp.Workers { + if strings.EqualFold(w.DesiredStatus, "ONLINE") || strings.EqualFold(w.DesiredStatus, "BUSY") { + availableReplicas++ + readyReplicas++ + } + } + + image := resp.Image + if image == "" && resp.Template != nil { + image = resp.Template.ImageName + } + + return &interfaces.AppInfo{ + Name: resp.Name, + Status: status, + Replicas: int32(resp.Replicas), + ReadyReplicas: readyReplicas, + AvailableReplicas: availableReplicas, + Image: image, + Labels: map[string]string{}, + CreatedAt: resp.CreatedAt, + } +} + +// ======================================== +// Response → PodInfo / PodDetail conversion +// ======================================== + +// convertPodInfoFromGMI converts gmiPodInfo to interfaces.PodInfo +func convertPodInfoFromGMI(pod *gmiPodInfo) *interfaces.PodInfo { + return &interfaces.PodInfo{ + Name: pod.Name, + Phase: pod.Phase, + Status: pod.Status, + Reason: pod.Reason, + Message: pod.Message, + IP: pod.IP, + NodeName: pod.NodeName, + CreatedAt: pod.CreatedAt, + StartedAt: pod.StartedAt, + Labels: pod.Labels, + } +} + +// convertPodDetailFromGMI converts gmiPodInfo to interfaces.PodDetail +func convertPodDetailFromGMI(pod *gmiPodInfo) *interfaces.PodDetail { + // Convert containers + containers := make([]interfaces.ContainerInfo, len(pod.Containers)) + for i, c := range pod.Containers { + env := make([]interfaces.EnvVar, 0, len(c.Env)) + for _, e := range c.Env { + env = append(env, interfaces.EnvVar{Name: e.Name, Value: e.Value}) + } + + resources := make(map[string]interface{}) + if len(c.Resources.Requests) > 0 || len(c.Resources.Limits) > 0 { + resources["requests"] = c.Resources.Requests + resources["limits"] = c.Resources.Limits + } + + state := "Running" + if !c.Ready { + state = "Waiting" + } + + containers[i] = interfaces.ContainerInfo{ + Name: c.Name, + Image: c.Image, + State: state, + Ready: c.Ready, + Resources: resources, + Env: env, + } + } + + // Convert conditions + conditions := make([]interfaces.PodCondition, len(pod.Conditions)) + for i, c := range pod.Conditions { + conditions[i] = interfaces.PodCondition{ + Type: c.Type, + Status: c.Status, + LastTransitionTime: c.LastTransitionTime, + } + } + + // Convert volumes + volumes := make([]interfaces.VolumeInfo, len(pod.Volumes)) + for i, v := range pod.Volumes { + volumes[i] = interfaces.VolumeInfo{Name: v.Name, Type: v.Type} + } + + return &interfaces.PodDetail{ + PodInfo: &interfaces.PodInfo{ + Name: pod.Name, + Phase: pod.Phase, + Status: pod.Status, + Reason: pod.Reason, + Message: pod.Message, + IP: pod.IP, + NodeName: pod.NodeName, + CreatedAt: pod.CreatedAt, + StartedAt: pod.StartedAt, + Labels: pod.Labels, + RestartCount: int32(pod.RestartCount), + }, + Namespace: pod.Namespace, + UID: pod.UID, + Annotations: pod.Annotations, + Containers: containers, + Conditions: conditions, + Events: []interfaces.PodEvent{}, + Volumes: volumes, + } +} + +// convertWorkerToPodInfo converts gmiWorkerResponse to interfaces.PodInfo (for status sync) +func convertWorkerToPodInfo(worker *gmiWorkerResponse, endpoint string) *interfaces.PodInfo { + status := worker.DesiredStatus + phase := "Running" + if strings.EqualFold(status, "STARTING") { + phase = "Pending" + } + + return &interfaces.PodInfo{ + Name: worker.Name, + Phase: phase, + Status: status, + CreatedAt: worker.LastStartedAt, + StartedAt: worker.LastStartedAt, + Labels: map[string]string{"app": endpoint}, + } +} diff --git a/pkg/provider/gmi/provider.go b/pkg/provider/gmi/provider.go new file mode 100644 index 0000000..279956c --- /dev/null +++ b/pkg/provider/gmi/provider.go @@ -0,0 +1,879 @@ +package gmi + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "sync" + "sync/atomic" + "time" + + "waverless/pkg/config" + "waverless/pkg/interfaces" + "waverless/pkg/logger" +) + +// GMIDeploymentProvider implements the DeploymentProvider interface for GMI. +// It calls the gmiless API (/api/v1/endpoints) for deployment operations. +type GMIDeploymentProvider struct { + baseURL string + token string + client *http.Client + cfg *config.Config + gmiConfig *config.GMIConfig + pollInterval time.Duration + + // Endpoint name → ID cache + endpointCache sync.Map + + // Worker state tracking for polling-based sync + workerStates sync.Map // workerID → *gmiWorkerState + watcherRunning sync.Once + watcherCtx context.Context + watcherCancel context.CancelFunc + + // Worker status change callbacks + workerStatusCallbacks map[uint64]WorkerStatusChangeCallback + workerStatusCallbacksLock sync.RWMutex + workerDeleteCallbacks map[uint64]WorkerDeleteCallback + workerDeleteCallbacksLock sync.RWMutex + nextCallbackID uint64 +} + +// NewGMIDeploymentProvider creates a new GMI deployment provider. +func NewGMIDeploymentProvider(cfg *config.Config) (interfaces.DeploymentProvider, error) { + if !cfg.GMI.Enabled { + return nil, fmt.Errorf("gmi provider is not enabled in config") + } + + baseURL := cfg.GMI.BaseURL + if baseURL == "" { + return nil, fmt.Errorf("gmi base_url is required") + } + baseURL = strings.TrimRight(baseURL, "/") + + token := cfg.GMI.APIKey + + pollInterval := 10 * time.Second + if cfg.GMI.PollInterval > 0 { + pollInterval = time.Duration(cfg.GMI.PollInterval) * time.Second + } + + return &GMIDeploymentProvider{ + baseURL: baseURL, + token: token, + client: &http.Client{ + Timeout: 60 * time.Second, + }, + cfg: cfg, + gmiConfig: &cfg.GMI, + pollInterval: pollInterval, + workerStatusCallbacks: make(map[uint64]WorkerStatusChangeCallback), + workerDeleteCallbacks: make(map[uint64]WorkerDeleteCallback), + }, nil +} + +// ======================================== +// CoreDeploymentProvider (required) +// ======================================== + +func (p *GMIDeploymentProvider) Deploy(ctx context.Context, req *interfaces.DeployRequest) (*interfaces.DeployResponse, error) { + logger.Infof("GMI Deploy: endpoint=%s, image=%s, replicas=%d, spec=%s, gpuCount=%d", + req.Endpoint, req.Image, req.Replicas, req.SpecName, req.GpuCount) + + computeType := "GPU" + endpointType := "QB" + + // Build template with default env vars merged + mergedEnv := p.mergeEnv(p.buildDefaultEnv(req.Endpoint), req.Env) + + // Pass registry credential as environment variables for private image pulling + if req.RegistryCredential != nil { + if req.RegistryCredential.Registry != "" { + mergedEnv["REGISTRY_SERVER"] = req.RegistryCredential.Registry + } + if req.RegistryCredential.Username != "" { + mergedEnv["REGISTRY_USERNAME"] = req.RegistryCredential.Username + } + if req.RegistryCredential.Password != "" { + mergedEnv["REGISTRY_PASSWORD"] = req.RegistryCredential.Password + } + } + + template := &gmiTemplateData{ + ImageName: &req.Image, + Env: mergedEnv, + } + if req.ShmSize != "" { + template.ShmSize = &req.ShmSize + } + + // Build request matching gmiless EndpointRequest + defaultRegions := []string{"us-west1"} + gmiReq := &gmiEndpointRequest{ + Name: &req.Endpoint, + Replicas: &req.Replicas, + GpuCount: &req.GpuCount, + ComputeType: &computeType, + Type: &endpointType, + Template: template, + WorkersMin: &req.Replicas, + WorkersMax: &req.Replicas, + DataCenterIds: &defaultRegions, + } + + // Map spec name to GPU type ID + if req.SpecName != "" { + gpuType := specNameToGPUType(req.SpecName) + gmiReq.GpuTypeIds = &[]string{gpuType} + } + + // Convert TaskTimeout (seconds) to ExecutionTimeoutMs (milliseconds) + if req.TaskTimeout > 0 { + timeoutMs := int64(req.TaskTimeout) * 1000 + gmiReq.ExecutionTimeoutMs = &timeoutMs + } + + url := p.baseURL + "/api/v1/endpoints" + body, err := p.doRequest(ctx, "POST", url, gmiReq) + if err != nil { + return nil, fmt.Errorf("failed to deploy via GMI API: %w", err) + } + + // Parse response to get endpoint ID + var resp gmiEndpointResponse + if err := json.Unmarshal(body, &resp); err != nil { + logger.Warnf("GMI Deploy: failed to parse response: %v, body=%s", err, string(body)) + } else if resp.Id != "" { + p.endpointCache.Store(req.Endpoint, resp.Id) + } + + logger.Infof("GMI Deploy: endpoint=%s, id=%s, SUCCESS", req.Endpoint, resp.Id) + + return &interfaces.DeployResponse{ + Endpoint: req.Endpoint, + Message: "Successfully deployed via GMI API", + CreatedAt: time.Now().Format(time.RFC3339), + }, nil +} + +func (p *GMIDeploymentProvider) GetApp(ctx context.Context, endpoint string) (*interfaces.AppInfo, error) { + endpointID, err := p.getEndpointID(ctx, endpoint) + if err != nil { + return nil, err + } + + url := fmt.Sprintf("%s/api/v1/endpoints/%s", p.baseURL, endpointID) + body, err := p.doRequest(ctx, "GET", url, nil) + if err != nil { + return nil, err + } + + var resp gmiEndpointResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + p.endpointCache.Store(resp.Name, resp.Id) + return convertToAppInfo(&resp), nil +} + +func (p *GMIDeploymentProvider) ListApps(ctx context.Context) ([]*interfaces.AppInfo, error) { + url := p.baseURL + "/api/v1/endpoints" + body, err := p.doRequest(ctx, "GET", url, nil) + if err != nil { + return nil, err + } + + var respList []gmiEndpointResponse + if err := json.Unmarshal(body, &respList); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + apps := make([]*interfaces.AppInfo, len(respList)) + for i, resp := range respList { + if resp.Id != "" && resp.Name != "" { + p.endpointCache.Store(resp.Name, resp.Id) + } + apps[i] = convertToAppInfo(&resp) + } + + return apps, nil +} + +func (p *GMIDeploymentProvider) DeleteApp(ctx context.Context, endpoint string) error { + logger.Infof("GMI DeleteApp: endpoint=%s", endpoint) + + endpointID, err := p.getEndpointID(ctx, endpoint) + if err != nil { + return err + } + + url := fmt.Sprintf("%s/api/v1/endpoints/%s", p.baseURL, endpointID) + _, err = p.doRequest(ctx, "DELETE", url, nil) + if err != nil { + return fmt.Errorf("failed to delete via GMI API: %w", err) + } + + p.endpointCache.Delete(endpoint) + logger.Infof("GMI DeleteApp: endpoint=%s, SUCCESS", endpoint) + return nil +} + +func (p *GMIDeploymentProvider) ScaleApp(ctx context.Context, endpoint string, replicas int) error { + logger.Infof("GMI ScaleApp: endpoint=%s, target replicas=%d", endpoint, replicas) + + endpointID, err := p.getEndpointID(ctx, endpoint) + if err != nil { + return err + } + + // Only send replicas fields, no template or env changes + gmiReq := &gmiEndpointRequest{ + Replicas: &replicas, + WorkersMin: &replicas, + WorkersMax: &replicas, + } + + url := fmt.Sprintf("%s/api/v1/endpoints/%s", p.baseURL, endpointID) + _, err = p.doRequest(ctx, "PATCH", url, gmiReq) + if err != nil { + return fmt.Errorf("failed to scale via GMI API: %w", err) + } + + logger.Infof("GMI ScaleApp: endpoint=%s, replicas=%d, SUCCESS", endpoint, replicas) + return nil +} + +func (p *GMIDeploymentProvider) GetAppStatus(ctx context.Context, endpoint string) (*interfaces.AppStatus, error) { + app, err := p.GetApp(ctx, endpoint) + if err != nil { + return nil, err + } + + return &interfaces.AppStatus{ + Endpoint: app.Name, + Status: app.Status, + ReadyReplicas: app.ReadyReplicas, + AvailableReplicas: app.AvailableReplicas, + TotalReplicas: app.Replicas, + }, nil +} + +func (p *GMIDeploymentProvider) UpdateDeployment(ctx context.Context, req *interfaces.UpdateDeploymentRequest) (*interfaces.DeployResponse, error) { + logger.Infof("GMI UpdateDeployment: endpoint=%s, image=%s, replicas=%v, env=%v, shmSize=%v, taskTimeout=%v", + req.Endpoint, req.Image, req.Replicas, req.Env != nil, req.ShmSize, req.TaskTimeout) + + endpointID, err := p.getEndpointID(ctx, req.Endpoint) + if err != nil { + return nil, err + } + + gmiReq := &gmiEndpointRequest{} + + if req.Replicas != nil { + gmiReq.Replicas = req.Replicas + gmiReq.WorkersMin = req.Replicas + gmiReq.WorkersMax = req.Replicas + } + + // Only build and send template when template-related fields are being updated + needTemplate := req.Image != "" || req.Env != nil || (req.ShmSize != nil && *req.ShmSize != "") + if needTemplate { + template := &gmiTemplateData{} + if req.Image != "" { + template.ImageName = &req.Image + } + if req.ShmSize != nil && *req.ShmSize != "" { + template.ShmSize = req.ShmSize + } + + // Build env: start with defaults, then merge existing user env vars, then apply new env + defaultEnv := p.buildDefaultEnv(req.Endpoint) + existingEnv := p.getExistingEnv(ctx, endpointID) + + // Preserve existing non-default env vars (user-set vars from previous deploys) + if existingEnv != nil { + for k, v := range existingEnv { + if _, isDefault := defaultEnv[k]; !isDefault { + defaultEnv[k] = v + } + } + } + + // If user explicitly provided new env, override with it + if req.Env != nil { + for k, v := range *req.Env { + defaultEnv[k] = v + } + } + + template.Env = defaultEnv + gmiReq.Template = template + } + + if req.TaskTimeout != nil && *req.TaskTimeout > 0 { + timeoutMs := int64(*req.TaskTimeout) * 1000 + gmiReq.ExecutionTimeoutMs = &timeoutMs + } + + // Log the full request for debugging + if reqJSON, err := json.Marshal(gmiReq); err == nil { + logger.Infof("GMI UpdateDeployment: endpoint=%s, endpointID=%s, request=%s", req.Endpoint, endpointID, string(reqJSON)) + } + + url := fmt.Sprintf("%s/api/v1/endpoints/%s", p.baseURL, endpointID) + body, err := p.doRequest(ctx, "PATCH", url, gmiReq) + if err != nil { + return nil, fmt.Errorf("failed to update via GMI API: %w", err) + } + + logger.Infof("GMI UpdateDeployment: endpoint=%s, SUCCESS, response=%s", req.Endpoint, string(body)) + + return &interfaces.DeployResponse{ + Endpoint: req.Endpoint, + Message: "Successfully updated via GMI API", + CreatedAt: time.Now().Format(time.RFC3339), + }, nil +} + +// ======================================== +// SpecProvider (optional) +// ======================================== + +func (p *GMIDeploymentProvider) ListSpecs(ctx context.Context) ([]*interfaces.SpecInfo, error) { + return nil, fmt.Errorf("GMI provider: ListSpecs - use database spec service instead") +} + +func (p *GMIDeploymentProvider) GetSpec(ctx context.Context, specName string) (*interfaces.SpecInfo, error) { + return nil, fmt.Errorf("GMI provider: GetSpec - use database spec service instead") +} + +// ======================================== +// LogProvider (optional) +// ======================================== + +func (p *GMIDeploymentProvider) GetAppLogs(ctx context.Context, endpoint string, lines int, podName ...string) (string, error) { + endpointID, err := p.getEndpointID(ctx, endpoint) + if err != nil { + return "", err + } + + url := fmt.Sprintf("%s/api/v1/endpoints/%s/logs?lines=%d", p.baseURL, endpointID, lines) + if len(podName) > 0 && podName[0] != "" { + url += "&pod_name=" + podName[0] + } + + body, err := p.doRequest(ctx, "GET", url, nil) + if err != nil { + return "", err + } + + return string(body), nil +} + +// ======================================== +// PodProvider (optional) +// ======================================== + +func (p *GMIDeploymentProvider) GetPods(ctx context.Context, endpoint string) ([]*interfaces.PodInfo, error) { + if endpoint == "" { + return p.getAllPods(ctx) + } + + endpointID, err := p.getEndpointID(ctx, endpoint) + if err != nil { + return nil, err + } + + url := fmt.Sprintf("%s/api/v1/endpoints/%s/workers", p.baseURL, endpointID) + body, err := p.doRequest(ctx, "GET", url, nil) + if err != nil { + return nil, err + } + + // gmiless may return workers as direct array or as part of endpoint response + var podList []gmiPodInfo + if err := json.Unmarshal(body, &podList); err != nil { + // Try parsing as endpoint response with workers + var endpointResp gmiEndpointResponse + if err2 := json.Unmarshal(body, &endpointResp); err2 == nil { + pods := make([]*interfaces.PodInfo, len(endpointResp.Workers)) + for i, w := range endpointResp.Workers { + pods[i] = convertWorkerToPodInfo(&w, endpoint) + } + return pods, nil + } + return nil, fmt.Errorf("failed to parse workers response: %w", err) + } + + pods := make([]*interfaces.PodInfo, len(podList)) + for i := range podList { + pods[i] = convertPodInfoFromGMI(&podList[i]) + } + + return pods, nil +} + +func (p *GMIDeploymentProvider) getAllPods(ctx context.Context) ([]*interfaces.PodInfo, error) { + apps, err := p.ListApps(ctx) + if err != nil { + return nil, fmt.Errorf("failed to list apps: %w", err) + } + + var allPods []*interfaces.PodInfo + for _, app := range apps { + pods, err := p.GetPods(ctx, app.Name) + if err != nil { + logger.Warnf("GMI: failed to get pods for %s: %v", app.Name, err) + continue + } + for _, pod := range pods { + if pod.Labels == nil { + pod.Labels = make(map[string]string) + } + if pod.Labels["app"] == "" { + pod.Labels["app"] = app.Name + } + } + allPods = append(allPods, pods...) + } + + return allPods, nil +} + +func (p *GMIDeploymentProvider) DescribePod(ctx context.Context, endpoint string, podName string) (*interfaces.PodDetail, error) { + endpointID, err := p.getEndpointID(ctx, endpoint) + if err != nil { + return nil, err + } + + url := fmt.Sprintf("%s/api/v1/endpoints/%s/workers/%s/describe", p.baseURL, endpointID, podName) + body, err := p.doRequest(ctx, "GET", url, nil) + if err != nil { + return nil, err + } + + var pod gmiPodInfo + if err := json.Unmarshal(body, &pod); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return convertPodDetailFromGMI(&pod), nil +} + +func (p *GMIDeploymentProvider) GetPodYAML(ctx context.Context, endpoint string, podName string) (string, error) { + return "", fmt.Errorf("GMI provider: GetPodYAML not supported") +} + +func (p *GMIDeploymentProvider) IsPodTerminating(ctx context.Context, podName string) (bool, error) { + if podName == "" { + return false, nil + } + + apps, err := p.ListApps(ctx) + if err != nil { + return false, nil + } + + for _, app := range apps { + if !strings.HasPrefix(podName, app.Name+"-") { + continue + } + pods, err := p.GetPods(ctx, app.Name) + if err != nil { + continue + } + for _, pod := range pods { + if pod.Name == podName { + return pod.DeletionTimestamp != "" || pod.Status == "Terminating", nil + } + } + } + + return false, nil +} + +// ======================================== +// StorageProvider (optional) +// ======================================== + +func (p *GMIDeploymentProvider) ListPVCs(ctx context.Context) ([]*interfaces.PVCInfo, error) { + return nil, nil +} + +// ======================================== +// ConfigProvider (optional) +// ======================================== + +func (p *GMIDeploymentProvider) GetDefaultEnv(ctx context.Context) (map[string]string, error) { + return map[string]string{ + "PLATFORM": "waverless", + }, nil +} + +// buildDefaultEnv builds the default RUNPOD/WAVERLESS environment variables for a given endpoint. +// These are automatically injected during Deploy and Update, not exposed to the frontend. +func (p *GMIDeploymentProvider) buildDefaultEnv(endpoint string) map[string]string { + callbackURL := strings.TrimRight(p.gmiConfig.CallbackURL, "/") + if callbackURL == "" { + callbackURL = p.baseURL + } + + serverAPIKey := p.cfg.Server.APIKey + heartbeatMs := "10000" + + return map[string]string{ + "PLATFORM": "waverless", + "RUNPOD_AI_API_KEY": serverAPIKey, + "RUNPOD_ENDPOINT_ID": endpoint, + "RUNPOD_PING_INTERVAL": heartbeatMs, + "RUNPOD_WEBHOOK_GET_JOB": fmt.Sprintf("%s/v2/%s/job-take/$ID?", callbackURL, endpoint), + "RUNPOD_WEBHOOK_PING": fmt.Sprintf("%s/v2/%s/ping/$RUNPOD_POD_ID", callbackURL, endpoint), + "RUNPOD_WEBHOOK_POST_OUTPUT": fmt.Sprintf("%s/v2/%s/job-done/$RUNPOD_POD_ID/$ID?", callbackURL, endpoint), + "RUNPOD_WEBHOOK_JOB_STREAM": fmt.Sprintf("%s/v2/%s/job-stream/$RUNPOD_POD_ID/$ID?", callbackURL, endpoint), + "WAVERLESS_ENDPOINT_ID": endpoint, + "WAVERLESS_PING_INTERVAL": heartbeatMs, + "WAVERLESS_WEBHOOK_GET_JOB": fmt.Sprintf("%s/v2/%s/job-take/$ID?", callbackURL, endpoint), + "WAVERLESS_WEBHOOK_PING": fmt.Sprintf("%s/v2/%s/ping/$WAVERLESS_POD_ID", callbackURL, endpoint), + "WAVERLESS_WEBHOOK_POST_OUTPUT": fmt.Sprintf("%s/v2/%s/job-done/$WAVERLESS_POD_ID/$ID?", callbackURL, endpoint), + "WAVERLESS_WEBHOOK_POST_STREAM": fmt.Sprintf("%s/v2/%s/job-stream/$WAVERLESS_POD_ID/$ID?", callbackURL, endpoint), + } +} + +// getExistingEnv fetches the current env vars from the remote GMI endpoint. +// Returns nil if the endpoint cannot be fetched (non-fatal). +func (p *GMIDeploymentProvider) getExistingEnv(ctx context.Context, endpointID string) map[string]string { + url := fmt.Sprintf("%s/api/v1/endpoints/%s", p.baseURL, endpointID) + body, err := p.doRequest(ctx, "GET", url, nil) + if err != nil { + return nil + } + var resp gmiEndpointResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil + } + return resp.Env +} + +// mergeEnv merges default env with user env. User-provided values take precedence. +func (p *GMIDeploymentProvider) mergeEnv(defaults map[string]string, userEnv map[string]string) map[string]string { + merged := make(map[string]string, len(defaults)+len(userEnv)) + for k, v := range defaults { + merged[k] = v + } + for k, v := range userEnv { + merged[k] = v + } + return merged +} + +// ======================================== +// PreviewProvider (optional) +// ======================================== + +func (p *GMIDeploymentProvider) PreviewDeploymentYAML(ctx context.Context, req *interfaces.DeployRequest) (string, error) { + gpuCount := req.GpuCount + if gpuCount <= 0 { + gpuCount = 1 + } + + yaml := fmt.Sprintf(`# GMI Deployment Preview +# Endpoint: %s +# Image: %s +# Replicas: %d +# GPU Count: %d +# +# This endpoint will be deployed via GMI (gmiless) API at: +# %s/api/v1/endpoints +`, + req.Endpoint, req.Image, req.Replicas, gpuCount, p.baseURL, + ) + + return yaml, nil +} + +// ======================================== +// WatchProvider (optional) +// ======================================== + +func (p *GMIDeploymentProvider) WatchReplicas(ctx context.Context, callback interfaces.ReplicaCallback) error { + go func() { + ticker := time.NewTicker(p.pollInterval) + defer ticker.Stop() + + previousState := make(map[string]interfaces.ReplicaEvent) + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + apps, err := p.ListApps(ctx) + if err != nil { + continue + } + + for _, app := range apps { + event := interfaces.ReplicaEvent{ + Name: app.Name, + DesiredReplicas: int(app.Replicas), + ReadyReplicas: int(app.ReadyReplicas), + AvailableReplicas: int(app.AvailableReplicas), + } + + prev, exists := previousState[app.Name] + if !exists || + prev.DesiredReplicas != event.DesiredReplicas || + prev.ReadyReplicas != event.ReadyReplicas || + prev.AvailableReplicas != event.AvailableReplicas { + + if exists { + logger.Infof("GMI WatchReplicas: endpoint=%s, replica change: desired %d->%d, ready %d->%d", + app.Name, prev.DesiredReplicas, event.DesiredReplicas, prev.ReadyReplicas, event.ReadyReplicas) + } + callback(event) + previousState[app.Name] = event + } + } + } + } + }() + + return nil +} + +// ======================================== +// Worker Status Sync (polling-based) +// ======================================== + +// WatchPodStatusChange registers a callback for worker status changes. +func (p *GMIDeploymentProvider) WatchPodStatusChange(ctx context.Context, callback WorkerStatusChangeCallback) error { + if callback == nil { + return fmt.Errorf("worker status change callback is nil") + } + + p.workerStatusCallbacksLock.Lock() + callbackID := atomic.AddUint64(&p.nextCallbackID, 1) + p.workerStatusCallbacks[callbackID] = callback + p.workerStatusCallbacksLock.Unlock() + + logger.Infof("GMI: registered worker status change callback (ID: %d)", callbackID) + + p.watcherRunning.Do(func() { + p.watcherCtx, p.watcherCancel = context.WithCancel(ctx) + logger.Infof("GMI: starting worker status watcher (poll interval: %v)", p.pollInterval) + go p.runWorkerStatusWatcher(p.watcherCtx) + }) + + go func() { + <-ctx.Done() + p.workerStatusCallbacksLock.Lock() + delete(p.workerStatusCallbacks, callbackID) + p.workerStatusCallbacksLock.Unlock() + }() + + return nil +} + +// WatchPodDelete registers a callback for worker deletions. +func (p *GMIDeploymentProvider) WatchPodDelete(ctx context.Context, callback WorkerDeleteCallback) error { + if callback == nil { + return fmt.Errorf("worker delete callback is nil") + } + + p.workerDeleteCallbacksLock.Lock() + callbackID := atomic.AddUint64(&p.nextCallbackID, 1) + p.workerDeleteCallbacks[callbackID] = callback + p.workerDeleteCallbacksLock.Unlock() + + logger.Infof("GMI: registered worker delete callback (ID: %d)", callbackID) + + p.watcherRunning.Do(func() { + p.watcherCtx, p.watcherCancel = context.WithCancel(ctx) + logger.Infof("GMI: starting worker status watcher (poll interval: %v)", p.pollInterval) + go p.runWorkerStatusWatcher(p.watcherCtx) + }) + + go func() { + <-ctx.Done() + p.workerDeleteCallbacksLock.Lock() + delete(p.workerDeleteCallbacks, callbackID) + p.workerDeleteCallbacksLock.Unlock() + }() + + return nil +} + +// StopWatcher stops the worker status watcher +func (p *GMIDeploymentProvider) StopWatcher() { + if p.watcherCancel != nil { + p.watcherCancel() + } +} + +// runWorkerStatusWatcher runs the polling loop +func (p *GMIDeploymentProvider) runWorkerStatusWatcher(ctx context.Context) { + ticker := time.NewTicker(p.pollInterval) + defer ticker.Stop() + + logger.Infof("GMI: worker status watcher started") + + for { + select { + case <-ctx.Done(): + logger.Infof("GMI: worker status watcher stopped") + return + case <-ticker.C: + p.pollWorkerStates(ctx) + } + } +} + +// pollWorkerStates polls gmiless for all endpoints and workers +func (p *GMIDeploymentProvider) pollWorkerStates(ctx context.Context) { + url := p.baseURL + "/api/v1/endpoints" + body, err := p.doRequest(ctx, "GET", url, nil) + if err != nil { + logger.Warnf("GMI: failed to poll endpoints: %v", err) + return + } + + var endpoints []gmiEndpointResponse + if err := json.Unmarshal(body, &endpoints); err != nil { + logger.Warnf("GMI: failed to parse endpoints response: %v", err) + return + } + + currentWorkerIDs := make(map[string]bool) + totalWorkers := 0 + + for _, ep := range endpoints { + if ep.Id != "" && ep.Name != "" { + p.endpointCache.Store(ep.Name, ep.Id) + } + + // The list endpoints API does not include workers inline, + // so we need to fetch workers separately for each endpoint. + workers := ep.Workers + if len(workers) == 0 && ep.Id != "" { + workersURL := p.baseURL + "/api/v1/endpoints/" + ep.Id + "/workers" + wBody, wErr := p.doRequest(ctx, "GET", workersURL, nil) + if wErr != nil { + logger.Warnf("GMI: failed to poll workers for endpoint %s: %v", ep.Name, wErr) + } else { + var fetchedWorkers []gmiWorkerResponse + if jErr := json.Unmarshal(wBody, &fetchedWorkers); jErr != nil { + logger.Warnf("GMI: failed to parse workers response for endpoint %s: %v", ep.Name, jErr) + } else { + workers = fetchedWorkers + } + } + } + + totalWorkers += len(workers) + for i := range workers { + worker := &workers[i] + workerID := worker.Id + if workerID == "" { + workerID = worker.Name + } + currentWorkerIDs[workerID] = true + p.processWorkerStateChange(ep.Name, workerID, worker) + } + } + + if len(endpoints) > 0 || totalWorkers > 0 { + logger.Infof("GMI poll: found %d endpoints, %d workers", len(endpoints), totalWorkers) + } else { + logger.Debugf("GMI poll: no endpoints or workers found") + } + + p.detectDeletedWorkers(currentWorkerIDs) +} + +// processWorkerStateChange detects worker state changes and triggers callbacks +func (p *GMIDeploymentProvider) processWorkerStateChange(endpoint, workerID string, worker *gmiWorkerResponse) { + prevInterface, exists := p.workerStates.Load(workerID) + + podInfo := convertWorkerToPodInfo(worker, endpoint) + + currentState := &gmiWorkerState{ + ID: workerID, + Endpoint: endpoint, + Status: worker.DesiredStatus, + CreatedAt: worker.LastStartedAt, + StartedAt: worker.LastStartedAt, + } + + if !exists { + p.workerStates.Store(workerID, currentState) + logger.Infof("GMI: new worker detected: %s (endpoint: %s, status: %s)", workerID, endpoint, worker.DesiredStatus) + p.notifyWorkerStatusChange(workerID, endpoint, podInfo) + return + } + + prev := prevInterface.(*gmiWorkerState) + if prev.Status != currentState.Status { + p.workerStates.Store(workerID, currentState) + logger.Infof("GMI: worker state changed: %s (endpoint: %s, %s -> %s)", workerID, endpoint, prev.Status, currentState.Status) + p.notifyWorkerStatusChange(workerID, endpoint, podInfo) + } +} + +// detectDeletedWorkers finds workers that disappeared and triggers delete callbacks +func (p *GMIDeploymentProvider) detectDeletedWorkers(currentWorkerIDs map[string]bool) { + p.workerStates.Range(func(key, value interface{}) bool { + workerID := key.(string) + if !currentWorkerIDs[workerID] { + state := value.(*gmiWorkerState) + logger.Infof("GMI: worker deleted: %s (endpoint: %s)", workerID, state.Endpoint) + p.workerStates.Delete(workerID) + p.notifyWorkerDelete(workerID, state.Endpoint) + } + return true + }) +} + +// notifyWorkerStatusChange triggers all registered status change callbacks +func (p *GMIDeploymentProvider) notifyWorkerStatusChange(workerID, endpoint string, info *interfaces.PodInfo) { + p.workerStatusCallbacksLock.RLock() + defer p.workerStatusCallbacksLock.RUnlock() + + for _, cb := range p.workerStatusCallbacks { + cb := cb + go func() { + defer func() { + if r := recover(); r != nil { + logger.Errorf("GMI: panic in worker status callback: %v", r) + } + }() + cb(workerID, endpoint, info) + }() + } +} + +// notifyWorkerDelete triggers all registered delete callbacks +func (p *GMIDeploymentProvider) notifyWorkerDelete(workerID, endpoint string) { + p.workerDeleteCallbacksLock.RLock() + defer p.workerDeleteCallbacksLock.RUnlock() + + for _, cb := range p.workerDeleteCallbacks { + cb := cb + go func() { + defer func() { + if r := recover(); r != nil { + logger.Errorf("GMI: panic in worker delete callback: %v", r) + } + }() + cb(workerID, endpoint) + }() + } +} + +// GetLifecycle returns the GMI lifecycle manager +func (p *GMIDeploymentProvider) GetLifecycle() *GMIProviderLifecycle { + return NewGMIProviderLifecycle(p) +} diff --git a/pkg/provider/gmi/types.go b/pkg/provider/gmi/types.go new file mode 100644 index 0000000..7a57003 --- /dev/null +++ b/pkg/provider/gmi/types.go @@ -0,0 +1,154 @@ +package gmi + +import "waverless/pkg/interfaces" + +// ======================================== +// Request types - matches gmiless interfaces.EndpointRequest +// ======================================== + +// gmiEndpointRequest matches gmiless interfaces.EndpointRequest +type gmiEndpointRequest struct { + // Core + Name *string `json:"name,omitempty"` + Replicas *int `json:"replicas,omitempty"` + + // Hardware + ComputeType *string `json:"computeType,omitempty"` // GPU, CPU + GpuCount *int `json:"gpuCount,omitempty"` + GpuTypeIds *[]string `json:"gpuTypeIds,omitempty"` + VcpuCount *int `json:"vcpuCount,omitempty"` + + // Template + Template *gmiTemplateData `json:"template,omitempty"` + TemplateId *string `json:"templateId,omitempty"` + + // Networking / Storage + DataCenterIds *[]string `json:"dataCenterIds,omitempty"` + NetworkVolumeId *string `json:"networkVolumeId,omitempty"` + + // Endpoint type + Type *string `json:"type,omitempty"` // LB, QB + UseContainerResource *bool `json:"useContainerResource,omitempty"` + + // Autoscaling + ExecutionTimeoutMs *int64 `json:"executionTimeoutMs,omitempty"` + IdleTimeout *int `json:"idleTimeout,omitempty"` + WorkersMin *int `json:"workersMin,omitempty"` + WorkersMax *int `json:"workersMax,omitempty"` + ScalerType *string `json:"scalerType,omitempty"` + ScalerValue *int `json:"scalerValue,omitempty"` + ScaleDownIdleTime *int `json:"scaleDownIdleTime,omitempty"` + ScaleUpCooldown *int `json:"scaleUpCooldown,omitempty"` + ScaleDownCooldown *int `json:"scaleDownCooldown,omitempty"` +} + +// gmiTemplateData matches gmiless interfaces.TemplateData +type gmiTemplateData struct { + ImageName *string `json:"imageName,omitempty"` + Env map[string]string `json:"env,omitempty"` + DockerEntrypoint []string `json:"dockerEntrypoint,omitempty"` + DockerStartCmd []string `json:"dockerStartCmd,omitempty"` + Ports []string `json:"ports,omitempty"` + ShmSize *string `json:"shmSize,omitempty"` + VolumeMountPath *string `json:"volumeMountPath,omitempty"` +} + +// ======================================== +// Response types - from gmiless API +// ======================================== + +// gmiEndpointResponse is the response from gmiless endpoint APIs +type gmiEndpointResponse struct { + Id string `json:"id"` + Name string `json:"name"` + Image string `json:"image"` + Replicas int `json:"replicas"` + Status string `json:"status"` + GpuCount int `json:"gpuCount"` + GpuTypeIds []string `json:"gpuTypeIds"` + CreatedAt string `json:"createdAt"` + Env map[string]string `json:"env"` + Workers []gmiWorkerResponse `json:"workers"` + WorkersMin int `json:"workersMin"` + WorkersMax int `json:"workersMax"` + Template *gmiTemplateResp `json:"template,omitempty"` + AccessURL string `json:"accessUrl,omitempty"` +} + +type gmiTemplateResp struct { + ImageName string `json:"imageName,omitempty"` +} + +type gmiWorkerResponse struct { + Id string `json:"id"` + Name string `json:"name"` + Image string `json:"image"` + DesiredStatus string `json:"desiredStatus"` + LastStartedAt string `json:"lastStartedAt,omitempty"` +} + +// gmiPodInfo represents pod data returned by gmiless /workers and /describe +type gmiPodInfo struct { + Name string `json:"name"` + Namespace string `json:"namespace"` + UID string `json:"uid"` + Phase string `json:"phase"` + Status string `json:"status"` + Reason string `json:"reason"` + Message string `json:"message"` + IP string `json:"ip"` + NodeName string `json:"nodeName"` + CreatedAt string `json:"createdAt"` + StartedAt string `json:"startedAt"` + Labels map[string]string `json:"labels"` + Annotations map[string]string `json:"annotations"` + RestartCount int `json:"restartCount"` + Ready bool `json:"ready"` + Containers []struct { + Name string `json:"name"` + Image string `json:"image"` + Env []struct { + Name string `json:"name"` + Value string `json:"value,omitempty"` + ValueFrom *struct { + FieldRef *struct { + FieldPath string `json:"fieldPath"` + } `json:"fieldRef,omitempty"` + } `json:"valueFrom,omitempty"` + } `json:"env"` + Ready bool `json:"ready"` + Resources struct { + Limits map[string]string `json:"limits"` + Requests map[string]string `json:"requests"` + } `json:"resources"` + } `json:"containers"` + Conditions []struct { + Type string `json:"type"` + Status string `json:"status"` + LastTransitionTime string `json:"lastTransitionTime"` + } `json:"conditions"` + Volumes []struct { + Name string `json:"name"` + Type string `json:"type"` + } `json:"volumes"` + DeletionTimestamp string `json:"deletionTimestamp,omitempty"` +} + +// ======================================== +// Internal types +// ======================================== + +// WorkerStatusChangeCallback is called when a worker's status changes +type WorkerStatusChangeCallback func(workerID, endpoint string, info *interfaces.PodInfo) + +// WorkerDeleteCallback is called when a worker is deleted +type WorkerDeleteCallback func(workerID, endpoint string) + +// gmiWorkerState tracks a worker's last known state +type gmiWorkerState struct { + ID string + Endpoint string + Status string + CreatedAt string + StartedAt string +} diff --git a/pkg/provider/gmi/worker_status_monitor.go b/pkg/provider/gmi/worker_status_monitor.go new file mode 100644 index 0000000..ef6555e --- /dev/null +++ b/pkg/provider/gmi/worker_status_monitor.go @@ -0,0 +1,159 @@ +// Package gmi provides GMI deployment provider implementation. +// This file implements the GMI Worker Status Monitor for tracking worker failures. +package gmi + +import ( + "encoding/json" + "strings" + "sync" + "time" + + "waverless/pkg/interfaces" + "waverless/pkg/logger" + "waverless/pkg/status" + "waverless/pkg/store/mysql" +) + +// GMIWorkerStatusMonitor monitors GMI worker status changes and detects failures. +// It uses gmiless API polling to detect status changes (same approach as Novita). +type GMIWorkerStatusMonitor struct { + workerRepo *mysql.WorkerRepository + sanitizer *status.StatusSanitizer + + // workerStates tracks the last known state of each worker + // key: workerID, value: *monitorWorkerState + workerStates sync.Map +} + +// monitorWorkerState stores the last known state of a worker for the status monitor +type monitorWorkerState struct { + Status string // Worker status: "ONLINE", "OFFLINE", "STARTING", etc. + Phase string // Pod phase + Reason string // Reason if failed + Message string // Status message + UpdatedAt time.Time // Last update time +} + +// NewGMIWorkerStatusMonitor creates a new GMI worker status monitor. +func NewGMIWorkerStatusMonitor(workerRepo *mysql.WorkerRepository) *GMIWorkerStatusMonitor { + return &GMIWorkerStatusMonitor{ + workerRepo: workerRepo, + sanitizer: status.NewStatusSanitizer(), + } +} + +// DetectFailure detects if a worker is in a failed state from PodInfo. +// Returns nil if the worker is not in a failed state. +func (m *GMIWorkerStatusMonitor) DetectFailure(info *interfaces.PodInfo) *interfaces.WorkerFailureInfo { + if info == nil { + return nil + } + + phaseLower := strings.ToLower(info.Phase) + statusLower := strings.ToLower(info.Status) + reasonLower := strings.ToLower(info.Reason) + messageLower := strings.ToLower(info.Message) + + // Check for failure indicators + isFailed := phaseLower == "failed" || phaseLower == "error" || + statusLower == "failed" || statusLower == "error" || + strings.Contains(phaseLower, "fail") || + strings.Contains(statusLower, "fail") || + (reasonLower != "" && reasonLower != "ready") || + strings.Contains(messageLower, "error") || + strings.Contains(messageLower, "fail") + + if !isFailed { + return nil + } + + failureType := m.classifyFailure(info.Phase, info.Status, info.Reason, info.Message) + return m.createFailureInfo(failureType, info.Phase, info.Reason, info.Message) +} + +// classifyFailure converts GMI worker status to generic FailureType. +func (m *GMIWorkerStatusMonitor) classifyFailure(phase, status, reason, message string) interfaces.FailureType { + allLower := strings.ToLower(phase + " " + status + " " + reason + " " + message) + + // Image pull failures + if containsAny(allLower, "image", "pull", "registry", "manifest", "repository", "not found") { + return interfaces.FailureTypeImagePull + } + + // Container crash failures + if containsAny(allLower, "crash", "exit", "oom", "killed", "container error") { + return interfaces.FailureTypeContainerCrash + } + + // Resource limit failures + if containsAny(allLower, "resource", "memory", "cpu", "gpu", "quota", "limit", "insufficient", "unavailable") { + return interfaces.FailureTypeResourceLimit + } + + // Timeout failures + if containsAny(allLower, "timeout", "deadline", "timed out") { + return interfaces.FailureTypeTimeout + } + + return interfaces.FailureTypeUnknown +} + +// createFailureInfo creates a WorkerFailureInfo from state information. +func (m *GMIWorkerStatusMonitor) createFailureInfo(failureType interfaces.FailureType, state, reason, message string) *interfaces.WorkerFailureInfo { + r := state + if reason != "" { + r = reason + } + + sanitizedMsg := "" + if m.sanitizer != nil { + sanitized := m.sanitizer.Sanitize(failureType, r, message) + if sanitized != nil { + sanitizedMsg = sanitized.UserMessage + if sanitized.Suggestion != "" { + sanitizedMsg += ". " + sanitized.Suggestion + } + } + } + + return &interfaces.WorkerFailureInfo{ + Type: failureType, + Reason: r, + Message: message, + SanitizedMsg: sanitizedMsg, + OccurredAt: time.Now(), + } +} + +// UpdateWorkerFailure updates the worker record with failure information in the database. +func (m *GMIWorkerStatusMonitor) UpdateWorkerFailure(workerID, endpoint string, info *interfaces.WorkerFailureInfo) error { + if m.workerRepo == nil || info == nil { + return nil + } + + details := map[string]any{ + "type": string(info.Type), + "reason": info.Reason, + "message": info.Message, + "sanitizedMsg": info.SanitizedMsg, + "occurredAt": info.OccurredAt.Format(time.RFC3339), + "provider": "gmi", + } + detailsJSON, err := json.Marshal(details) + if err != nil { + logger.Warnf("GMI: failed to marshal failure details: %v", err) + detailsJSON = []byte("{}") + } + + return m.workerRepo.UpdateWorkerFailure(nil, workerID, string(info.Type), info.SanitizedMsg, string(detailsJSON), info.OccurredAt) +} + +// containsAny checks if the string contains any of the given substrings. +func containsAny(s string, substrs ...string) bool { + for _, substr := range substrs { + if strings.Contains(s, substr) { + return true + } + } + return false +} diff --git a/pkg/store/mysql/monitoring_repository.go b/pkg/store/mysql/monitoring_repository.go index a357ed3..f55ee69 100644 --- a/pkg/store/mysql/monitoring_repository.go +++ b/pkg/store/mysql/monitoring_repository.go @@ -252,8 +252,8 @@ func (r *MonitoringRepository) AggregateMinuteStats(ctx context.Context, endpoin } r.ds.DB(ctx).Raw(` SELECT - COALESCE(SUM(LEAST(idle_duration_ms, TIMESTAMPDIFF(MICROSECOND, ?, event_time) / 1000)), 0) as sum_idle_ms, - COALESCE(MAX(LEAST(idle_duration_ms, TIMESTAMPDIFF(MICROSECOND, ?, event_time) / 1000)), 0) as max_idle_ms, + CAST(COALESCE(SUM(LEAST(idle_duration_ms, TIMESTAMPDIFF(MICROSECOND, ?, event_time) / 1000)), 0) AS SIGNED) as sum_idle_ms, + CAST(COALESCE(MAX(LEAST(idle_duration_ms, TIMESTAMPDIFF(MICROSECOND, ?, event_time) / 1000)), 0) AS SIGNED) as max_idle_ms, COUNT(*) as count FROM worker_events WHERE endpoint = ? AND event_time >= ? AND event_time < ? @@ -334,13 +334,13 @@ func (r *MonitoringRepository) AggregateMinuteStats(ctx context.Context, endpoin // 6. Worker lifecycle from worker_events var lifecycleStats struct { - Created int `gorm:"column:created"` - Terminated int `gorm:"column:terminated"` + Created int `gorm:"column:created_count"` + Terminated int `gorm:"column:terminated_count"` } r.ds.DB(ctx).Raw(` SELECT - COUNT(CASE WHEN event_type = 'WORKER_REGISTERED' THEN 1 END) as created, - COUNT(CASE WHEN event_type = 'WORKER_OFFLINE' THEN 1 END) as `+"`terminated`"+` + COUNT(CASE WHEN event_type = 'WORKER_REGISTERED' THEN 1 END) as created_count, + COUNT(CASE WHEN event_type = 'WORKER_OFFLINE' THEN 1 END) as terminated_count FROM worker_events WHERE endpoint = ? AND event_time >= ? AND event_time < ? `, endpoint, from, to).Scan(&lifecycleStats)