diff --git a/kubectl-plugin/pkg/cmd/job/job_submit.go b/kubectl-plugin/pkg/cmd/job/job_submit.go index 189cbee82c4..68fc0c97c25 100644 --- a/kubectl-plugin/pkg/cmd/job/job_submit.go +++ b/kubectl-plugin/pkg/cmd/job/job_submit.go @@ -16,7 +16,6 @@ import ( "github.com/google/shlex" "github.com/spf13/cobra" - "k8s.io/apimachinery/pkg/api/meta" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/cli-runtime/pkg/genericclioptions" "k8s.io/cli-runtime/pkg/genericiooptions" @@ -36,6 +35,8 @@ const ( dashboardAddr = "http://localhost:8265" clusterTimeout = 120.0 portForwardTimeout = 60.0 + pollInterval = 2 + httpTimeout = 5 ) type SubmitJobOptions struct { @@ -392,40 +393,19 @@ func (options *SubmitJobOptions) Run(ctx context.Context, factory cmdutil.Factor if err != nil { return fmt.Errorf("Failed to get Ray Job status") } - time.Sleep(2 * time.Second) + time.Sleep(pollInterval * time.Second) } options.cluster = options.RayJob.Status.RayClusterName } else { return fmt.Errorf("Unknown cluster and did not provide Ray Job. One of the fields must be set") } - // Wait til the cluster is ready - var clusterReady bool - clusterWaitStartTime := time.Now() - currTime := clusterWaitStartTime - fmt.Printf("Waiting for RayCluster\n") - fmt.Printf("Checking Cluster Status for cluster %s...\n", options.cluster) - for !clusterReady && currTime.Sub(clusterWaitStartTime).Seconds() <= clusterTimeout { - time.Sleep(2 * time.Second) - currCluster, err := k8sClients.RayClient().RayV1().RayClusters(options.namespace).Get(ctx, options.cluster, v1.GetOptions{}) - if err != nil { - return fmt.Errorf("Failed to get cluster information with error: %w", err) - } - clusterReady = isRayClusterReady(currCluster) - if !clusterReady { - fmt.Println("Cluster is not ready") + // Wait until the RayCluster with type=RayClusterProvisioned and status=true + fmt.Printf("Waiting for RayCluster %s to be ready...\n", options.cluster) + if err := k8sClients.WaitRayClusterProvisioned(ctx, options.namespace, options.cluster, time.Duration(clusterTimeout)*time.Second); err != nil { + if cleanupErr := options.cleanupRayJob(ctx, k8sClients); cleanupErr != nil { + return fmt.Errorf("Failed to clean up Ray job after timeout: %w (original error: %w)", cleanupErr, err) } - currTime = time.Now() - } - - if !clusterReady { - fmt.Printf("Deleting RayJob...\n") - err = k8sClients.RayClient().RayV1().RayJobs(options.namespace).Delete(ctx, options.RayJob.GetName(), v1.DeleteOptions{}) - if err != nil { - return fmt.Errorf("Failed to clean up Ray job after time out.: %w", err) - } - fmt.Printf("Cleaned Up RayJob: %s\n", options.RayJob.GetName()) - return fmt.Errorf("Timed out waiting for cluster") } @@ -450,35 +430,9 @@ func (options *SubmitJobOptions) Run(ctx context.Context, factory cmdutil.Factor }() // Wait for port forward to be ready - var portForwardReady bool - portForwardWaitStartTime := time.Now() - currTime = portForwardWaitStartTime - - portforwardCheckRequest, err := http.NewRequestWithContext(ctx, http.MethodGet, dashboardAddr, nil) - if err != nil { - return fmt.Errorf("Error occurred when trying to create request to probe cluster endpoint: %w", err) - } - httpClient := http.Client{ - Timeout: 5 * time.Second, - } fmt.Printf("Waiting for port forwarding...") - for !portForwardReady && currTime.Sub(portForwardWaitStartTime).Seconds() <= portForwardTimeout { - time.Sleep(2 * time.Second) - rayDashboardResponse, err := httpClient.Do(portforwardCheckRequest) - if err != nil { - err = fmt.Errorf("Error occurred when waiting for port forwarding: %w", err) - fmt.Println(err) - currTime = time.Now() - continue - } - if rayDashboardResponse.StatusCode >= 200 && rayDashboardResponse.StatusCode < 300 { - portForwardReady = true - } - rayDashboardResponse.Body.Close() - currTime = time.Now() - } - if !portForwardReady { - return fmt.Errorf("Timed out waiting for port forwarding") + if err := waitForPortForward(ctx); err != nil { + return fmt.Errorf("Failed to establish port forwarding: %w", err) } options.address = dashboardAddr fmt.Printf("Port forwarding started on %s\n", options.address) @@ -722,10 +676,6 @@ func runtimeEnvHasWorkingDir(runtimePath string) (string, error) { return "", nil } -func isRayClusterReady(rayCluster *rayv1.RayCluster) bool { - return meta.IsStatusConditionTrue(rayCluster.Status.Conditions, "Ready") || rayCluster.Status.State == rayv1.Ready -} - // Generates a 16-character random ID with a prefix, mimicking Ray Job submission_id. // ref: ray/python/ray/dashboard/modules/job/job_manager.py func generateSubmissionID() (string, error) { @@ -743,3 +693,38 @@ func generateSubmissionID() (string, error) { } return fmt.Sprintf("raysubmit_%s", string(idRunes)), nil } + +// waitForPortForward waits for port forwarding to be ready +func waitForPortForward(ctx context.Context) error { + httpClient := http.Client{Timeout: httpTimeout * time.Second} + portforwardCheckRequest, err := http.NewRequestWithContext(ctx, http.MethodGet, dashboardAddr, nil) + if err != nil { + return fmt.Errorf("Error occurred when trying to create request to probe cluster endpoint: %w", err) + } + + startTime := time.Now() + for time.Since(startTime).Seconds() <= portForwardTimeout { + time.Sleep(pollInterval * time.Second) + rayDashboardResponse, err := httpClient.Do(portforwardCheckRequest) + if err != nil { + fmt.Printf("Error occurred when waiting for port forwarding: %v\n", err) + continue + } + if rayDashboardResponse.StatusCode >= 200 && rayDashboardResponse.StatusCode < 300 { + rayDashboardResponse.Body.Close() + return nil + } + rayDashboardResponse.Body.Close() + } + return fmt.Errorf("Timed out waiting for port forwarding") +} + +func (options *SubmitJobOptions) cleanupRayJob(ctx context.Context, k8sClients client.Client) error { + fmt.Printf("Deleting RayJob...\n") + err := k8sClients.RayClient().RayV1().RayJobs(options.namespace).Delete(ctx, options.RayJob.GetName(), v1.DeleteOptions{}) + if err != nil { + return fmt.Errorf("Failed to clean up Ray job: %w", err) + } + fmt.Printf("Cleaned Up RayJob: %s\n", options.RayJob.GetName()) + return nil +}