Skip to content

Commit ebd48c9

Browse files
authored
feature: [huggingface] Add torch.distributed support for Trainium and torchrun (#3759)
* add torch distributed * change order * entrypoint * make tox happy * add HF test * change test name * Update test_huggingface_torch_distributed.py * fix tox and black * removed mark * fix flake8 * fix test * fix instance type
1 parent c217e1c commit ebd48c9

File tree

2 files changed

+135
-28
lines changed

2 files changed

+135
-28
lines changed

src/sagemaker/huggingface/estimator.py

+85-28
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,10 @@
1717
import re
1818
from typing import Optional, Union, Dict
1919

20-
from sagemaker.deprecations import renamed_kwargs
2120
from sagemaker.estimator import Framework, EstimatorBase
2221
from sagemaker.fw_utils import (
2322
framework_name_from_image,
24-
warn_if_parameter_server_with_multi_gpu,
25-
validate_smdistributed,
23+
validate_distribution,
2624
)
2725
from sagemaker.huggingface.model import HuggingFaceModel
2826
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
@@ -37,6 +35,9 @@ class HuggingFace(Framework):
3735
"""Handle training of custom HuggingFace code."""
3836

3937
_framework_name = "huggingface"
38+
LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled"
39+
LAUNCH_TORCH_DISTRIBUTED_ENV_NAME = "sagemaker_torch_distributed_enabled"
40+
INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type"
4041

4142
def __init__(
4243
self,
@@ -142,6 +143,36 @@ def __init__(
142143
}
143144
}
144145
146+
**To enable PyTorch DDP:**
147+
148+
.. code:: python
149+
150+
{
151+
"pytorchddp": {
152+
"enabled": True
153+
}
154+
}
155+
156+
To learn more, see `Distributed PyTorch Training
157+
<https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training>`_.
158+
159+
**To enable Torch Distributed:**
160+
161+
This is available for general distributed training on
162+
GPU instances from PyTorch v1.13.1 and later.
163+
164+
.. code:: python
165+
166+
{
167+
"torch_distributed": {
168+
"enabled": True
169+
}
170+
}
171+
172+
This option also supports distributed training on Trn1.
173+
To learn more, see `Distributed PyTorch Training on Trainium
174+
<https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training-on-trainium>`_.
175+
145176
To enable distributed training with
146177
`SageMaker Training Compiler <https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_
147178
for Hugging Face Transformers with PyTorch:
@@ -182,29 +213,6 @@ def __init__(
182213

183214
self._validate_args(image_uri=image_uri)
184215

185-
instance_type = renamed_kwargs(
186-
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
187-
)
188-
189-
base_framework_name = "tensorflow" if tensorflow_version is not None else "pytorch"
190-
base_framework_version = (
191-
tensorflow_version if tensorflow_version is not None else pytorch_version
192-
)
193-
194-
if distribution is not None:
195-
validate_smdistributed(
196-
instance_type=instance_type,
197-
framework_name=base_framework_name,
198-
framework_version=base_framework_version,
199-
py_version=self.py_version,
200-
distribution=distribution,
201-
image_uri=image_uri,
202-
)
203-
204-
warn_if_parameter_server_with_multi_gpu(
205-
training_instance_type=instance_type, distribution=distribution
206-
)
207-
208216
if "enable_sagemaker_metrics" not in kwargs:
209217
kwargs["enable_sagemaker_metrics"] = True
210218

@@ -214,6 +222,25 @@ def __init__(
214222
entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs
215223
)
216224

225+
if "entry_point" not in kwargs:
226+
kwargs["entry_point"] = entry_point
227+
228+
self.base_framework_name = "tensorflow" if tensorflow_version is not None else "pytorch"
229+
self.base_framework_version = (
230+
tensorflow_version if tensorflow_version is not None else pytorch_version
231+
)
232+
233+
if distribution is not None:
234+
distribution = validate_distribution(
235+
distribution,
236+
self.instance_groups,
237+
self.base_framework_name,
238+
self.base_framework_version,
239+
py_version,
240+
image_uri,
241+
kwargs,
242+
)
243+
217244
self.distribution = distribution or {}
218245

219246
if compiler_config is not None:
@@ -267,14 +294,44 @@ def _validate_args(self, image_uri):
267294
"transformers_version, tensorflow_version and pytorch_version."
268295
)
269296

297+
def _huggingface_distribution_configuration(self, distribution):
298+
"""Returns a dict of distribution config for Hugging Face training
299+
300+
Args:
301+
distribution (dict): A dictionary with information on how to run distributed training.
302+
Returns:
303+
dict containing Pytorch DDP config
304+
"""
305+
distribution_config = {}
306+
pytorch_ddp_enabled = False
307+
torch_distributed_enabled = False
308+
309+
if "pytorchddp" in distribution:
310+
pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False)
311+
elif "torch_distributed" in distribution:
312+
torch_distributed_enabled = distribution.get("torch_distributed").get("enabled", False)
313+
314+
if pytorch_ddp_enabled:
315+
distribution_config[self.LAUNCH_PYTORCH_DDP_ENV_NAME] = pytorch_ddp_enabled
316+
if self.instance_type is not None:
317+
distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type
318+
elif torch_distributed_enabled:
319+
distribution_config[self.LAUNCH_TORCH_DISTRIBUTED_ENV_NAME] = torch_distributed_enabled
320+
if self.instance_type is not None:
321+
distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type
322+
else:
323+
distribution_config = self._distribution_configuration(distribution=distribution)
324+
325+
return distribution_config
326+
270327
def hyperparameters(self):
271328
"""Return hyperparameters used by your custom PyTorch code during model training."""
272329
hyperparameters = super(HuggingFace, self).hyperparameters()
273-
distributed_training_hyperparameters = self._distribution_configuration(
330+
additional_hyperparameters = self._huggingface_distribution_configuration(
274331
distribution=self.distribution
275332
)
276333
hyperparameters.update(
277-
EstimatorBase._json_encode_hyperparameters(distributed_training_hyperparameters)
334+
EstimatorBase._json_encode_hyperparameters(additional_hyperparameters)
278335
)
279336

280337
if self.compiler_config:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
from sagemaker.huggingface import HuggingFace
17+
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES, timeout
18+
19+
20+
def test_huggingface_torch_distributed_g5_glue(
21+
sagemaker_session,
22+
huggingface_training_latest_version,
23+
huggingface_training_pytorch_latest_version,
24+
huggingface_pytorch_latest_training_py_version,
25+
):
26+
with timeout.timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
27+
data_path = os.path.join(DATA_DIR, "huggingface")
28+
estimator = HuggingFace(
29+
py_version=huggingface_pytorch_latest_training_py_version,
30+
entry_point=os.path.join(data_path, "run_glue.py"),
31+
role="SageMakerRole",
32+
transformers_version=huggingface_training_latest_version,
33+
pytorch_version=huggingface_training_pytorch_latest_version,
34+
instance_count=1,
35+
instance_type="ml.g5.12xlarge",
36+
hyperparameters={
37+
"model_name_or_path": "distilbert-base-cased",
38+
"task_name": "wnli",
39+
"do_train": True,
40+
"do_eval": True,
41+
"max_seq_length": 128,
42+
"fp16": True,
43+
"per_device_train_batch_size": 32,
44+
"output_dir": "/opt/ml/model",
45+
},
46+
distribution={"torch_distributed": {"enabled": True}},
47+
sagemaker_session=sagemaker_session,
48+
disable_profiler=True,
49+
)
50+
estimator.fit()

0 commit comments

Comments
 (0)