17
17
import re
18
18
from typing import Optional , Union , Dict
19
19
20
- from sagemaker .deprecations import renamed_kwargs
21
20
from sagemaker .estimator import Framework , EstimatorBase
22
21
from sagemaker .fw_utils import (
23
22
framework_name_from_image ,
24
- warn_if_parameter_server_with_multi_gpu ,
25
- validate_smdistributed ,
23
+ validate_distribution ,
26
24
)
27
25
from sagemaker .huggingface .model import HuggingFaceModel
28
26
from sagemaker .vpc_utils import VPC_CONFIG_DEFAULT
@@ -37,6 +35,9 @@ class HuggingFace(Framework):
37
35
"""Handle training of custom HuggingFace code."""
38
36
39
37
_framework_name = "huggingface"
38
+ LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled"
39
+ LAUNCH_TORCH_DISTRIBUTED_ENV_NAME = "sagemaker_torch_distributed_enabled"
40
+ INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type"
40
41
41
42
def __init__ (
42
43
self ,
@@ -142,6 +143,36 @@ def __init__(
142
143
}
143
144
}
144
145
146
+ **To enable PyTorch DDP:**
147
+
148
+ .. code:: python
149
+
150
+ {
151
+ "pytorchddp": {
152
+ "enabled": True
153
+ }
154
+ }
155
+
156
+ To learn more, see `Distributed PyTorch Training
157
+ <https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training>`_.
158
+
159
+ **To enable Torch Distributed:**
160
+
161
+ This is available for general distributed training on
162
+ GPU instances from PyTorch v1.13.1 and later.
163
+
164
+ .. code:: python
165
+
166
+ {
167
+ "torch_distributed": {
168
+ "enabled": True
169
+ }
170
+ }
171
+
172
+ This option also supports distributed training on Trn1.
173
+ To learn more, see `Distributed PyTorch Training on Trainium
174
+ <https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training-on-trainium>`_.
175
+
145
176
To enable distributed training with
146
177
`SageMaker Training Compiler <https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_
147
178
for Hugging Face Transformers with PyTorch:
@@ -182,29 +213,6 @@ def __init__(
182
213
183
214
self ._validate_args (image_uri = image_uri )
184
215
185
- instance_type = renamed_kwargs (
186
- "train_instance_type" , "instance_type" , kwargs .get ("instance_type" ), kwargs
187
- )
188
-
189
- base_framework_name = "tensorflow" if tensorflow_version is not None else "pytorch"
190
- base_framework_version = (
191
- tensorflow_version if tensorflow_version is not None else pytorch_version
192
- )
193
-
194
- if distribution is not None :
195
- validate_smdistributed (
196
- instance_type = instance_type ,
197
- framework_name = base_framework_name ,
198
- framework_version = base_framework_version ,
199
- py_version = self .py_version ,
200
- distribution = distribution ,
201
- image_uri = image_uri ,
202
- )
203
-
204
- warn_if_parameter_server_with_multi_gpu (
205
- training_instance_type = instance_type , distribution = distribution
206
- )
207
-
208
216
if "enable_sagemaker_metrics" not in kwargs :
209
217
kwargs ["enable_sagemaker_metrics" ] = True
210
218
@@ -214,6 +222,25 @@ def __init__(
214
222
entry_point , source_dir , hyperparameters , image_uri = image_uri , ** kwargs
215
223
)
216
224
225
+ if "entry_point" not in kwargs :
226
+ kwargs ["entry_point" ] = entry_point
227
+
228
+ self .base_framework_name = "tensorflow" if tensorflow_version is not None else "pytorch"
229
+ self .base_framework_version = (
230
+ tensorflow_version if tensorflow_version is not None else pytorch_version
231
+ )
232
+
233
+ if distribution is not None :
234
+ distribution = validate_distribution (
235
+ distribution ,
236
+ self .instance_groups ,
237
+ self .base_framework_name ,
238
+ self .base_framework_version ,
239
+ py_version ,
240
+ image_uri ,
241
+ kwargs ,
242
+ )
243
+
217
244
self .distribution = distribution or {}
218
245
219
246
if compiler_config is not None :
@@ -267,14 +294,44 @@ def _validate_args(self, image_uri):
267
294
"transformers_version, tensorflow_version and pytorch_version."
268
295
)
269
296
297
+ def _huggingface_distribution_configuration (self , distribution ):
298
+ """Returns a dict of distribution config for Hugging Face training
299
+
300
+ Args:
301
+ distribution (dict): A dictionary with information on how to run distributed training.
302
+ Returns:
303
+ dict containing Pytorch DDP config
304
+ """
305
+ distribution_config = {}
306
+ pytorch_ddp_enabled = False
307
+ torch_distributed_enabled = False
308
+
309
+ if "pytorchddp" in distribution :
310
+ pytorch_ddp_enabled = distribution .get ("pytorchddp" ).get ("enabled" , False )
311
+ elif "torch_distributed" in distribution :
312
+ torch_distributed_enabled = distribution .get ("torch_distributed" ).get ("enabled" , False )
313
+
314
+ if pytorch_ddp_enabled :
315
+ distribution_config [self .LAUNCH_PYTORCH_DDP_ENV_NAME ] = pytorch_ddp_enabled
316
+ if self .instance_type is not None :
317
+ distribution_config [self .INSTANCE_TYPE_ENV_NAME ] = self .instance_type
318
+ elif torch_distributed_enabled :
319
+ distribution_config [self .LAUNCH_TORCH_DISTRIBUTED_ENV_NAME ] = torch_distributed_enabled
320
+ if self .instance_type is not None :
321
+ distribution_config [self .INSTANCE_TYPE_ENV_NAME ] = self .instance_type
322
+ else :
323
+ distribution_config = self ._distribution_configuration (distribution = distribution )
324
+
325
+ return distribution_config
326
+
270
327
def hyperparameters (self ):
271
328
"""Return hyperparameters used by your custom PyTorch code during model training."""
272
329
hyperparameters = super (HuggingFace , self ).hyperparameters ()
273
- distributed_training_hyperparameters = self ._distribution_configuration (
330
+ additional_hyperparameters = self ._huggingface_distribution_configuration (
274
331
distribution = self .distribution
275
332
)
276
333
hyperparameters .update (
277
- EstimatorBase ._json_encode_hyperparameters (distributed_training_hyperparameters )
334
+ EstimatorBase ._json_encode_hyperparameters (additional_hyperparameters )
278
335
)
279
336
280
337
if self .compiler_config :
0 commit comments