diff --git a/test/integration/controller/jobs/pytorchjob/pytorchjob_controller_test.go b/test/integration/controller/jobs/pytorchjob/pytorchjob_controller_test.go index fa02d31fac..966e2e1f7f 100644 --- a/test/integration/controller/jobs/pytorchjob/pytorchjob_controller_test.go +++ b/test/integration/controller/jobs/pytorchjob/pytorchjob_controller_test.go @@ -609,3 +609,95 @@ var _ = ginkgo.Describe("Job controller interacting with scheduler", ginkgo.Orde }) }) }) + +var _ = ginkgo.Describe("Gang Scheduling gets partial admission", ginkgo.Ordered, ginkgo.ContinueOnFailure, func() { + var ( + ns *corev1.Namespace + clusterQueue *kueue.ClusterQueue + localQueue *kueue.LocalQueue + defaultFlavor *kueue.ResourceFlavor + ) + + ginkgo.BeforeAll(func() { + fwk = &framework.Framework{ + CRDPath: crdPath, + DepCRDPaths: []string{pytorchCrdPath}, + } + cfg := fwk.Init() + ctx, k8sClient = fwk.RunManager(cfg, managerAndSchedulerSetup()) + }) + ginkgo.AfterAll(func() { + fwk.Teardown() + }) + + ginkgo.BeforeEach(func() { + ns = &corev1.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "core-", + }, + } + gomega.Expect(k8sClient.Create(ctx, ns)).To(gomega.Succeed()) + + // Give cluster queue enough resources except for GPU + defaultFlavor = testing.MakeResourceFlavor("default").NodeLabel(instanceKey, "default").Obj() + gomega.Expect(k8sClient.Create(ctx, defaultFlavor)).Should(gomega.Succeed()) + clusterQueue = testing.MakeClusterQueue("dev-clusterqueue"). + ResourceGroup( + *testing.MakeFlavorQuotas("default"). + Resource(corev1.ResourceCPU, "172"). + Resource(corev1.ResourceMemory, "2074"). + Resource("example.com/gpu", "8"). + Obj(), + ).Obj() + gomega.Expect(k8sClient.Create(ctx, clusterQueue)).Should(gomega.Succeed()) + }) + + ginkgo.AfterEach(func() { + gomega.Expect(util.DeleteNamespace(ctx, k8sClient, ns)).To(gomega.Succeed()) + util.ExpectObjectToBeDeleted(ctx, k8sClient, clusterQueue, true) + util.ExpectObjectToBeDeleted(ctx, k8sClient, defaultFlavor, true) + }) + + ginkgo.It("Should schedule jobs as they fit in their ClusterQueue", func() { + ginkgo.By("creating localQueue") + localQueue = testing.MakeLocalQueue("local-queue", ns.Name).ClusterQueue(clusterQueue.Name).Obj() + gomega.Expect(k8sClient.Create(ctx, localQueue)).Should(gomega.Succeed()) + + // A job that exceeds GPU twice + kfJob := testingpytorchjob.MakePyTorchJob(jobName, ns.Name). + PyTorchReplicaSpecsDefault(). + Queue(localQueue.Name). + Request(kftraining.PyTorchJobReplicaTypeMaster, corev1.ResourceCPU, "86"). + Request(kftraining.PyTorchJobReplicaTypeWorker, corev1.ResourceCPU, "86"). + Request(kftraining.PyTorchJobReplicaTypeMaster, corev1.ResourceMemory, "1037"). + Request(kftraining.PyTorchJobReplicaTypeWorker, corev1.ResourceMemory, "1037"). + Request(kftraining.PyTorchJobReplicaTypeMaster, "example.com/gpu", "8"). + Request(kftraining.PyTorchJobReplicaTypeWorker, "example.com/gpu", "8"). + Obj() + + ginkgo.By("creating the job", func() { + gomega.Expect(k8sClient.Create(ctx, kfJob)).Should(gomega.Succeed()) + }) + + jobLookupKey := &types.NamespacedName{Name: jobName, Namespace: ns.Name} + ginkgo.By("fetch the job and verify it is suspended", func() { + createdJob := &kftraining.PyTorchJob{} + gomega.Eventually(func() *bool { + gomega.Expect(k8sClient.Get(ctx, *jobLookupKey, createdJob)).Should(gomega.Succeed()) + return createdJob.Spec.RunPolicy.Suspend + }, util.Timeout, util.Interval).Should(gomega.Equal(ptr.To(true))) + }) + + wlLookupKey := types.NamespacedName{Name: workloadpytorchjob.GetWorkloadNameForPyTorchJob(kfJob.Name, kfJob.UID), Namespace: ns.Name} + createdWorkload := util.AwaitAndVerifyCreatedWorkload(ctx, k8sClient, wlLookupKey, kfJob) + + util.ExpectPendingWorkloadsMetric(clusterQueue, 0, 1) + util.ExpectReservingActiveWorkloadsMetric(clusterQueue, 0) + + gomega.Consistently(func() bool { + lookupKey := types.NamespacedName{Name: createdWorkload.Name, Namespace: createdWorkload.Namespace} + gomega.Expect(k8sClient.Get(ctx, lookupKey, createdWorkload)).Should(gomega.Succeed()) + return !workload.HasQuotaReservation(createdWorkload) + }, util.ConsistentDuration, util.Interval).Should(gomega.BeTrue()) + }) +})