Skip to content

Commit 4841929

Browse files
beniericpintaoz-aws
authored andcommitted
Add example notebook (#1528)
* Add testing notebook * format * use smaller data * remove large dataset * update * pylint * flake8 * ignore docstyle in directories with test * format * format
1 parent 632c10b commit 4841929

File tree

16 files changed

+1176
-92
lines changed

16 files changed

+1176
-92
lines changed

.pydocstylerc

+1
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
inherit = false
33
ignore = D104,D107,D202,D203,D213,D214,D400,D401,D404,D406,D407,D411,D413,D414,D415,D417
44
match = (?!record_pb2).*\.py
5+
match-dir = (?!.*test).*

src/sagemaker/modules/image_spec.py

+92-72
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,28 @@
1616
import re
1717
from enum import Enum
1818
from typing import Optional
19+
from packaging.version import Version
1920

2021
from sagemaker import utils
21-
from sagemaker.image_uris import _validate_version_and_set_if_needed, _version_for_config, \
22-
_config_for_framework_and_scope, _validate_py_version_and_set_if_needed, _registry_from_region, ECR_URI_TEMPLATE, \
23-
_get_latest_versions, _validate_instance_deprecation, _get_image_tag, _validate_arg
24-
from packaging.version import Version
22+
from sagemaker.image_uris import (
23+
_validate_version_and_set_if_needed,
24+
_version_for_config,
25+
_config_for_framework_and_scope,
26+
_validate_py_version_and_set_if_needed,
27+
_registry_from_region,
28+
ECR_URI_TEMPLATE,
29+
_get_latest_versions,
30+
_validate_instance_deprecation,
31+
_get_image_tag,
32+
_validate_arg,
33+
)
2534

2635
DEFAULT_TOLERATE_MODEL = False
2736

2837

2938
class Framework(Enum):
39+
"""Framework enum class."""
40+
3041
HUGGING_FACE = "huggingface"
3142
HUGGING_FACE_NEURON = "huggingface-neuron"
3243
HUGGING_FACE_NEURON_X = "huggingface-neuronx"
@@ -46,12 +57,16 @@ class Framework(Enum):
4657

4758

4859
class ImageScope(Enum):
60+
"""ImageScope enum class."""
61+
4962
TRAINING = "training"
5063
INFERENCE = "inference"
5164
INFERENCE_GRAVITON = "inference-graviton"
5265

5366

5467
class Processor(Enum):
68+
"""Processor enum class."""
69+
5570
INF = "inf"
5671
NEURON = "neuron"
5772
GPU = "gpu"
@@ -60,22 +75,53 @@ class Processor(Enum):
6075

6176

6277
class ImageSpec:
63-
"""ImageSpec class to get image URI for a specific framework version."""
64-
65-
def __init__(self,
66-
framework: Framework,
67-
processor: Optional[Processor] = Processor.CPU,
68-
region: Optional[str] = "us-west-2",
69-
version=None,
70-
py_version=None,
71-
instance_type=None,
72-
accelerator_type=None,
73-
image_scope: ImageScope = ImageScope.TRAINING,
74-
container_version=None,
75-
distribution=None,
76-
base_framework_version=None,
77-
sdk_version=None,
78-
inference_tool=None):
78+
"""ImageSpec class to get image URI for a specific framework version.
79+
80+
Attributes:
81+
framework (Framework): The name of the framework or algorithm.
82+
processor (Processor): The name of the processor (CPU, GPU, etc.).
83+
region (str): The AWS region.
84+
version (str): The framework or algorithm version. This is required if there is
85+
more than one supported version for the given framework or algorithm.
86+
py_version (str): The Python version. This is required if there is
87+
more than one supported Python version for the given framework version.
88+
instance_type (str): The SageMaker instance type. For supported types, see
89+
https://aws.amazon.com/sagemaker/pricing. This is required if
90+
there are different images for different processor types.
91+
accelerator_type (str): Elastic Inference accelerator type. For more, see
92+
https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
93+
image_scope (str): The image type, i.e. what it is used for.
94+
Valid values: "training", "inference", "inference_graviton", "eia".
95+
If ``accelerator_type`` is set, ``image_scope`` is ignored.
96+
container_version (str): the version of docker image.
97+
Ideally the value of parameter should be created inside the framework.
98+
For custom use, see the list of supported container versions:
99+
https://github.com/aws/deep-learning-containers/blob/master/available_images.md
100+
(default: None).
101+
distribution (dict): A dictionary with information on how to run distributed training
102+
sdk_version (str): the version of python-sdk that will be used in the image retrieval.
103+
(default: None).
104+
inference_tool (str): the tool that will be used to aid in the inference.
105+
Valid values: "neuron, neuronx, None"
106+
(default: None).
107+
"""
108+
109+
def __init__(
110+
self,
111+
framework: Framework,
112+
processor: Optional[Processor] = Processor.CPU,
113+
region: Optional[str] = "us-west-2",
114+
version=None,
115+
py_version=None,
116+
instance_type=None,
117+
accelerator_type=None,
118+
image_scope: ImageScope = ImageScope.TRAINING,
119+
container_version=None,
120+
distribution=None,
121+
base_framework_version=None,
122+
sdk_version=None,
123+
inference_tool=None,
124+
):
79125
self.framework = framework
80126
self.processor = processor
81127
self.version = version
@@ -91,44 +137,14 @@ def __init__(self,
91137
self.inference_tool = inference_tool
92138

93139
def update_image_spec(self, **kwargs):
140+
"""Update the ImageSpec object with the given arguments."""
94141
for key, value in kwargs.items():
95142
if hasattr(self, key):
96143
setattr(self, key, value)
97144

98145
def retrieve(self) -> str:
99146
"""Retrieves the ECR URI for the Docker image matching the given arguments.
100147
101-
Ideally this function should not be called directly, rather it should be called from the
102-
fit() function inside framework estimator.
103-
104-
Args:
105-
framework (Framework): The name of the framework or algorithm.
106-
processor (Processor): The name of the processor (CPU, GPU, etc.).
107-
region (str): The AWS region.
108-
version (str): The framework or algorithm version. This is required if there is
109-
more than one supported version for the given framework or algorithm.
110-
py_version (str): The Python version. This is required if there is
111-
more than one supported Python version for the given framework version.
112-
instance_type (str): The SageMaker instance type. For supported types, see
113-
https://aws.amazon.com/sagemaker/pricing. This is required if
114-
there are different images for different processor types.
115-
accelerator_type (str): Elastic Inference accelerator type. For more, see
116-
https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
117-
image_scope (str): The image type, i.e. what it is used for.
118-
Valid values: "training", "inference", "inference_graviton", "eia".
119-
If ``accelerator_type`` is set, ``image_scope`` is ignored.
120-
container_version (str): the version of docker image.
121-
Ideally the value of parameter should be created inside the framework.
122-
For custom use, see the list of supported container versions:
123-
https://github.com/aws/deep-learning-containers/blob/master/available_images.md
124-
(default: None).
125-
distribution (dict): A dictionary with information on how to run distributed training
126-
sdk_version (str): the version of python-sdk that will be used in the image retrieval.
127-
(default: None).
128-
inference_tool (str): the tool that will be used to aid in the inference.
129-
Valid values: "neuron, neuronx, None"
130-
(default: None).
131-
132148
Returns:
133149
str: The ECR URI for the corresponding SageMaker Docker image.
134150
@@ -140,13 +156,14 @@ def retrieve(self) -> str:
140156
known security vulnerabilities.
141157
DeprecatedJumpStartModelError: If the version of the model is deprecated.
142158
"""
143-
config = _config_for_framework_and_scope(self.framework.value,
144-
self.image_scope.value,
145-
self.accelerator_type)
146-
159+
config = _config_for_framework_and_scope(
160+
self.framework.value, self.image_scope.value, self.accelerator_type
161+
)
147162
original_version = self.version
148163
try:
149-
version = _validate_version_and_set_if_needed(self.version, config, self.framework.value)
164+
version = _validate_version_and_set_if_needed(
165+
self.version, config, self.framework.value
166+
)
150167
except ValueError:
151168
version = None
152169
if not version:
@@ -159,12 +176,14 @@ def retrieve(self) -> str:
159176
full_base_framework_version = version_config["version_aliases"].get(
160177
self.base_framework_version, self.base_framework_version
161178
)
162-
_validate_arg(full_base_framework_version, list(version_config.keys()), "base framework")
179+
_validate_arg(
180+
full_base_framework_version, list(version_config.keys()), "base framework"
181+
)
163182
version_config = version_config.get(full_base_framework_version)
164183

165-
self.py_version = _validate_py_version_and_set_if_needed(self.py_version,
166-
version_config,
167-
self.framework.value)
184+
self.py_version = _validate_py_version_and_set_if_needed(
185+
self.py_version, version_config, self.framework.value
186+
)
168187
version_config = version_config.get(self.py_version) or version_config
169188

170189
registry = _registry_from_region(self.region, version_config["registries"])
@@ -206,16 +225,18 @@ def retrieve(self) -> str:
206225
if config.get("version_aliases").get(original_version):
207226
_version = config.get("version_aliases")[original_version]
208227
if (
209-
config.get("versions", {})
210-
.get(_version, {})
211-
.get("version_aliases", {})
212-
.get(self.base_framework_version, {})
228+
config.get("versions", {})
229+
.get(_version, {})
230+
.get("version_aliases", {})
231+
.get(self.base_framework_version, {})
213232
):
214233
_base_framework_version = config.get("versions")[_version]["version_aliases"][
215234
self.base_framework_version
216235
]
217236
pt_or_tf_version = (
218-
re.compile("^(pytorch|tensorflow)(.*)$").match(_base_framework_version).group(2)
237+
re.compile("^(pytorch|tensorflow)(.*)$")
238+
.match(_base_framework_version)
239+
.group(2)
219240
)
220241

221242
tag_prefix = f"{pt_or_tf_version}-transformers{_version}"
@@ -224,29 +245,28 @@ def retrieve(self) -> str:
224245

225246
if repo == f"{self.framework.value}-inference-graviton":
226247
self.container_version = f"{self.container_version}-sagemaker"
227-
_validate_instance_deprecation(self.framework,
228-
self.instance_type,
229-
version)
248+
_validate_instance_deprecation(self.framework, self.instance_type, version)
230249

231250
tag = _get_image_tag(
232251
self.container_version,
233252
self.distribution,
234253
self.image_scope.value,
235-
self.framework,
254+
self.framework.value,
236255
self.inference_tool,
237256
self.instance_type,
238257
self.processor.value,
239258
self.py_version,
240259
tag_prefix,
241-
version)
260+
version,
261+
)
242262

243263
if tag:
244264
repo += ":{}".format(tag)
245265

246266
return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo)
247267

248-
def _fetch_latest_version_from_config(self,
249-
framework_config: dict) -> str:
268+
def _fetch_latest_version_from_config(self, framework_config: dict) -> str:
269+
"""Fetches the latest version from the framework config."""
250270
if self.image_scope.value in framework_config:
251271
if image_scope_config := framework_config[self.image_scope.value]:
252272
if version_aliases := image_scope_config["version_aliases"]:

0 commit comments

Comments
 (0)