13
13
"""Placeholder docstring"""
14
14
from __future__ import absolute_import
15
15
16
+ import json
16
17
import logging
17
18
import math
18
19
import os
35
36
profiler_config_deprecation_warning ,
36
37
)
37
38
from sagemaker .git_utils import _run_clone_command
39
+ from sagemaker .image_uris import retrieve
38
40
from sagemaker .pytorch import defaults
39
41
from sagemaker .pytorch .model import PyTorchModel
40
42
from sagemaker .pytorch .training_compiler .config import TrainingCompilerConfig
43
+ from sagemaker .session import Session
41
44
from sagemaker .vpc_utils import VPC_CONFIG_DEFAULT
42
45
from sagemaker .workflow .entities import PipelineVariable
43
46
@@ -67,15 +70,6 @@ class PyTorch(Framework):
67
70
68
71
# [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve
69
72
# to retrieve the image uri below before GA.
70
- SM_ADAPTER_REPO = "[email protected] :aws/private-sagemaker-training-adapter-for-nemo-staging.git"
71
- SM_LAUNCHER_REPO = "[email protected] :aws/private-sagemaker-training-launcher-staging.git"
72
- SM_TRAINING_RECIPE_GPU_IMG = (
73
- "855988369404.dkr.ecr.us-west-2.amazonaws.com/chinmayee-dev:adaptor_sept9_v1"
74
- )
75
- SM_NEURONX_DIST_REPO = "https://github.com/aws-neuron/neuronx-distributed-training.git"
76
- SM_NEURONX_DIST_IMG = (
77
- "855988369404.dkr.ecr.us-west-2.amazonaws.com/chinmayee-dev:neuron_sept26_v1"
78
- )
79
73
80
74
def __init__ (
81
75
self ,
@@ -561,6 +555,16 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
561
555
dict containing arg values for estimator initialization and setup.
562
556
563
557
"""
558
+ if kwargs .get ("sagemaker_session" ) is not None :
559
+ region_name = kwargs .get ("sagemaker_session" ).boto_region_name
560
+ else :
561
+ region_name = Session ().boto_region_name
562
+ training_recipes_cfg_filename = os .path .join (
563
+ os .path .dirname (__file__ ), "training_recipes.json"
564
+ )
565
+ with open (training_recipes_cfg_filename ) as training_recipes_cfg_file :
566
+ training_recipes_cfg = json .load (training_recipes_cfg_file )
567
+
564
568
if recipe_overrides is None :
565
569
recipe_overrides = dict ()
566
570
cls .recipe_train_dir = tempfile .TemporaryDirectory (prefix = "training_" )
@@ -580,7 +584,9 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
580
584
f"Could not fetch the provided recipe { training_recipe } : exception { str (e )} "
581
585
)
582
586
else :
583
- launcher_repo = os .environ .get ("training_launcher_git" , None ) or cls .SM_LAUNCHER_REPO
587
+ launcher_repo = os .environ .get (
588
+ "training_launcher_git" , None
589
+ ) or training_recipes_cfg .get ("launcher_repo" )
584
590
_run_clone_command (launcher_repo , cls .recipe_launcher_dir .name )
585
591
recipe = os .path .join (
586
592
cls .recipe_launcher_dir .name ,
@@ -629,7 +635,9 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
629
635
# [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve
630
636
# to retrieve the image uri below before we go GA.
631
637
if device_type == "gpu" :
632
- adapter_repo = os .environ .get ("training_adapter_git" , None ) or cls .SM_ADAPTER_REPO
638
+ adapter_repo = os .environ .get ("training_adapter_git" , None ) or training_recipes_cfg .get (
639
+ "adapter_repo"
640
+ )
633
641
_run_clone_command (adapter_repo , cls .recipe_train_dir .name )
634
642
635
643
model_type_to_entry = {
@@ -650,7 +658,17 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
650
658
cls .recipe_train_dir .name , "examples" , model_type_to_entry [model_type ][0 ]
651
659
)
652
660
args ["entry_point" ] = model_type_to_entry [model_type ][1 ]
653
- args ["default_image_uri" ] = cls .SM_TRAINING_RECIPE_GPU_IMG
661
+ gpu_image_cfg = training_recipes_cfg .get ("gpu_image" )
662
+ if isinstance (gpu_image_cfg , str ):
663
+ args ["default_image_uri" ] = gpu_image_cfg
664
+ else :
665
+ args ["default_image_uri" ] = retrieve (
666
+ gpu_image_cfg .get ("framework" ),
667
+ region = region_name ,
668
+ version = gpu_image_cfg .get ("version" ),
669
+ image_scope = "training" ,
670
+ ** gpu_image_cfg .get ("additional_args" ),
671
+ )
654
672
smp_options = {
655
673
"enabled" : True ,
656
674
"parameters" : {
@@ -662,10 +680,22 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
662
680
"torch_distributed" : {"enabled" : True },
663
681
}
664
682
elif device_type == "trainium" :
665
- _run_clone_command (cls .SM_NEURONX_DIST_REPO , cls .recipe_train_dir .name )
683
+ _run_clone_command (
684
+ training_recipes_cfg .get ("neuron_dist_repo" ), cls .recipe_train_dir .name
685
+ )
666
686
args ["source_dir" ] = os .path .join (cls .recipe_train_dir .name , "examples" )
667
687
args ["entry_point" ] = "training_orchestrator.py"
668
- args ["default_image_uri" ] = cls .SM_NEURONX_DIST_IMG
688
+ neuron_image_cfg = training_recipes_cfg .get ("neuron_image" )
689
+ if isinstance (neuron_image_cfg , str ):
690
+ args ["default_image_uri" ] = neuron_image_cfg
691
+ else :
692
+ args ["default_image_uri" ] = retrieve (
693
+ neuron_image_cfg .get ("framework" ),
694
+ region = region_name ,
695
+ version = neuron_image_cfg .get ("version" ),
696
+ image_scope = "training" ,
697
+ ** neuron_image_cfg .get ("additional_args" ),
698
+ )
669
699
args ["distribution" ] = {
670
700
"torch_distributed" : {"enabled" : True },
671
701
}
0 commit comments