Skip to content

Commit a28c603

Browse files
schinmayeepintaoz-aws
authored andcommitted
Feature: Move image uris and git repos for training recipes to json (#1547)
1 parent e195b0b commit a28c603

File tree

3 files changed

+55
-16
lines changed

3 files changed

+55
-16
lines changed

src/sagemaker/pytorch/estimator.py

+44-14
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
import json
1617
import logging
1718
import math
1819
import os
@@ -35,9 +36,11 @@
3536
profiler_config_deprecation_warning,
3637
)
3738
from sagemaker.git_utils import _run_clone_command
39+
from sagemaker.image_uris import retrieve
3840
from sagemaker.pytorch import defaults
3941
from sagemaker.pytorch.model import PyTorchModel
4042
from sagemaker.pytorch.training_compiler.config import TrainingCompilerConfig
43+
from sagemaker.session import Session
4144
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
4245
from sagemaker.workflow.entities import PipelineVariable
4346

@@ -67,15 +70,6 @@ class PyTorch(Framework):
6770

6871
# [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve
6972
# to retrieve the image uri below before GA.
70-
SM_ADAPTER_REPO = "[email protected]:aws/private-sagemaker-training-adapter-for-nemo-staging.git"
71-
SM_LAUNCHER_REPO = "[email protected]:aws/private-sagemaker-training-launcher-staging.git"
72-
SM_TRAINING_RECIPE_GPU_IMG = (
73-
"855988369404.dkr.ecr.us-west-2.amazonaws.com/chinmayee-dev:adaptor_sept9_v1"
74-
)
75-
SM_NEURONX_DIST_REPO = "https://github.com/aws-neuron/neuronx-distributed-training.git"
76-
SM_NEURONX_DIST_IMG = (
77-
"855988369404.dkr.ecr.us-west-2.amazonaws.com/chinmayee-dev:neuron_sept26_v1"
78-
)
7973

8074
def __init__(
8175
self,
@@ -561,6 +555,16 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
561555
dict containing arg values for estimator initialization and setup.
562556
563557
"""
558+
if kwargs.get("sagemaker_session") is not None:
559+
region_name = kwargs.get("sagemaker_session").boto_region_name
560+
else:
561+
region_name = Session().boto_region_name
562+
training_recipes_cfg_filename = os.path.join(
563+
os.path.dirname(__file__), "training_recipes.json"
564+
)
565+
with open(training_recipes_cfg_filename) as training_recipes_cfg_file:
566+
training_recipes_cfg = json.load(training_recipes_cfg_file)
567+
564568
if recipe_overrides is None:
565569
recipe_overrides = dict()
566570
cls.recipe_train_dir = tempfile.TemporaryDirectory(prefix="training_")
@@ -580,7 +584,9 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
580584
f"Could not fetch the provided recipe {training_recipe}: exception {str(e)}"
581585
)
582586
else:
583-
launcher_repo = os.environ.get("training_launcher_git", None) or cls.SM_LAUNCHER_REPO
587+
launcher_repo = os.environ.get(
588+
"training_launcher_git", None
589+
) or training_recipes_cfg.get("launcher_repo")
584590
_run_clone_command(launcher_repo, cls.recipe_launcher_dir.name)
585591
recipe = os.path.join(
586592
cls.recipe_launcher_dir.name,
@@ -629,7 +635,9 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
629635
# [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve
630636
# to retrieve the image uri below before we go GA.
631637
if device_type == "gpu":
632-
adapter_repo = os.environ.get("training_adapter_git", None) or cls.SM_ADAPTER_REPO
638+
adapter_repo = os.environ.get("training_adapter_git", None) or training_recipes_cfg.get(
639+
"adapter_repo"
640+
)
633641
_run_clone_command(adapter_repo, cls.recipe_train_dir.name)
634642

635643
model_type_to_entry = {
@@ -650,7 +658,17 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
650658
cls.recipe_train_dir.name, "examples", model_type_to_entry[model_type][0]
651659
)
652660
args["entry_point"] = model_type_to_entry[model_type][1]
653-
args["default_image_uri"] = cls.SM_TRAINING_RECIPE_GPU_IMG
661+
gpu_image_cfg = training_recipes_cfg.get("gpu_image")
662+
if isinstance(gpu_image_cfg, str):
663+
args["default_image_uri"] = gpu_image_cfg
664+
else:
665+
args["default_image_uri"] = retrieve(
666+
gpu_image_cfg.get("framework"),
667+
region=region_name,
668+
version=gpu_image_cfg.get("version"),
669+
image_scope="training",
670+
**gpu_image_cfg.get("additional_args"),
671+
)
654672
smp_options = {
655673
"enabled": True,
656674
"parameters": {
@@ -662,10 +680,22 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
662680
"torch_distributed": {"enabled": True},
663681
}
664682
elif device_type == "trainium":
665-
_run_clone_command(cls.SM_NEURONX_DIST_REPO, cls.recipe_train_dir.name)
683+
_run_clone_command(
684+
training_recipes_cfg.get("neuron_dist_repo"), cls.recipe_train_dir.name
685+
)
666686
args["source_dir"] = os.path.join(cls.recipe_train_dir.name, "examples")
667687
args["entry_point"] = "training_orchestrator.py"
668-
args["default_image_uri"] = cls.SM_NEURONX_DIST_IMG
688+
neuron_image_cfg = training_recipes_cfg.get("neuron_image")
689+
if isinstance(neuron_image_cfg, str):
690+
args["default_image_uri"] = neuron_image_cfg
691+
else:
692+
args["default_image_uri"] = retrieve(
693+
neuron_image_cfg.get("framework"),
694+
region=region_name,
695+
version=neuron_image_cfg.get("version"),
696+
image_scope="training",
697+
**neuron_image_cfg.get("additional_args"),
698+
)
669699
args["distribution"] = {
670700
"torch_distributed": {"enabled": True},
671701
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"adapter_repo": "[email protected]:aws/private-sagemaker-training-adapter-for-nemo-staging.git",
3+
"launcher_repo": "[email protected]:aws/private-sagemaker-training-launcher-staging.git",
4+
"neuron_dist_repo": "https://github.com/aws-neuron/neuronx-distributed-training.git",
5+
"gpu_image" : {
6+
"framework": "pytorch-smp",
7+
"version": "2.3.1",
8+
"additional_args": {}
9+
},
10+
"neuron_image": "855988369404.dkr.ecr.us-west-2.amazonaws.com/chinmayee-dev:neuron_sept26_v1"
11+
}

tests/unit/test_pytorch.py

-2
Original file line numberDiff line numberDiff line change
@@ -902,7 +902,6 @@ def test_training_recipe_for_gpu(sagemaker_session, recipe, model):
902902

903903
assert pytorch.source_dir == os.path.join(pytorch.recipe_train_dir.name, "examples", model)
904904
assert pytorch.entry_point == f"{model}_pretrain.py"
905-
assert pytorch.image_uri == pytorch.SM_TRAINING_RECIPE_GPU_IMG
906905
expected_distribution = {
907906
"torch_distributed": {
908907
"enabled": True,
@@ -982,7 +981,6 @@ def test_training_recipe_for_trainium(sagemaker_session):
982981

983982
assert pytorch.source_dir == os.path.join(pytorch.recipe_train_dir.name, "examples")
984983
assert pytorch.entry_point == "training_orchestrator.py"
985-
assert pytorch.image_uri == pytorch.SM_NEURONX_DIST_IMG
986984
expected_distribution = {
987985
"torch_distributed": {
988986
"enabled": True,

0 commit comments

Comments
 (0)