16
16
import re
17
17
from enum import Enum
18
18
from typing import Optional
19
+ from packaging .version import Version
19
20
20
21
from sagemaker import utils
21
- from sagemaker .image_uris import _validate_version_and_set_if_needed , _version_for_config , \
22
- _config_for_framework_and_scope , _validate_py_version_and_set_if_needed , _registry_from_region , ECR_URI_TEMPLATE , \
23
- _get_latest_versions , _validate_instance_deprecation , _get_image_tag , _validate_arg
24
- from packaging .version import Version
22
+ from sagemaker .image_uris import (
23
+ _validate_version_and_set_if_needed ,
24
+ _version_for_config ,
25
+ _config_for_framework_and_scope ,
26
+ _validate_py_version_and_set_if_needed ,
27
+ _registry_from_region ,
28
+ ECR_URI_TEMPLATE ,
29
+ _get_latest_versions ,
30
+ _validate_instance_deprecation ,
31
+ _get_image_tag ,
32
+ _validate_arg ,
33
+ )
25
34
26
35
DEFAULT_TOLERATE_MODEL = False
27
36
28
37
29
38
class Framework (Enum ):
39
+ """Framework enum class."""
40
+
30
41
HUGGING_FACE = "huggingface"
31
42
HUGGING_FACE_NEURON = "huggingface-neuron"
32
43
HUGGING_FACE_NEURON_X = "huggingface-neuronx"
@@ -46,12 +57,16 @@ class Framework(Enum):
46
57
47
58
48
59
class ImageScope (Enum ):
60
+ """ImageScope enum class."""
61
+
49
62
TRAINING = "training"
50
63
INFERENCE = "inference"
51
64
INFERENCE_GRAVITON = "inference-graviton"
52
65
53
66
54
67
class Processor (Enum ):
68
+ """Processor enum class."""
69
+
55
70
INF = "inf"
56
71
NEURON = "neuron"
57
72
GPU = "gpu"
@@ -60,22 +75,53 @@ class Processor(Enum):
60
75
61
76
62
77
class ImageSpec :
63
- """ImageSpec class to get image URI for a specific framework version."""
64
-
65
- def __init__ (self ,
66
- framework : Framework ,
67
- processor : Optional [Processor ] = Processor .CPU ,
68
- region : Optional [str ] = "us-west-2" ,
69
- version = None ,
70
- py_version = None ,
71
- instance_type = None ,
72
- accelerator_type = None ,
73
- image_scope : ImageScope = ImageScope .TRAINING ,
74
- container_version = None ,
75
- distribution = None ,
76
- base_framework_version = None ,
77
- sdk_version = None ,
78
- inference_tool = None ):
78
+ """ImageSpec class to get image URI for a specific framework version.
79
+
80
+ Attributes:
81
+ framework (Framework): The name of the framework or algorithm.
82
+ processor (Processor): The name of the processor (CPU, GPU, etc.).
83
+ region (str): The AWS region.
84
+ version (str): The framework or algorithm version. This is required if there is
85
+ more than one supported version for the given framework or algorithm.
86
+ py_version (str): The Python version. This is required if there is
87
+ more than one supported Python version for the given framework version.
88
+ instance_type (str): The SageMaker instance type. For supported types, see
89
+ https://aws.amazon.com/sagemaker/pricing. This is required if
90
+ there are different images for different processor types.
91
+ accelerator_type (str): Elastic Inference accelerator type. For more, see
92
+ https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
93
+ image_scope (str): The image type, i.e. what it is used for.
94
+ Valid values: "training", "inference", "inference_graviton", "eia".
95
+ If ``accelerator_type`` is set, ``image_scope`` is ignored.
96
+ container_version (str): the version of docker image.
97
+ Ideally the value of parameter should be created inside the framework.
98
+ For custom use, see the list of supported container versions:
99
+ https://github.com/aws/deep-learning-containers/blob/master/available_images.md
100
+ (default: None).
101
+ distribution (dict): A dictionary with information on how to run distributed training
102
+ sdk_version (str): the version of python-sdk that will be used in the image retrieval.
103
+ (default: None).
104
+ inference_tool (str): the tool that will be used to aid in the inference.
105
+ Valid values: "neuron, neuronx, None"
106
+ (default: None).
107
+ """
108
+
109
+ def __init__ (
110
+ self ,
111
+ framework : Framework ,
112
+ processor : Optional [Processor ] = Processor .CPU ,
113
+ region : Optional [str ] = "us-west-2" ,
114
+ version = None ,
115
+ py_version = None ,
116
+ instance_type = None ,
117
+ accelerator_type = None ,
118
+ image_scope : ImageScope = ImageScope .TRAINING ,
119
+ container_version = None ,
120
+ distribution = None ,
121
+ base_framework_version = None ,
122
+ sdk_version = None ,
123
+ inference_tool = None ,
124
+ ):
79
125
self .framework = framework
80
126
self .processor = processor
81
127
self .version = version
@@ -91,44 +137,14 @@ def __init__(self,
91
137
self .inference_tool = inference_tool
92
138
93
139
def update_image_spec (self , ** kwargs ):
140
+ """Update the ImageSpec object with the given arguments."""
94
141
for key , value in kwargs .items ():
95
142
if hasattr (self , key ):
96
143
setattr (self , key , value )
97
144
98
145
def retrieve (self ) -> str :
99
146
"""Retrieves the ECR URI for the Docker image matching the given arguments.
100
147
101
- Ideally this function should not be called directly, rather it should be called from the
102
- fit() function inside framework estimator.
103
-
104
- Args:
105
- framework (Framework): The name of the framework or algorithm.
106
- processor (Processor): The name of the processor (CPU, GPU, etc.).
107
- region (str): The AWS region.
108
- version (str): The framework or algorithm version. This is required if there is
109
- more than one supported version for the given framework or algorithm.
110
- py_version (str): The Python version. This is required if there is
111
- more than one supported Python version for the given framework version.
112
- instance_type (str): The SageMaker instance type. For supported types, see
113
- https://aws.amazon.com/sagemaker/pricing. This is required if
114
- there are different images for different processor types.
115
- accelerator_type (str): Elastic Inference accelerator type. For more, see
116
- https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
117
- image_scope (str): The image type, i.e. what it is used for.
118
- Valid values: "training", "inference", "inference_graviton", "eia".
119
- If ``accelerator_type`` is set, ``image_scope`` is ignored.
120
- container_version (str): the version of docker image.
121
- Ideally the value of parameter should be created inside the framework.
122
- For custom use, see the list of supported container versions:
123
- https://github.com/aws/deep-learning-containers/blob/master/available_images.md
124
- (default: None).
125
- distribution (dict): A dictionary with information on how to run distributed training
126
- sdk_version (str): the version of python-sdk that will be used in the image retrieval.
127
- (default: None).
128
- inference_tool (str): the tool that will be used to aid in the inference.
129
- Valid values: "neuron, neuronx, None"
130
- (default: None).
131
-
132
148
Returns:
133
149
str: The ECR URI for the corresponding SageMaker Docker image.
134
150
@@ -140,13 +156,14 @@ def retrieve(self) -> str:
140
156
known security vulnerabilities.
141
157
DeprecatedJumpStartModelError: If the version of the model is deprecated.
142
158
"""
143
- config = _config_for_framework_and_scope (self .framework .value ,
144
- self .image_scope .value ,
145
- self .accelerator_type )
146
-
159
+ config = _config_for_framework_and_scope (
160
+ self .framework .value , self .image_scope .value , self .accelerator_type
161
+ )
147
162
original_version = self .version
148
163
try :
149
- version = _validate_version_and_set_if_needed (self .version , config , self .framework .value )
164
+ version = _validate_version_and_set_if_needed (
165
+ self .version , config , self .framework .value
166
+ )
150
167
except ValueError :
151
168
version = None
152
169
if not version :
@@ -159,12 +176,14 @@ def retrieve(self) -> str:
159
176
full_base_framework_version = version_config ["version_aliases" ].get (
160
177
self .base_framework_version , self .base_framework_version
161
178
)
162
- _validate_arg (full_base_framework_version , list (version_config .keys ()), "base framework" )
179
+ _validate_arg (
180
+ full_base_framework_version , list (version_config .keys ()), "base framework"
181
+ )
163
182
version_config = version_config .get (full_base_framework_version )
164
183
165
- self .py_version = _validate_py_version_and_set_if_needed (self . py_version ,
166
- version_config ,
167
- self . framework . value )
184
+ self .py_version = _validate_py_version_and_set_if_needed (
185
+ self . py_version , version_config , self . framework . value
186
+ )
168
187
version_config = version_config .get (self .py_version ) or version_config
169
188
170
189
registry = _registry_from_region (self .region , version_config ["registries" ])
@@ -206,16 +225,18 @@ def retrieve(self) -> str:
206
225
if config .get ("version_aliases" ).get (original_version ):
207
226
_version = config .get ("version_aliases" )[original_version ]
208
227
if (
209
- config .get ("versions" , {})
210
- .get (_version , {})
211
- .get ("version_aliases" , {})
212
- .get (self .base_framework_version , {})
228
+ config .get ("versions" , {})
229
+ .get (_version , {})
230
+ .get ("version_aliases" , {})
231
+ .get (self .base_framework_version , {})
213
232
):
214
233
_base_framework_version = config .get ("versions" )[_version ]["version_aliases" ][
215
234
self .base_framework_version
216
235
]
217
236
pt_or_tf_version = (
218
- re .compile ("^(pytorch|tensorflow)(.*)$" ).match (_base_framework_version ).group (2 )
237
+ re .compile ("^(pytorch|tensorflow)(.*)$" )
238
+ .match (_base_framework_version )
239
+ .group (2 )
219
240
)
220
241
221
242
tag_prefix = f"{ pt_or_tf_version } -transformers{ _version } "
@@ -224,29 +245,28 @@ def retrieve(self) -> str:
224
245
225
246
if repo == f"{ self .framework .value } -inference-graviton" :
226
247
self .container_version = f"{ self .container_version } -sagemaker"
227
- _validate_instance_deprecation (self .framework ,
228
- self .instance_type ,
229
- version )
248
+ _validate_instance_deprecation (self .framework , self .instance_type , version )
230
249
231
250
tag = _get_image_tag (
232
251
self .container_version ,
233
252
self .distribution ,
234
253
self .image_scope .value ,
235
- self .framework ,
254
+ self .framework . value ,
236
255
self .inference_tool ,
237
256
self .instance_type ,
238
257
self .processor .value ,
239
258
self .py_version ,
240
259
tag_prefix ,
241
- version )
260
+ version ,
261
+ )
242
262
243
263
if tag :
244
264
repo += ":{}" .format (tag )
245
265
246
266
return ECR_URI_TEMPLATE .format (registry = registry , hostname = hostname , repository = repo )
247
267
248
- def _fetch_latest_version_from_config (self ,
249
- framework_config : dict ) -> str :
268
+ def _fetch_latest_version_from_config (self , framework_config : dict ) -> str :
269
+ """Fetches the latest version from the framework config."""
250
270
if self .image_scope .value in framework_config :
251
271
if image_scope_config := framework_config [self .image_scope .value ]:
252
272
if version_aliases := image_scope_config ["version_aliases" ]:
0 commit comments