Skip to content

Commit e692eaa

Browse files
committedFeb 20, 2025·
support fp8 gemm
1 parent a70aaf6 commit e692eaa

File tree

2 files changed

+61
-1
lines changed

2 files changed

+61
-1
lines changed
 

‎internlm/core/engine.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,18 @@
33

44
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
55

6+
from contextlib import nullcontext
67
from typing import List, Optional
78

89
import torch
10+
import transformer_engine.pytorch as te
911
from torch.nn import Module
1012
from torch.nn.modules.loss import _Loss
1113
from torch.optim.lr_scheduler import _LRScheduler
14+
from transformer_engine.common.recipe import DelayedScaling, Format
1215

16+
from internlm.core.context import ParallelMode
17+
from internlm.core.context import global_context as gpc
1318
from internlm.core.gradient_handler import BaseGradientHandler
1419
from internlm.solver.optimizer import BaseOptimizer
1520
from internlm.solver.schedulers import Beta2Scheduler
@@ -78,6 +83,28 @@ def __init__(
7883
# build gradient handler
7984
self._gradient_handlers = gradient_handlers if gradient_handlers else []
8085

86+
# FP8 GEMM
87+
fp8_cfg = gpc.config.get("fp8", None)
88+
self.use_fp8 = fp8_cfg is not None
89+
self.fp8_recipe = None
90+
self.fp8_group = None
91+
if self.use_fp8:
92+
self.fp8_group = gpc.get_group(ParallelMode.GLOBAL)
93+
if fp8_cfg.format == "e4m3":
94+
fp8_format = Format.E4M3
95+
elif fp8_cfg.format == "hybrid":
96+
fp8_format = Format.HYBRID
97+
else:
98+
raise ValueError("The DelayedScaling recipe only supports E4M3 and HYBRID formats.")
99+
self.fp8_recipe = DelayedScaling(
100+
margin=fp8_cfg.margin,
101+
interval=fp8_cfg.interval,
102+
fp8_format=fp8_format,
103+
amax_history_len=fp8_cfg.amax_history_len,
104+
amax_compute_algo=fp8_cfg.amax_compute_algo,
105+
override_linear_precision=(False, False, not fp8_cfg.fp8_wgrad),
106+
)
107+
81108
@property
82109
def model(self):
83110
"""Returns the model attached to the engine."""
@@ -166,7 +193,11 @@ def __call__(self, *args, **kwargs):
166193
Returns:
167194
torch.Tensor: The output of the model.
168195
"""
169-
return self.model(*args, **kwargs)
196+
with te.fp8_autocast(
197+
enabled=self.use_fp8, fp8_recipe=self.fp8_recipe, fp8_group=self.fp8_group
198+
) if self.use_fp8 else nullcontext():
199+
output = self.model(*args, **kwargs)
200+
return output
170201

171202
def load_batch(self, data_iter, to_gpu=True):
172203
"""

‎internlm/model/model_implementations/builder.py

+29
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import List, Union
22

33
import torch
4+
import transformer_engine.pytorch as te
45
from torch import nn
56

67
from internlm.core.context import ParallelMode
@@ -22,6 +23,31 @@
2223
logger = get_logger(__file__)
2324

2425

26+
def simple_swap(model, device):
27+
for submodule_name, submodule in model.named_modules():
28+
if isinstance(submodule, torch.nn.Linear):
29+
path_in_state_dict = submodule_name.split(".")
30+
current_module = model
31+
32+
# traverse to leaf module
33+
leaf_path = path_in_state_dict[:-1]
34+
leaf_name = path_in_state_dict[-1]
35+
for child_name in leaf_path:
36+
current_module = getattr(current_module, child_name)
37+
38+
# perform a swap
39+
old_leaf = getattr(current_module, leaf_name)
40+
new_leaf = te.Linear(old_leaf.in_features, old_leaf.out_features, old_leaf.bias is not None, device=device)
41+
with torch.no_grad():
42+
new_leaf.weight.copy_(old_leaf.weight)
43+
assert torch.equal(new_leaf.weight, old_leaf.weight)
44+
if old_leaf.bias is not None:
45+
new_leaf.bias.copy_(old_leaf.bias)
46+
assert torch.equal(new_leaf.bias, old_leaf.bias)
47+
48+
setattr(current_module, leaf_name, new_leaf)
49+
50+
2551
def create_model() -> Union[nn.Module, List[nn.Module]]:
2652
if is_using_hf():
2753
model = create_model_hf(hf=gpc.config.hf)
@@ -130,4 +156,7 @@ def traverse(module):
130156
else:
131157
traverse(model)
132158

159+
if gpc.config.get("fp8", None) is not None:
160+
simple_swap(model, fsdp_init_method)
161+
133162
return model

0 commit comments

Comments
 (0)
Please sign in to comment.