|
3 | 3 |
|
4 | 4 | # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
5 | 5 |
|
| 6 | +from contextlib import nullcontext |
6 | 7 | from typing import List, Optional
|
7 | 8 |
|
8 | 9 | import torch
|
| 10 | +import transformer_engine.pytorch as te |
9 | 11 | from torch.nn import Module
|
10 | 12 | from torch.nn.modules.loss import _Loss
|
11 | 13 | from torch.optim.lr_scheduler import _LRScheduler
|
| 14 | +from transformer_engine.common.recipe import DelayedScaling, Format |
12 | 15 |
|
| 16 | +from internlm.core.context import ParallelMode |
| 17 | +from internlm.core.context import global_context as gpc |
13 | 18 | from internlm.core.gradient_handler import BaseGradientHandler
|
14 | 19 | from internlm.solver.optimizer import BaseOptimizer
|
15 | 20 | from internlm.solver.schedulers import Beta2Scheduler
|
@@ -78,6 +83,28 @@ def __init__(
|
78 | 83 | # build gradient handler
|
79 | 84 | self._gradient_handlers = gradient_handlers if gradient_handlers else []
|
80 | 85 |
|
| 86 | + # FP8 GEMM |
| 87 | + fp8_cfg = gpc.config.get("fp8", None) |
| 88 | + self.use_fp8 = fp8_cfg is not None |
| 89 | + self.fp8_recipe = None |
| 90 | + self.fp8_group = None |
| 91 | + if self.use_fp8: |
| 92 | + self.fp8_group = gpc.get_group(ParallelMode.GLOBAL) |
| 93 | + if fp8_cfg.format == "e4m3": |
| 94 | + fp8_format = Format.E4M3 |
| 95 | + elif fp8_cfg.format == "hybrid": |
| 96 | + fp8_format = Format.HYBRID |
| 97 | + else: |
| 98 | + raise ValueError("The DelayedScaling recipe only supports E4M3 and HYBRID formats.") |
| 99 | + self.fp8_recipe = DelayedScaling( |
| 100 | + margin=fp8_cfg.margin, |
| 101 | + interval=fp8_cfg.interval, |
| 102 | + fp8_format=fp8_format, |
| 103 | + amax_history_len=fp8_cfg.amax_history_len, |
| 104 | + amax_compute_algo=fp8_cfg.amax_compute_algo, |
| 105 | + override_linear_precision=(False, False, not fp8_cfg.fp8_wgrad), |
| 106 | + ) |
| 107 | + |
81 | 108 | @property
|
82 | 109 | def model(self):
|
83 | 110 | """Returns the model attached to the engine."""
|
@@ -166,7 +193,11 @@ def __call__(self, *args, **kwargs):
|
166 | 193 | Returns:
|
167 | 194 | torch.Tensor: The output of the model.
|
168 | 195 | """
|
169 |
| - return self.model(*args, **kwargs) |
| 196 | + with te.fp8_autocast( |
| 197 | + enabled=self.use_fp8, fp8_recipe=self.fp8_recipe, fp8_group=self.fp8_group |
| 198 | + ) if self.use_fp8 else nullcontext(): |
| 199 | + output = self.model(*args, **kwargs) |
| 200 | + return output |
170 | 201 |
|
171 | 202 | def load_batch(self, data_iter, to_gpu=True):
|
172 | 203 | """
|
|
0 commit comments