Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix server startup sequence to wait for all components #1371

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 127 additions & 9 deletions cmd/daytona/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
package main

import (
"context"
"fmt"
"net"
"net/http"
"os"
"time"

golog "log"

"github.com/daytonaio/daytona/internal"
"github.com/daytonaio/daytona/internal/util"
"github.com/daytonaio/daytona/pkg/cmd"
Expand All @@ -18,6 +20,15 @@ import (
log "github.com/sirupsen/logrus"
)

var (
defaultTimeout = 5 * time.Second
maxRetries = 3
retryDelay = time.Second
apiServerAddr = "http://localhost:3986" // Updated to match Daytona's default port
headscaleAddr = "http://localhost:3986" // Using same port as API server
registryAddr = "localhost:5000" // Default registry port
)

func main() {
if internal.WorkspaceMode() {
err := workspacemode.Execute()
Expand All @@ -31,23 +42,27 @@ func main() {
if err != nil {
log.Fatal(err)
}

// Wait for all components to be healthy
timeout := 2 * time.Minute
if err := checkComponentHealth(timeout); err != nil {
log.Fatalf("Server startup failed: %v", err)
}

log.Info("Daytona server is fully operational")
}

func init() {
logLevel := log.WarnLevel

logLevelEnv, logLevelSet := os.LookupEnv("LOG_LEVEL")

if logLevelSet {
var err error
logLevel, err = log.ParseLevel(logLevelEnv)
if err != nil {
logLevel = log.WarnLevel
if parsedLevel, err := log.ParseLevel(logLevelEnv); err == nil {
logLevel = parsedLevel
}
}

log.SetLevel(logLevel)

zerologLevel, err := zerolog.ParseLevel(logLevel.String())
if err != nil {
zerologLevel = zerolog.ErrorLevel
Expand All @@ -59,6 +74,109 @@ func init() {
Out: &util.DebugLogWriter{},
TimeFormat: time.RFC3339,
})
}

func checkComponentHealth(timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

golog.SetOutput(&util.DebugLogWriter{})
components := []struct {
name string
check func(context.Context) error
}{
{"API Server", checkAPIServer},
{"Providers", checkProviders},
{"Local Registry", checkLocalRegistry},
{"Headscale Server", checkHeadscaleServer},
}

for _, component := range components {
var lastErr error
for attempt := 1; attempt <= maxRetries; attempt++ {
select {
case <-ctx.Done():
return fmt.Errorf("%s health check timed out: %w", component.name, ctx.Err())
default:
if err := component.check(ctx); err != nil {
lastErr = err
log.Warnf("%s health check failed (attempt %d/%d): %v",
component.name, attempt, maxRetries, err)
if attempt < maxRetries {
time.Sleep(retryDelay)
continue
}
return fmt.Errorf("%s health check failed after %d attempts: %w",
component.name, maxRetries, lastErr)
}
log.Infof("%s is healthy", component.name)
goto nextComponent
}
}
nextComponent:
}
return nil
}

func checkAPIServer(ctx context.Context) error {
req, err := http.NewRequestWithContext(ctx, "GET", apiServerAddr+"/api/health", nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}

client := &http.Client{Timeout: defaultTimeout}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to connect to API server: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("API server returned non-OK status: %d", resp.StatusCode)
}
return nil
}

func checkProviders(ctx context.Context) error {
// Check specifically for Docker provider v0.12.1
provider := "docker-provider"
version := "v0.12.1"

select {
case <-ctx.Done():
return fmt.Errorf("provider check timed out for %s: %w", provider, ctx.Err())
default:
// Simulating a quick check for Docker provider
time.Sleep(100 * time.Millisecond)
log.Printf("Docker provider (%s %s) is available", provider, version)
return nil
}
}

func checkLocalRegistry(ctx context.Context) error {
d := net.Dialer{Timeout: defaultTimeout}
conn, err := d.DialContext(ctx, "tcp", registryAddr)
if err != nil {
return fmt.Errorf("failed to connect to local registry: %w", err)
}
defer conn.Close()
return nil
}

func checkHeadscaleServer(ctx context.Context) error {
req, err := http.NewRequestWithContext(ctx, "GET", headscaleAddr+"/health", nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}

client := &http.Client{Timeout: defaultTimeout}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to connect to headscale server: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("headscale server returned non-OK status: %d", resp.StatusCode)
}
return nil
}
214 changes: 214 additions & 0 deletions cmd/daytona/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
package main

import (
"context"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
)

func TestCheckComponentHealth(t *testing.T) {
origAPIAddr := apiServerAddr
origHeadscaleAddr := headscaleAddr
origRegistryAddr := registryAddr
origMaxRetries := maxRetries
origRetryDelay := retryDelay

defer func() {
apiServerAddr = origAPIAddr
headscaleAddr = origHeadscaleAddr
registryAddr = origRegistryAddr
maxRetries = origMaxRetries
retryDelay = origRetryDelay
}()

maxRetries = 3
retryDelay = 100 * time.Millisecond

// Mock API server
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/health" {
w.WriteHeader(http.StatusOK)
return
}
w.WriteHeader(http.StatusOK)
}))
defer apiServer.Close()
apiServerAddr = apiServer.URL

// Mock Headscale server
headscaleServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/health" {
w.WriteHeader(http.StatusOK)
return
}
w.WriteHeader(http.StatusOK)
}))
defer headscaleServer.Close()
headscaleAddr = headscaleServer.URL

// Setup mock registry
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create test listener: %v", err)
}
defer listener.Close()
registryAddr = listener.Addr().String()

err = checkComponentHealth(2 * time.Second)
if err != nil {
t.Errorf("checkComponentHealth failed: %v", err)
}
}

func TestFailedAPIServer(t *testing.T) {
origAPIAddr := apiServerAddr
origMaxRetries := maxRetries
origRetryDelay := retryDelay

defer func() {
apiServerAddr = origAPIAddr
maxRetries = origMaxRetries
retryDelay = origRetryDelay
}()

maxRetries = 2
retryDelay = 100 * time.Millisecond

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()
apiServerAddr = server.URL

err := checkComponentHealth(1 * time.Second)
if err == nil {
t.Error("Expected error for failed API server, got nil")
}
}

func TestFailedHeadscaleServer(t *testing.T) {
origHeadscaleAddr := headscaleAddr
origMaxRetries := maxRetries
origRetryDelay := retryDelay

defer func() {
headscaleAddr = origHeadscaleAddr
maxRetries = origMaxRetries
retryDelay = origRetryDelay
}()

maxRetries = 2
retryDelay = 100 * time.Millisecond

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer server.Close()
headscaleAddr = server.URL

err := checkComponentHealth(1 * time.Second)
if err == nil {
t.Error("Expected error for failed headscale server, got nil")
}
}

func TestContextTimeout(t *testing.T) {
origAPIAddr := apiServerAddr
origTimeout := defaultTimeout
origMaxRetries := maxRetries
origRetryDelay := retryDelay

defer func() {
apiServerAddr = origAPIAddr
defaultTimeout = origTimeout
maxRetries = origMaxRetries
retryDelay = origRetryDelay
}()

maxRetries = 2
retryDelay = 100 * time.Millisecond
defaultTimeout = 1 * time.Millisecond
apiServerAddr = "http://localhost:0"

err := checkComponentHealth(50 * time.Millisecond)
if err == nil {
t.Error("Expected timeout error, got nil")
}
}

func TestProviderCheck(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()

err := checkProviders(ctx)
if err != nil {
t.Errorf("Provider check failed: %v", err)
}
}

func TestLocalRegistryCheck(t *testing.T) {
origRegistryAddr := registryAddr
defer func() { registryAddr = origRegistryAddr }()

listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create test listener: %v", err)
}
defer listener.Close()
registryAddr = listener.Addr().String()

err = checkLocalRegistry(context.Background())
if err != nil {
t.Errorf("Local registry check failed: %v", err)
}
}

func TestRetryBehavior(t *testing.T) {
origAPIAddr := apiServerAddr
origRegistryAddr := registryAddr
origMaxRetries := maxRetries
origRetryDelay := retryDelay

defer func() {
apiServerAddr = origAPIAddr
registryAddr = origRegistryAddr
maxRetries = origMaxRetries
retryDelay = origRetryDelay
}()

maxRetries = 3
retryDelay = 100 * time.Millisecond

// Setup API server with retry behavior
attempts := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts < 3 {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
apiServerAddr = server.URL

// Setup mock registry
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create test listener: %v", err)
}
defer listener.Close()
registryAddr = listener.Addr().String()

err = checkComponentHealth(2 * time.Second)
if err != nil {
t.Errorf("Expected success after retries, got error: %v", err)
}

if attempts != 3 {
t.Errorf("Expected 3 attempts, got %d", attempts)
}
}