Skip to content

Commit 3463abf

Browse files
committed
fix transformer_engine import error
1 parent a71091c commit 3463abf

File tree

2 files changed

+159
-148
lines changed

2 files changed

+159
-148
lines changed

configs/7B_sft.py

+2
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,8 @@
199199
tp_overlap_cfg=dict(
200200
tp_comm_overlap_ag=True,
201201
tp_comm_overlap_rs=True,
202+
tp_comm_bulk_wgrad=True,
203+
tp_comm_bulk_dgrad=True,
202204
),
203205
),
204206
pipeline=dict(size=1, interleaved_overlap=True),

internlm/model/modules/linear.py

+157-148
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,14 @@
99

1010
import torch
1111
import torch.distributed as dist
12-
import transformer_engine as te
12+
13+
try:
14+
import transformer_engine as te
15+
16+
has_te = True
17+
except (ModuleNotFoundError, ImportError):
18+
has_te = False
19+
1320
from torch import nn
1421

1522
from internlm.accelerator import get_accelerator
@@ -1011,161 +1018,163 @@ def __init__(
10111018
self.full_weight_shape = torch.Size((num_groups, in_features, out_features))
10121019

10131020

1014-
class TEColumnParallelLinear(te.pytorch.Linear):
1015-
"""
1016-
Wrapper for the Transformer-Engine's `Linear` layer.
1017-
"""
1021+
if has_te:
10181022

1019-
def __init__(
1020-
self,
1021-
in_features: int,
1022-
out_features: int,
1023-
bias: bool,
1024-
skip_bias_add: bool,
1025-
is_expert: bool,
1026-
tp_comm_buffer_name: str = None,
1027-
):
1028-
if is_expert:
1029-
raise ValueError("Transformer Engine linear layers do not yet support MoE")
1030-
1031-
# TE returns a zero length Tensor when bias=False and
1032-
# return_bias=True, but we prefer None. So in that case we
1033-
# tell TE to not return the bias, and return None
1034-
# ourselves. This way our forward always returns two values
1035-
# and we don't have to deal with the zero length Tensor.
1036-
self.te_return_bias = skip_bias_add and bias
1037-
self.is_first_microbatch = True
1038-
1039-
extra_kwargs = {"params_dtype": gpc.config.model.dtype}
1040-
if is_te_min_version("0.12.0"):
1041-
extra_kwargs["device"] = torch.cuda.current_device()
1042-
1043-
if gpc.config.parallel["tensor"]["tp_overlap"]:
1044-
extra_kwargs["ub_bulk_wgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
1045-
"tp_comm_bulk_wgrad", True
1046-
)
1047-
extra_kwargs["ub_bulk_dgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
1048-
"tp_comm_bulk_dgrad", True
1049-
)
1050-
if is_te_min_version("1.5.0", check_equality=False):
1051-
extra_kwargs["ub_overlap_ag"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
1052-
"tp_comm_overlap_ag", True
1053-
)
1054-
else:
1055-
raise NotImplementedError("tp overlap is supported only when transformer_engine version >= 1.5.0")
1056-
assert (
1057-
tp_comm_buffer_name is not None
1058-
), "Buffer name should be set to configure communication overlap settings"
1059-
extra_kwargs["ub_name"] = tp_comm_buffer_name
1060-
1061-
parallel_mode = get_tensor_split_parallel_mode(is_expert=is_expert)
1062-
tp_size = gpc.get_world_size(parallel_mode)
1063-
tp_group = gpc.get_group(parallel_mode)
1064-
super().__init__(
1065-
in_features=in_features,
1066-
out_features=out_features,
1067-
sequence_parallel=gpc.config.parallel.sequence_parallel,
1068-
tp_group=tp_group,
1069-
tp_size=tp_size,
1070-
bias=bias,
1071-
return_bias=self.te_return_bias,
1072-
parallel_mode="column",
1073-
**extra_kwargs,
1074-
)
1075-
1076-
def forward(self, x):
1077-
"""Forward."""
1078-
_is_first_microbatch = self.is_first_microbatch
1079-
x = x.transpose(0, 1)
1080-
out = super().forward(x, is_first_microbatch=_is_first_microbatch)
1081-
out = out.transpose(0, 1)
1082-
1083-
self.is_first_microbatch = False
1023+
class TEColumnParallelLinear(te.pytorch.Linear):
1024+
"""
1025+
Wrapper for the Transformer-Engine's `Linear` layer.
1026+
"""
10841027

1085-
return out
1028+
def __init__(
1029+
self,
1030+
in_features: int,
1031+
out_features: int,
1032+
bias: bool,
1033+
skip_bias_add: bool,
1034+
is_expert: bool,
1035+
tp_comm_buffer_name: str = None,
1036+
):
1037+
if is_expert:
1038+
raise ValueError("Transformer Engine linear layers do not yet support MoE")
1039+
1040+
# TE returns a zero length Tensor when bias=False and
1041+
# return_bias=True, but we prefer None. So in that case we
1042+
# tell TE to not return the bias, and return None
1043+
# ourselves. This way our forward always returns two values
1044+
# and we don't have to deal with the zero length Tensor.
1045+
self.te_return_bias = skip_bias_add and bias
1046+
self.is_first_microbatch = True
1047+
1048+
extra_kwargs = {"params_dtype": gpc.config.model.dtype}
1049+
if is_te_min_version("0.12.0"):
1050+
extra_kwargs["device"] = torch.cuda.current_device()
1051+
1052+
if gpc.config.parallel["tensor"]["tp_overlap"]:
1053+
extra_kwargs["ub_bulk_wgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
1054+
"tp_comm_bulk_wgrad", True
1055+
)
1056+
extra_kwargs["ub_bulk_dgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
1057+
"tp_comm_bulk_dgrad", True
1058+
)
1059+
if is_te_min_version("1.5.0", check_equality=False):
1060+
extra_kwargs["ub_overlap_ag"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
1061+
"tp_comm_overlap_ag", True
1062+
)
1063+
else:
1064+
raise NotImplementedError("tp overlap is supported only when transformer_engine version >= 1.5.0")
1065+
assert (
1066+
tp_comm_buffer_name is not None
1067+
), "Buffer name should be set to configure communication overlap settings"
1068+
extra_kwargs["ub_name"] = tp_comm_buffer_name
1069+
1070+
parallel_mode = get_tensor_split_parallel_mode(is_expert=is_expert)
1071+
tp_size = gpc.get_world_size(parallel_mode)
1072+
tp_group = gpc.get_group(parallel_mode)
1073+
super().__init__(
1074+
in_features=in_features,
1075+
out_features=out_features,
1076+
sequence_parallel=gpc.config.parallel.sequence_parallel,
1077+
tp_group=tp_group,
1078+
tp_size=tp_size,
1079+
bias=bias,
1080+
return_bias=self.te_return_bias,
1081+
parallel_mode="column",
1082+
**extra_kwargs,
1083+
)
10861084

1085+
def forward(self, x):
1086+
"""Forward."""
1087+
_is_first_microbatch = self.is_first_microbatch
1088+
x = x.transpose(0, 1)
1089+
out = super().forward(x, is_first_microbatch=_is_first_microbatch)
1090+
out = out.transpose(0, 1)
10871091

1088-
class TERowParallelLinear(te.pytorch.Linear):
1089-
"""
1090-
Wrapper for the Transformer-Engine's `Linear` layer.
1091-
"""
1092+
self.is_first_microbatch = False
10921093

1093-
def __init__(
1094-
self,
1095-
in_features: int,
1096-
out_features: int,
1097-
bias: bool,
1098-
skip_bias_add: bool,
1099-
is_expert: bool = False,
1100-
tp_comm_buffer_name: str = None,
1101-
):
1102-
# TE returns a zero length Tensor when bias=False and
1103-
# return_bias=True. Here we need a single Tensor
1104-
self.te_return_bias = skip_bias_add and bias
1105-
self.is_first_microbatch = True
1106-
1107-
extra_kwargs = {"params_dtype": gpc.config.model.dtype}
1108-
if is_te_min_version("0.12.0"):
1109-
extra_kwargs["device"] = torch.cuda.current_device()
1110-
1111-
if gpc.config.parallel["tensor"]["tp_overlap"]:
1112-
if is_te_min_version("1.5.0"):
1113-
extra_kwargs["ub_overlap_ag"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
1114-
"tp_comm_overlap_ag", True
1115-
)
1116-
extra_kwargs["ub_overlap_rs"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
1117-
"tp_comm_overlap_rs", True
1118-
)
1119-
# Disable ub overlap for experts.
1120-
if is_expert:
1121-
extra_kwargs["ub_overlap_ag"] = False
1122-
extra_kwargs["ub_overlap_rs"] = False
1123-
else:
1124-
raise NotImplementedError("tp overlap is supported only when transformer_engine version >= 1.5.0")
1125-
assert (
1126-
tp_comm_buffer_name is not None
1127-
), "Buffer name should be set to configure communication overlap settings"
1128-
extra_kwargs["ub_name"] = tp_comm_buffer_name
1094+
return out
11291095

1130-
self.expert_parallel = gpc.config.parallel["expert"].get("size", 1) > 1
1131-
parallel_mode = get_tensor_split_parallel_mode(is_expert=is_expert)
1132-
# Disable communications in TE when using TP or EP by making TE agnostic of model parallel.
1133-
tp_size = gpc.get_world_size(parallel_mode)
1134-
tp_group = gpc.get_group(parallel_mode)
1135-
explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel)
1136-
1137-
split_mode = "row"
1138-
if explicit_expert_comm:
1139-
assert in_features % tp_size == 0, "{} is not divisible by {}".format(in_features, tp_size)
1140-
in_features = in_features // tp_size
1141-
split_mode = None
1142-
tp_size = 1
1143-
tp_group = None
1096+
class TERowParallelLinear(te.pytorch.Linear):
1097+
"""
1098+
Wrapper for the Transformer-Engine's `Linear` layer.
1099+
"""
11441100

1145-
super().__init__(
1146-
in_features=in_features,
1147-
out_features=out_features,
1148-
sequence_parallel=gpc.config.parallel.sequence_parallel,
1149-
tp_group=tp_group,
1150-
tp_size=tp_size,
1151-
bias=bias,
1152-
return_bias=self.te_return_bias,
1153-
parallel_mode=split_mode,
1154-
**extra_kwargs,
1155-
)
1101+
def __init__(
1102+
self,
1103+
in_features: int,
1104+
out_features: int,
1105+
bias: bool,
1106+
skip_bias_add: bool,
1107+
is_expert: bool = False,
1108+
tp_comm_buffer_name: str = None,
1109+
):
1110+
# TE returns a zero length Tensor when bias=False and
1111+
# return_bias=True. Here we need a single Tensor
1112+
self.te_return_bias = skip_bias_add and bias
1113+
self.is_first_microbatch = True
1114+
1115+
extra_kwargs = {"params_dtype": gpc.config.model.dtype}
1116+
if is_te_min_version("0.12.0"):
1117+
extra_kwargs["device"] = torch.cuda.current_device()
1118+
1119+
if gpc.config.parallel["tensor"]["tp_overlap"]:
1120+
if is_te_min_version("1.5.0"):
1121+
extra_kwargs["ub_overlap_ag"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
1122+
"tp_comm_overlap_ag", True
1123+
)
1124+
extra_kwargs["ub_overlap_rs"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
1125+
"tp_comm_overlap_rs", True
1126+
)
1127+
# Disable ub overlap for experts.
1128+
if is_expert:
1129+
extra_kwargs["ub_overlap_ag"] = False
1130+
extra_kwargs["ub_overlap_rs"] = False
1131+
else:
1132+
raise NotImplementedError("tp overlap is supported only when transformer_engine version >= 1.5.0")
1133+
assert (
1134+
tp_comm_buffer_name is not None
1135+
), "Buffer name should be set to configure communication overlap settings"
1136+
extra_kwargs["ub_name"] = tp_comm_buffer_name
1137+
1138+
self.expert_parallel = gpc.config.parallel["expert"].get("size", 1) > 1
1139+
parallel_mode = get_tensor_split_parallel_mode(is_expert=is_expert)
1140+
# Disable communications in TE when using TP or EP by making TE agnostic of model parallel.
1141+
tp_size = gpc.get_world_size(parallel_mode)
1142+
tp_group = gpc.get_group(parallel_mode)
1143+
explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel)
1144+
1145+
split_mode = "row"
1146+
if explicit_expert_comm:
1147+
assert in_features % tp_size == 0, "{} is not divisible by {}".format(in_features, tp_size)
1148+
in_features = in_features // tp_size
1149+
split_mode = None
1150+
tp_size = 1
1151+
tp_group = None
1152+
1153+
super().__init__(
1154+
in_features=in_features,
1155+
out_features=out_features,
1156+
sequence_parallel=gpc.config.parallel.sequence_parallel,
1157+
tp_group=tp_group,
1158+
tp_size=tp_size,
1159+
bias=bias,
1160+
return_bias=self.te_return_bias,
1161+
parallel_mode=split_mode,
1162+
**extra_kwargs,
1163+
)
11561164

1157-
for param in self.parameters():
1158-
setattr(param, "allreduce", not (is_expert and self.expert_parallel))
1165+
def forward(self, x):
1166+
"""Forward."""
1167+
_is_first_microbatch = self.is_first_microbatch
1168+
x = x.transpose(0, 1)
1169+
out = super().forward(x, is_first_microbatch=_is_first_microbatch)
1170+
out = out.transpose(0, 1)
1171+
self.is_first_microbatch = False
11591172

1160-
def forward(self, x):
1161-
"""Forward."""
1162-
_is_first_microbatch = self.is_first_microbatch
1163-
x = x.transpose(0, 1)
1164-
out = super().forward(x, is_first_microbatch=_is_first_microbatch)
1165-
out = out.transpose(0, 1)
1166-
self.is_first_microbatch = False
1173+
return out
11671174

1168-
return out
1175+
else:
1176+
TEColumnParallelLinear = ColumnParallelLinear
1177+
TERowParallelLinear = RowParallelLinear
11691178

11701179

11711180
def new_linear(
@@ -1217,7 +1226,7 @@ def new_linear(
12171226
weight_scale=weight_scale,
12181227
norm_head=norm_head,
12191228
)
1220-
elif split_mode == "column":
1229+
elif split_mode == "column" or (split_mode == "tecolumn" and not has_te):
12211230
return ColumnParallelLinear(
12221231
in_features,
12231232
out_features,
@@ -1236,7 +1245,7 @@ def new_linear(
12361245
is_expert,
12371246
tp_comm_buffer_name,
12381247
)
1239-
elif split_mode == "row":
1248+
elif split_mode == "row" or (split_mode == "terow" and not has_te):
12401249
return RowParallelLinear(
12411250
in_features,
12421251
out_features,

0 commit comments

Comments
 (0)