Skip to content

Commit

Permalink
fix fp16 precision issue (#9376)
Browse files Browse the repository at this point in the history
* fix fp16 precision issue by disabling enable_autocast

Signed-off-by: dimapihtar <[email protected]>

* revert config

Signed-off-by: dimapihtar <[email protected]>

* add fp16 precision test

Signed-off-by: dimapihtar <[email protected]>

---------

Signed-off-by: dimapihtar <[email protected]>
  • Loading branch information
dimapihtar authored Jun 5, 2024
1 parent f2dffaa commit d02bb32
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3322,8 +3322,10 @@ jobs:
trainer.limit_val_batches=2 \
trainer.accumulate_grad_batches=1 \
trainer.max_steps=3 \
trainer.precision=16 \
trainer.gradient_clip_val=1.0 \
exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \
model.megatron_amp_O2=False \
model.tensor_model_parallel_size=2 \
model.optim.name=fused_adam \
model.optim.lr=2e-4 \
Expand Down Expand Up @@ -3355,9 +3357,11 @@ jobs:
trainer.limit_val_batches=2 \
trainer.accumulate_grad_batches=1 \
trainer.max_steps=6 \
trainer.precision=16 \
trainer.gradient_clip_val=1.0 \
exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \
exp_manager.resume_if_exists=True \
model.megatron_amp_O2=False \
model.tensor_model_parallel_size=2 \
model.optim.name=fused_adam \
model.optim.lr=2e-4 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,10 @@ def build_model_parallel_config(self) -> ModelParallelConfig:
"tp_comm_overlap": self.cfg.get('ub_tp_comm_overlap', False),
}

# Set enable_autocast to False when precision is fp16 and not using bias
if not megatron_amp_O2 and not self.cfg.get('bias', True):
config_mapping["enable_autocast"] = False

# instantitate ModelParallelConfig from this dict
mp_config_dict = {}

Expand Down

0 comments on commit d02bb32

Please sign in to comment.