Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions cmd/initializers.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"waverless/pkg/interfaces"
"waverless/pkg/logger"
"waverless/pkg/provider"
"waverless/pkg/provider/gmi"
"waverless/pkg/provider/k8s"
"waverless/pkg/provider/novita"
"waverless/pkg/resource"
Expand Down Expand Up @@ -204,9 +205,17 @@ func (app *Application) initServices() error {
}
}

// Get GMI deployment provider for status sync
var gmiDeployProvider *gmi.GMIDeploymentProvider
if app.config.GMI.Enabled {
if gmiProv, ok := app.deploymentProvider.(*gmi.GMIDeploymentProvider); ok {
gmiDeployProvider = gmiProv
}
}

// Setup Lifecycle Manager for unified watcher management
// This replaces all the individual setup*Watcher methods
if err := app.setupLifecycleManager(k8sDeployProvider, novitaDeployProvider); err != nil {
if err := app.setupLifecycleManager(k8sDeployProvider, novitaDeployProvider, gmiDeployProvider); err != nil {
logger.WarnCtx(app.ctx, "Failed to setup lifecycle manager: %v (non-critical, continuing)", err)
}

Expand Down Expand Up @@ -237,7 +246,7 @@ func (app *Application) initServices() error {

// setupLifecycleManager sets up the unified lifecycle manager for all providers
// This replaces the individual setup*Watcher methods with a centralized approach
func (app *Application) setupLifecycleManager(k8sProvider *k8s.K8sDeploymentProvider, novitaProvider *novita.NovitaDeploymentProvider) error {
func (app *Application) setupLifecycleManager(k8sProvider *k8s.K8sDeploymentProvider, novitaProvider *novita.NovitaDeploymentProvider, gmiProvider *gmi.GMIDeploymentProvider) error {
// Create lifecycle manager
app.lifecycleManager = lifecycle.NewManager(
app.ctx,
Expand Down Expand Up @@ -276,6 +285,14 @@ func (app *Application) setupLifecycleManager(k8sProvider *k8s.K8sDeploymentProv
}
}

// Register GMI provider if enabled
if gmiProvider != nil {
logger.InfoCtx(app.ctx, "Registering GMI provider with lifecycle manager...")
if err := app.lifecycleManager.RegisterGMIProvider(gmiProvider); err != nil {
logger.WarnCtx(app.ctx, "Failed to register GMI provider: %v", err)
}
}

logger.InfoCtx(app.ctx, "Lifecycle manager setup completed with providers: %v", app.lifecycleManager.GetRegisteredProviders())
return nil
}
Expand All @@ -288,8 +305,8 @@ func (app *Application) initHandlers() error {
app.statisticsHandler = handler.NewStatisticsHandler(app.statisticsService, app.workerService)
app.monitoringHandler = handler.NewMonitoringHandler(app.monitoringService)

// Initialize Endpoint Handler (for K8s or Novita)
if app.config.K8s.Enabled || app.config.Novita.Enabled {
// Initialize Endpoint Handler (for K8s, Novita or GMI)
if app.config.K8s.Enabled || app.config.Novita.Enabled || app.config.GMI.Enabled {
if app.deploymentProvider == nil {
logger.ErrorCtx(app.ctx, "Deployment provider is enabled but provider is nil")
} else {
Expand All @@ -300,6 +317,9 @@ func (app *Application) initHandlers() error {
if app.config.Novita.Enabled {
logger.InfoCtx(app.ctx, "Endpoint handler initialized for Novita")
}
if app.config.GMI.Enabled {
logger.InfoCtx(app.ctx, "Endpoint handler initialized for GMI")
}
}
}

Expand Down Expand Up @@ -347,8 +367,7 @@ func (app *Application) initAutoScaler() error {
logger.InfoCtx(app.ctx, "Spec service injected into SpecManager - specs will be read from database first")
}
} else {
logger.WarnCtx(app.ctx, "AutoScaler requires K8s deployment provider, skipping initialization")
return nil
logger.WarnCtx(app.ctx, "K8s deployment provider not available, autoscaler will run without resource limits")
}

autoscalerConfig := &autoscaler.Config{
Expand Down
9 changes: 9 additions & 0 deletions config/config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,12 @@ novita:
base_url: "https://api.novita.ai" # Novita API base URL
config_dir: "./config" # Configuration directory (contains specs.yaml and templates/)
poll_interval: 10 # Poll interval for status updates (seconds, default: 10)

# GMI Serverless Configuration
gmi:
enabled: false # Enable GMI serverless provider
api_key: "xxx" # Your GMI API key (Bearer token)
base_url: "http://34.105.52.179:31818" # GMI API base URL
callback_url: "http://waverless-svc" # External URL of waverless server for worker callbacks
config_dir: "./config" # Configuration directory (contains specs.yaml and templates/)
poll_interval: 10 # Poll interval for status updates (seconds, default: 10)
68 changes: 68 additions & 0 deletions internal/service/lifecycle/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"waverless/pkg/interfaces"
"waverless/pkg/logger"
"waverless/pkg/provider"
"waverless/pkg/provider/gmi"
"waverless/pkg/provider/k8s"
"waverless/pkg/provider/novita"
"waverless/pkg/status"
Expand Down Expand Up @@ -151,6 +152,38 @@ func (m *Manager) RegisterNovitaProvider(provider *novita.NovitaDeploymentProvid
return nil
}

// RegisterGMIProvider registers GMI Provider and starts its watchers
func (m *Manager) RegisterGMIProvider(provider *gmi.GMIDeploymentProvider) error {
if provider == nil {
return fmt.Errorf("gmi provider is nil")
}

m.mu.Lock()
defer m.mu.Unlock()

name := "gmi"
if _, exists := m.providers[name]; exists {
logger.WarnCtx(m.ctx, "Provider %s already registered, skipping", name)
return nil
}

lifecycle := provider.GetLifecycle()
if lifecycle == nil {
return fmt.Errorf("gmi provider lifecycle is nil")
}

callbacks := m.createGMICallbacks()

if err := lifecycle.RegisterWatchers(callbacks); err != nil {
return fmt.Errorf("failed to register watchers for gmi provider: %w", err)
}

m.providers[name] = lifecycle
logger.InfoCtx(m.ctx, "GMI Provider registered successfully")

return nil
}

// createK8sCallbacks creates K8s callback functions
func (m *Manager) createK8sCallbacks(k8sProvider *k8s.K8sDeploymentProvider) *k8s.K8sLifecycleCallbacks {
return &k8s.K8sLifecycleCallbacks{
Expand Down Expand Up @@ -262,6 +295,41 @@ func (m *Manager) createNovitaCallbacks() *novita.NovitaLifecycleCallbacks {
}
}

// createGMICallbacks creates GMI callback functions
func (m *Manager) createGMICallbacks() *gmi.GMILifecycleCallbacks {
return &gmi.GMILifecycleCallbacks{
OnWorkerStatusChange: func(workerID, endpoint string, podInfo *interfaces.PodInfo) {
m.callbackHandler.HandleWorkerStatusChange(&provider.WorkerStatusEvent{
WorkerID: workerID,
Endpoint: endpoint,
PodInfo: podInfo,
})
},
OnWorkerDelete: func(workerID, endpoint string) {
m.callbackHandler.HandleWorkerDelete(&provider.WorkerDeleteEvent{
WorkerID: workerID,
Endpoint: endpoint,
})
},
OnWorkerFailure: func(workerID, endpoint string, failureInfo *interfaces.WorkerFailureInfo) {
m.callbackHandler.HandleWorkerFailure(&provider.WorkerFailureEvent{
WorkerID: workerID,
Endpoint: endpoint,
FailureInfo: failureInfo,
})
},
OnEndpointStatusChange: func(endpoint, status string, desiredReplicas, readyReplicas, availableReplicas int) {
m.callbackHandler.HandleEndpointStatusChange(&provider.EndpointStatusEvent{
Endpoint: endpoint,
Status: status,
DesiredReplicas: desiredReplicas,
ReadyReplicas: readyReplicas,
AvailableReplicas: availableReplicas,
})
},
}
}

// UnregisterProvider unregisters a Provider and stops its watchers
func (m *Manager) UnregisterProvider(name string) error {
m.mu.Lock()
Expand Down
7 changes: 6 additions & 1 deletion pkg/autoscaler/resource_calculator.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ func (c *ResourceCalculator) CalculateEndpointResource(ctx context.Context, endp
specName = meta.SpecName
}

// Get spec
// Get spec - if specManager is nil (e.g., GMI/Novita provider), skip resource calculation
if c.specManager == nil {
logger.DebugCtx(ctx, "specManager is nil, skipping resource calculation for endpoint %s (non-K8s provider)", endpoint.Name)
return &Resources{}, nil
}

spec, err := c.specManager.GetSpec(specName)
if err != nil {
return nil, fmt.Errorf("failed to get spec %s: %w", specName, err)
Expand Down
11 changes: 11 additions & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type Config struct {
Notification NotificationConfig `yaml:"notification"` // Notification configuration
Providers *ProvidersConfig `yaml:"providers,omitempty"` // Providers configuration (optional)
Novita NovitaConfig `yaml:"novita"` // Novita serverless configuration
GMI GMIConfig `yaml:"gmi"` // GMI serverless configuration
ImageValidation ImageValidationConfig `yaml:"imageValidation"` // Image validation configuration
ResourceReleaser ResourceReleaserConfig `yaml:"resourceReleaser"` // Resource releaser configuration
}
Expand Down Expand Up @@ -217,6 +218,16 @@ type NovitaConfig struct {
PollInterval int `yaml:"poll_interval"` // Poll interval for status updates (seconds, default: 10)
}

// GMIConfig GMI (Generic Model Interface) serverless configuration
type GMIConfig struct {
Enabled bool `yaml:"enabled"` // Whether to enable GMI provider
APIKey string `yaml:"api_key"` // GMI API key (Bearer token)
BaseURL string `yaml:"base_url"` // API base URL
CallbackURL string `yaml:"callback_url"` // External URL of waverless server for worker callbacks (e.g., http://35.189.190.93:8087)
ConfigDir string `yaml:"config_dir"` // Configuration directory (specs.yaml and templates)
PollInterval int `yaml:"poll_interval"` // Poll interval for status updates (seconds, default: 10)
}

// Init initializes configuration
func Init() error {
configPath := os.Getenv("CONFIG_PATH")
Expand Down
4 changes: 2 additions & 2 deletions pkg/interfaces/deployment.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ type UpdateDeploymentRequest struct {
VolumeMounts *[]VolumeMount `json:"volumeMounts,omitempty"` // New volume mounts (optional, use pointer to distinguish empty from unset)
ShmSize *string `json:"shmSize,omitempty"` // New shared memory size (optional, use pointer to distinguish empty from unset)
EnablePtrace *bool `json:"enablePtrace,omitempty"` // Enable SYS_PTRACE capability (optional, use pointer to distinguish false from unset)
Env *map[string]string `json:"env,omitempty"` // New environment variables (optional, use pointer to distinguish empty from unset)
TaskTimeout *int `json:"taskTimeout,omitempty"` // New task timeout (optional)
Env *map[string]string `json:"env,omitempty"` // New environment variables (optional, use pointer to distinguish empty from unset)
TaskTimeout *int `json:"taskTimeout,omitempty"` // New task timeout (optional)
}

// UpdateEndpointConfigRequest update Endpoint configuration request (metadata + autoscaling configuration)
Expand Down
2 changes: 2 additions & 0 deletions pkg/provider/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"waverless/pkg/config"
"waverless/pkg/interfaces"
"waverless/pkg/provider/docker"
"waverless/pkg/provider/gmi"
"waverless/pkg/provider/k8s"
"waverless/pkg/provider/novita"
)
Expand Down Expand Up @@ -80,6 +81,7 @@ func init() {
RegisterDeploymentProvider("kubernetes", k8s.NewK8sDeploymentProvider)
RegisterDeploymentProvider("docker", docker.NewDockerDeploymentProvider)
RegisterDeploymentProvider("novita", novita.NewNovitaDeploymentProvider)
RegisterDeploymentProvider("gmi", gmi.NewGMIDeploymentProvider)
}

func (f *ProviderFactory) CreateDeploymentProvider(providerType string) (interfaces.DeploymentProvider, error) {
Expand Down
74 changes: 74 additions & 0 deletions pkg/provider/gmi/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package gmi

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"

"waverless/pkg/logger"
)

// doRequest executes an HTTP request with Bearer token auth
func (p *GMIDeploymentProvider) doRequest(ctx context.Context, method, url string, body interface{}) ([]byte, error) {
var reqBody io.Reader
if body != nil {
jsonData, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
logger.Debugf("GMI API request: %s %s body=%s", method, url, string(jsonData))
reqBody = bytes.NewBuffer(jsonData)
}

httpReq, err := http.NewRequestWithContext(ctx, method, url, reqBody)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}

// gmiless uses "Authorization: Bearer <token>" for /api/v1 routes
httpReq.Header.Set("Authorization", "Bearer "+p.token)
if body != nil {
httpReq.Header.Set("Content-Type", "application/json")
}
httpReq.Header.Set("Accept", "application/json")

resp, err := p.client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to call GMI API: %w", err)
}
defer resp.Body.Close()

respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}

if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("GMI API returned error status %d: %s", resp.StatusCode, string(respBody))
}

return respBody, nil
}

// getEndpointID looks up the gmiless endpoint ID by name (with cache)
func (p *GMIDeploymentProvider) getEndpointID(ctx context.Context, endpoint string) (string, error) {
// Check cache
if id, ok := p.endpointCache.Load(endpoint); ok {
return id.(string), nil
}

// Not cached, query list API to populate cache
_, err := p.ListApps(ctx)
if err != nil {
return "", fmt.Errorf("failed to list endpoints: %w", err)
}

if id, ok := p.endpointCache.Load(endpoint); ok {
return id.(string), nil
}

return "", fmt.Errorf("endpoint %s not found in GMI", endpoint)
}
Loading