Skip to content

Commit 48682c5

Browse files
author
JimmyYang20
committed
support for modelarts adapter
Signed-off-by: JimmyYang20 <[email protected]>
1 parent 78220b9 commit 48682c5

File tree

4 files changed

+177
-22
lines changed

4 files changed

+177
-22
lines changed

Diff for: pkg/globalmanager/controllers/lifelonglearning/lifelonglearningjob.go

+79-22
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package lifelonglearning
1818

1919
import (
2020
"context"
21+
"crypto/sha256"
2122
"encoding/json"
2223
"fmt"
2324
"k8s.io/apimachinery/pkg/types"
@@ -27,6 +28,7 @@ import (
2728
v1 "k8s.io/api/core/v1"
2829
"k8s.io/apimachinery/pkg/api/errors"
2930
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
31+
lruexpirecache "k8s.io/apimachinery/pkg/util/cache"
3032
utilrand "k8s.io/apimachinery/pkg/util/rand"
3133
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
3234
"k8s.io/apimachinery/pkg/util/wait"
@@ -52,6 +54,8 @@ const (
5254
KindName = "LifelongLearningJob"
5355
// Name is this controller name
5456
Name = "LifelongLearning"
57+
// VirtualKubeletNode is virtual node
58+
VirtualKubeletNode = "virtual-kubelet"
5559
)
5660

5761
// Kind contains the schema.GroupVersionKind for this controller type.
@@ -82,6 +86,8 @@ type Controller struct {
8286
cfg *config.ControllerConfig
8387

8488
sendToEdgeFunc runtime.DownstreamSendFunc
89+
90+
lruExpireCache *lruexpirecache.LRUExpireCache
8591
}
8692

8793
// Run starts the main goroutine responsible for watching and syncing jobs.
@@ -379,14 +385,17 @@ func (c *Controller) transitJobState(job *sednav1.LifelongLearningJob) (bool, er
379385
// include train, eval, deploy pod
380386
var err error
381387
if jobStage == sednav1.LLJobDeploy {
382-
err = c.restartInferPod(job)
383-
if err != nil {
384-
klog.V(2).Infof("lifelonglearning job %v/%v inference pod failed to restart, err:%s", job.Namespace, job.Name, err)
385-
return needUpdated, err
386-
}
388+
if !c.hasJobInCache(job) {
389+
err = c.restartInferPod(job)
390+
if err != nil {
391+
klog.V(2).Infof("lifelonglearning job %v/%v inference pod failed to restart, err:%s", job.Namespace, job.Name, err)
392+
return needUpdated, err
393+
}
387394

388-
klog.V(2).Infof("lifelonglearning job %v/%v inference pod restarts successfully", job.Namespace, job.Name)
389-
newConditionType = sednav1.LLJobStageCondCompleted
395+
klog.V(2).Infof("lifelonglearning job %v/%v inference pod restarts successfully", job.Namespace, job.Name)
396+
newConditionType = sednav1.LLJobStageCondCompleted
397+
c.addJobToCache(job)
398+
}
390399
} else {
391400
if podStatus != v1.PodPending && podStatus != v1.PodRunning {
392401
err = c.createPod(job, jobStage)
@@ -406,10 +415,6 @@ func (c *Controller) transitJobState(job *sednav1.LifelongLearningJob) (bool, er
406415

407416
// watch pod status, if pod running, set type running
408417
newConditionType = sednav1.LLJobStageCondRunning
409-
} else if podStatus == v1.PodSucceeded {
410-
// watch pod status, if pod completed, set type completed
411-
newConditionType = sednav1.LLJobStageCondCompleted
412-
klog.V(2).Infof("lifelonglearning job %v/%v %v stage completed!", job.Namespace, job.Name, jobStage)
413418
} else if podStatus == v1.PodFailed {
414419
newConditionType = sednav1.LLJobStageCondFailed
415420
klog.V(2).Infof("lifelonglearning job %v/%v %v stage failed!", job.Namespace, job.Name, jobStage)
@@ -491,6 +496,25 @@ func (c *Controller) getSpecifiedPods(job *sednav1.LifelongLearningJob, podType
491496
return latestPod
492497
}
493498

499+
func (c *Controller) getHas256(target interface{}) string {
500+
h := sha256.New()
501+
h.Write([]byte(fmt.Sprintf("%v", target)))
502+
return fmt.Sprintf("%x", h.Sum(nil))
503+
}
504+
505+
func (c *Controller) addJobToCache(job *sednav1.LifelongLearningJob) {
506+
c.lruExpireCache.Add(c.getHas256(job.Status), job, 10*time.Second)
507+
}
508+
509+
func (c *Controller) hasJobInCache(job *sednav1.LifelongLearningJob) bool {
510+
_, ok := c.lruExpireCache.Get(c.getHas256(job.Status))
511+
if !ok {
512+
return false
513+
}
514+
515+
return true
516+
}
517+
494518
func (c *Controller) restartInferPod(job *sednav1.LifelongLearningJob) error {
495519
inferPod := c.getSpecifiedPods(job, runtime.InferencePodType)
496520
if inferPod == nil {
@@ -542,6 +566,18 @@ func IsJobFinished(j *sednav1.LifelongLearningJob) bool {
542566
return false
543567
}
544568

569+
func (c *Controller) addPodAnnotations(spec *v1.PodTemplateSpec, key string, value string) {
570+
ann := spec.GetAnnotations()
571+
if ann == nil {
572+
ann = make(map[string]string)
573+
}
574+
575+
if _, ok := ann[key]; !ok {
576+
ann[key] = value
577+
spec.SetAnnotations(ann)
578+
}
579+
}
580+
545581
func (c *Controller) createPod(job *sednav1.LifelongLearningJob, podtype sednav1.LLJobStage) (err error) {
546582
ctx := context.Background()
547583
var podTemplate *v1.PodTemplateSpec
@@ -592,12 +628,20 @@ func (c *Controller) createPod(job *sednav1.LifelongLearningJob, podtype sednav1
592628
}
593629

594630
var workerParam *runtime.WorkerParam = new(runtime.WorkerParam)
631+
595632
if podtype == sednav1.LLJobTrain {
596-
workerParam.WorkerType = "Train"
633+
workerParam.WorkerType = runtime.TrainPodType
597634

598635
podTemplate = &job.Spec.TrainSpec.Template
599636
// Env parameters for train
600637

638+
c.addPodAnnotations(podTemplate, "type", workerParam.WorkerType)
639+
c.addPodAnnotations(podTemplate, "data", dataURL)
640+
datasetUseInitializer := true
641+
if podTemplate.Spec.NodeName == VirtualKubeletNode {
642+
datasetUseInitializer = false
643+
}
644+
601645
workerParam.Env = map[string]string{
602646
"NAMESPACE": job.Namespace,
603647
"JOB_NAME": job.Name,
@@ -621,7 +665,7 @@ func (c *Controller) createPod(job *sednav1.LifelongLearningJob, podtype sednav1
621665
URL: &runtime.MountURL{
622666
URL: dataURL,
623667
Secret: jobSecret,
624-
DownloadByInitializer: true,
668+
DownloadByInitializer: datasetUseInitializer,
625669
},
626670
EnvName: "TRAIN_DATASET_URL",
627671
},
@@ -632,14 +676,25 @@ func (c *Controller) createPod(job *sednav1.LifelongLearningJob, podtype sednav1
632676
Secret: datasetSecret,
633677
URL: originalDataURLOrIndex,
634678
Indirect: dataset.Spec.URL != originalDataURLOrIndex,
635-
DownloadByInitializer: true,
679+
DownloadByInitializer: datasetUseInitializer,
636680
},
637681
EnvName: "ORIGINAL_DATASET_URL",
638682
},
639683
)
640684
} else {
641685
podTemplate = &job.Spec.EvalSpec.Template
642-
workerParam.WorkerType = "Eval"
686+
workerParam.WorkerType = runtime.EvalPodType
687+
688+
c.addPodAnnotations(podTemplate, "type", workerParam.WorkerType)
689+
c.addPodAnnotations(podTemplate, "data", dataURL)
690+
datasetUseInitializer := true
691+
if podTemplate.Spec.NodeName == VirtualKubeletNode {
692+
datasetUseInitializer = false
693+
}
694+
modelUseInitializer := true
695+
if podTemplate.Spec.NodeName == VirtualKubeletNode {
696+
modelUseInitializer = false
697+
}
643698

644699
// Configure Env information for eval by initial WorkerParam
645700
workerParam.Env = map[string]string{
@@ -656,7 +711,7 @@ func (c *Controller) createPod(job *sednav1.LifelongLearningJob, podtype sednav1
656711
modelMountURLs = append(modelMountURLs, runtime.MountURL{
657712
URL: url,
658713
Secret: jobSecret,
659-
DownloadByInitializer: true,
714+
DownloadByInitializer: modelUseInitializer,
660715
})
661716
}
662717
workerParam.Mounts = append(workerParam.Mounts,
@@ -679,7 +734,7 @@ func (c *Controller) createPod(job *sednav1.LifelongLearningJob, podtype sednav1
679734
URL: &runtime.MountURL{
680735
URL: dataURL,
681736
Secret: datasetSecret,
682-
DownloadByInitializer: true,
737+
DownloadByInitializer: datasetUseInitializer,
683738
},
684739
Name: "datasets",
685740
EnvName: "TEST_DATASET_URL",
@@ -689,7 +744,7 @@ func (c *Controller) createPod(job *sednav1.LifelongLearningJob, podtype sednav1
689744
URL: &runtime.MountURL{
690745
Secret: datasetSecret,
691746
URL: originalDataURLOrIndex,
692-
DownloadByInitializer: true,
747+
DownloadByInitializer: datasetUseInitializer,
693748
Indirect: dataset.Spec.URL != originalDataURLOrIndex,
694749
},
695750
Name: "origin-dataset",
@@ -744,6 +799,7 @@ func (c *Controller) createInferPod(job *sednav1.LifelongLearningJob) error {
744799
}
745800

746801
workerParam.WorkerType = runtime.InferencePodType
802+
c.addPodAnnotations(&job.Spec.DeploySpec.Template, "type", workerParam.WorkerType)
747803
workerParam.HostNetwork = true
748804

749805
// create edge pod
@@ -764,10 +820,11 @@ func New(cc *runtime.ControllerContext) (runtime.FeatureControllerI, error) {
764820
eventBroadcaster.StartRecordingToSink(&v1core.EventSinkImpl{Interface: cc.KubeClient.CoreV1().Events("")})
765821

766822
jc := &Controller{
767-
kubeClient: cc.KubeClient,
768-
client: cc.SednaClient.SednaV1alpha1(),
769-
queue: workqueue.NewNamedRateLimitingQueue(workqueue.NewItemExponentialFailureRateLimiter(runtime.DefaultBackOff, runtime.MaxBackOff), Name),
770-
cfg: cfg,
823+
kubeClient: cc.KubeClient,
824+
client: cc.SednaClient.SednaV1alpha1(),
825+
queue: workqueue.NewNamedRateLimitingQueue(workqueue.NewItemExponentialFailureRateLimiter(runtime.DefaultBackOff, runtime.MaxBackOff), Name),
826+
cfg: cfg,
827+
lruExpireCache: lruexpirecache.NewLRUExpireCache(10),
771828
}
772829

773830
jobInformer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{

Diff for: pkg/localcontroller/managers/lifelonglearning/lifelonglearningjob.go

+64
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ import (
2020
"bufio"
2121
"encoding/json"
2222
"fmt"
23+
"io/ioutil"
2324
"os"
2425
"path"
26+
"path/filepath"
2527
"strconv"
2628
"strings"
2729
"sync"
@@ -61,6 +63,9 @@ const (
6163
AnnotationsRoundsKey = "sedna.io/rounds"
6264
AnnotationsNumberOfSamplesKey = "sedna.io/number-of-samples"
6365
AnnotationsDataFileOfEvalKey = "sedna.io/data-file-of-eval"
66+
67+
// WorkerS3StatusHandlerIntervalSeconds is interval time of handling s3 status of worker
68+
WorkerS3StatusHandlerIntervalSeconds = 30
6469
)
6570

6671
// LifelongLearningJobManager defines lifelong-learning-job Manager
@@ -257,6 +262,8 @@ func (lm *Manager) trainTask(job *Job) error {
257262
// continue anyway
258263
}
259264

265+
go lm.monitorS3Worker(job, sednav1.LLJobTrain)
266+
260267
jobConfig.TrainTriggerStatus = TriggerCompletedStatus
261268
klog.Infof("job(name=%s) complete the %sing phase triggering task successfully",
262269
jobConfig.UniqueIdentifier, jobStage)
@@ -297,6 +304,8 @@ func (lm *Manager) evalTask(job *Job) error {
297304

298305
forwardSamples(jobConfig, jobStage)
299306

307+
go lm.monitorS3Worker(job, sednav1.LLJobEval)
308+
300309
jobConfig.EvalTriggerStatus = TriggerCompletedStatus
301310
klog.Infof("job(%s) completed the %sing phase triggering task successfully",
302311
jobConfig.UniqueIdentifier, jobStage)
@@ -968,7 +977,62 @@ func (lm *Manager) monitorWorker() {
968977
if err := lm.Client.WriteMessage(msg, job.getHeader()); err != nil {
969978
klog.Errorf("job(%s) failed to write message: %v", name, err)
970979
continue
980+
} else {
981+
klog.Infof("job(%s) write message(%v) to GM", name, msg)
982+
}
983+
}
984+
}
985+
986+
func (lm *Manager) monitorS3Worker(job *Job, stage sednav1.LLJobStage) {
987+
jobConfig := job.JobConfig
988+
var statusFile string
989+
switch stage {
990+
case sednav1.LLJobTrain:
991+
statusFile = strings.Join([]string{jobConfig.OutputConfig.TrainOutput, strconv.Itoa(jobConfig.Rounds), "status.json"}, "/")
992+
case sednav1.LLJobEval:
993+
statusFile = strings.Join([]string{jobConfig.OutputConfig.EvalOutput, strconv.Itoa(jobConfig.Rounds), "status.json"}, "/")
994+
}
995+
996+
tempLocalFile := filepath.Join(os.TempDir(), "status.json")
997+
for {
998+
time.Sleep(WorkerS3StatusHandlerIntervalSeconds * time.Second)
999+
localFile, err := jobConfig.Storage.Download(statusFile, tempLocalFile)
1000+
if err != nil {
1001+
continue
9711002
}
1003+
1004+
bytes, _ := ioutil.ReadFile(localFile)
1005+
workerMessage := workertypes.MessageContent{}
1006+
err = json.Unmarshal(bytes, &workerMessage)
1007+
if err != nil {
1008+
continue
1009+
}
1010+
1011+
wo := clienttypes.Output{}
1012+
wo.Models = workerMessage.Results
1013+
wo.OwnerInfo = workerMessage.OwnerInfo
1014+
1015+
msg := &clienttypes.UpstreamMessage{
1016+
Phase: workerMessage.Kind,
1017+
Status: workerMessage.Status,
1018+
Output: &wo,
1019+
}
1020+
1021+
name := util.GetUniqueIdentifier(workerMessage.Namespace, workerMessage.OwnerName, workerMessage.OwnerKind)
1022+
if err := lm.Client.WriteMessage(msg, job.getHeader()); err != nil {
1023+
klog.Errorf("job(%s) failed to write message: %v", name, err)
1024+
continue
1025+
}
1026+
1027+
if err = jobConfig.Storage.DeleteFile(statusFile); err != nil {
1028+
continue
1029+
}
1030+
1031+
if err = jobConfig.Storage.DeleteFile(tempLocalFile); err != nil {
1032+
continue
1033+
}
1034+
1035+
break
9721036
}
9731037
}
9741038

Diff for: pkg/localcontroller/storage/minio.go

+17
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,20 @@ func (mc *MinioClient) parseURL(URL string) (string, string, error) {
145145

146146
return "", "", fmt.Errorf("invalid url(%s)", URL)
147147
}
148+
149+
// deleteFile deletes file
150+
func (mc *MinioClient) deleteFile(objectURL string) error {
151+
bucket, absPath, err := mc.parseURL(objectURL)
152+
if err != nil {
153+
return err
154+
}
155+
156+
ctx, cancel := context.WithTimeout(context.Background(), MaxTimeOut)
157+
defer cancel()
158+
159+
if err = mc.Client.RemoveObject(ctx, bucket, absPath, minio.RemoveObjectOptions{}); err != nil {
160+
return fmt.Errorf("delete file(url=%s) failed, error: %+v", objectURL, err)
161+
}
162+
163+
return nil
164+
}

Diff for: pkg/localcontroller/storage/storage.go

+17
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"encoding/json"
2121
"fmt"
2222
"net/url"
23+
"os"
2324
"path"
2425
"path/filepath"
2526

@@ -232,3 +233,19 @@ func (s *Storage) CopyFile(srcURL string, objectURL string) error {
232233

233234
return nil
234235
}
236+
237+
// DeleteFile deletes file
238+
func (s *Storage) DeleteFile(objectURL string) error {
239+
prefix, err := s.CheckURL(objectURL)
240+
if err != nil {
241+
return err
242+
}
243+
switch prefix {
244+
case S3Prefix:
245+
return s.MinioClient.deleteFile(objectURL)
246+
case LocalPrefix:
247+
return os.Remove(objectURL)
248+
default:
249+
return fmt.Errorf("invalid url(%s)", objectURL)
250+
}
251+
}

0 commit comments

Comments
 (0)