Skip to content

Commit 7ac997e

Browse files
committed
optimize code
1 parent c392ba5 commit 7ac997e

File tree

2 files changed

+50
-79
lines changed

2 files changed

+50
-79
lines changed

internlm/model/modules/linear.py

+48-76
Original file line numberDiff line numberDiff line change
@@ -1020,7 +1020,7 @@ def __init__(
10201020

10211021
if has_te:
10221022

1023-
class TEColumnParallelLinear(te.pytorch.Linear):
1023+
class TELinear(te.pytorch.Linear):
10241024
"""
10251025
Wrapper for the Transformer-Engine's `Linear` layer.
10261026
"""
@@ -1032,34 +1032,36 @@ def __init__(
10321032
bias: bool,
10331033
skip_bias_add: bool,
10341034
is_expert: bool,
1035+
split_mode: str = "none",
10351036
tp_comm_buffer_name: str = None,
10361037
):
10371038
if is_expert:
10381039
raise ValueError("Transformer Engine linear layers do not yet support MoE")
10391040

10401041
# TE returns a zero length Tensor when bias=False and
1041-
# return_bias=True, but we prefer None. So in that case we
1042-
# tell TE to not return the bias, and return None
1043-
# ourselves. This way our forward always returns two values
1044-
# and we don't have to deal with the zero length Tensor.
1042+
# return_bias=True. Here we need a single Tensor
10451043
self.te_return_bias = skip_bias_add and bias
10461044
self.is_first_microbatch = True
10471045

10481046
extra_kwargs = {"params_dtype": gpc.config.model.dtype}
1049-
if is_te_min_version("0.12.0"):
1050-
extra_kwargs["device"] = torch.cuda.current_device()
1047+
extra_kwargs["device"] = torch.cuda.current_device()
10511048

10521049
if gpc.config.parallel["tensor"].get("tp_overlap", False):
1053-
extra_kwargs["ub_bulk_wgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
1054-
"tp_comm_bulk_wgrad", True
1055-
)
1056-
extra_kwargs["ub_bulk_dgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
1057-
"tp_comm_bulk_dgrad", True
1058-
)
10591050
if is_te_min_version("1.5.0", check_equality=False):
10601051
extra_kwargs["ub_overlap_ag"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
10611052
"tp_comm_overlap_ag", True
10621053
)
1054+
if split_mode == "column":
1055+
extra_kwargs["ub_bulk_wgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
1056+
"tp_comm_bulk_wgrad", True
1057+
)
1058+
extra_kwargs["ub_bulk_dgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
1059+
"tp_comm_bulk_dgrad", True
1060+
)
1061+
elif split_mode == "row":
1062+
extra_kwargs["ub_overlap_rs"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
1063+
"tp_comm_overlap_rs", True
1064+
)
10631065
else:
10641066
raise NotImplementedError("tp overlap is supported only when transformer_engine version >= 1.5.0")
10651067
assert (
@@ -1070,6 +1072,7 @@ def __init__(
10701072
parallel_mode = get_tensor_split_parallel_mode(is_expert=is_expert)
10711073
tp_size = gpc.get_world_size(parallel_mode)
10721074
tp_group = gpc.get_group(parallel_mode)
1075+
10731076
super().__init__(
10741077
in_features=in_features,
10751078
out_features=out_features,
@@ -1078,7 +1081,7 @@ def __init__(
10781081
tp_size=tp_size,
10791082
bias=bias,
10801083
return_bias=self.te_return_bias,
1081-
parallel_mode="column",
1084+
parallel_mode=split_mode,
10821085
**extra_kwargs,
10831086
)
10841087

@@ -1088,14 +1091,13 @@ def forward(self, x):
10881091
x = x.transpose(0, 1)
10891092
out = super().forward(x, is_first_microbatch=_is_first_microbatch)
10901093
out = out.transpose(0, 1)
1091-
10921094
self.is_first_microbatch = False
10931095

10941096
return out
10951097

1096-
class TERowParallelLinear(te.pytorch.Linear):
1098+
class TEColumnParallelLinear(TELinear):
10971099
"""
1098-
Wrapper for the Transformer-Engine's `Linear` layer.
1100+
Wrapper for the TELinear layer.
10991101
"""
11001102

11011103
def __init__(
@@ -1107,72 +1109,42 @@ def __init__(
11071109
is_expert: bool = False,
11081110
tp_comm_buffer_name: str = None,
11091111
):
1110-
# TE returns a zero length Tensor when bias=False and
1111-
# return_bias=True. Here we need a single Tensor
1112-
self.te_return_bias = skip_bias_add and bias
1113-
self.is_first_microbatch = True
1114-
1115-
extra_kwargs = {"params_dtype": gpc.config.model.dtype}
1116-
if is_te_min_version("0.12.0"):
1117-
extra_kwargs["device"] = torch.cuda.current_device()
1118-
1119-
if gpc.config.parallel["tensor"].get("tp_overlap", False):
1120-
if is_te_min_version("1.5.0"):
1121-
extra_kwargs["ub_overlap_ag"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
1122-
"tp_comm_overlap_ag", True
1123-
)
1124-
extra_kwargs["ub_overlap_rs"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
1125-
"tp_comm_overlap_rs", True
1126-
)
1127-
# Disable ub overlap for experts.
1128-
if is_expert:
1129-
extra_kwargs["ub_overlap_ag"] = False
1130-
extra_kwargs["ub_overlap_rs"] = False
1131-
else:
1132-
raise NotImplementedError("tp overlap is supported only when transformer_engine version >= 1.5.0")
1133-
assert (
1134-
tp_comm_buffer_name is not None
1135-
), "Buffer name should be set to configure communication overlap settings"
1136-
extra_kwargs["ub_name"] = tp_comm_buffer_name
1137-
1138-
self.expert_parallel = gpc.config.parallel["expert"].get("size", 1) > 1
1139-
parallel_mode = get_tensor_split_parallel_mode(is_expert=is_expert)
1140-
# Disable communications in TE when using TP or EP by making TE agnostic of model parallel.
1141-
tp_size = gpc.get_world_size(parallel_mode)
1142-
tp_group = gpc.get_group(parallel_mode)
1143-
explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel)
1144-
1145-
split_mode = "row"
1146-
if explicit_expert_comm:
1147-
assert in_features % tp_size == 0, "{} is not divisible by {}".format(in_features, tp_size)
1148-
in_features = in_features // tp_size
1149-
split_mode = None
1150-
tp_size = 1
1151-
tp_group = None
1152-
11531112
super().__init__(
1154-
in_features=in_features,
1155-
out_features=out_features,
1156-
sequence_parallel=gpc.config.parallel.sequence_parallel,
1157-
tp_group=tp_group,
1158-
tp_size=tp_size,
1113+
in_features,
1114+
out_features,
11591115
bias=bias,
1160-
return_bias=self.te_return_bias,
1161-
parallel_mode=split_mode,
1162-
**extra_kwargs,
1116+
skip_bias_add=skip_bias_add,
1117+
is_expert=is_expert,
1118+
split_mode="column",
1119+
tp_comm_buffer_name=tp_comm_buffer_name,
11631120
)
11641121

1165-
def forward(self, x):
1166-
"""Forward."""
1167-
_is_first_microbatch = self.is_first_microbatch
1168-
x = x.transpose(0, 1)
1169-
out = super().forward(x, is_first_microbatch=_is_first_microbatch)
1170-
out = out.transpose(0, 1)
1171-
self.is_first_microbatch = False
1122+
class TERowParallelLinear(TELinear):
1123+
"""
1124+
Wrapper for the TELinear layer.
1125+
"""
11721126

1173-
return out
1127+
def __init__(
1128+
self,
1129+
in_features: int,
1130+
out_features: int,
1131+
bias: bool,
1132+
skip_bias_add: bool,
1133+
is_expert: bool = False,
1134+
tp_comm_buffer_name: str = None,
1135+
):
1136+
super().__init__(
1137+
in_features,
1138+
out_features,
1139+
bias=bias,
1140+
skip_bias_add=skip_bias_add,
1141+
is_expert=is_expert,
1142+
split_mode="row",
1143+
tp_comm_buffer_name=tp_comm_buffer_name,
1144+
)
11741145

11751146
else:
1147+
TELinear = ParallelLinearWithCommExt
11761148
TEColumnParallelLinear = ColumnParallelLinear
11771149
TERowParallelLinear = RowParallelLinear
11781150

internlm/train/pipeline.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,7 @@
5959
RewardModelLinear,
6060
RowParallelLinear,
6161
ScaleColumnParallelLinear,
62-
TEColumnParallelLinear,
63-
TERowParallelLinear,
62+
TELinear,
6463
new_linear,
6564
)
6665
from internlm.model.modules.norm import new_layer_norm
@@ -209,7 +208,7 @@ def _check_module(name, module):
209208
elif gpc.is_initialized(ParallelMode.WEIGHT) and is_using_isp():
210209
setattr(param, IS_WEIGHT_EXPERT_DATA_PARALLEL, True)
211210
# for non-moe linear module
212-
elif isinstance(module, (ParallelLinearWithCommExt, TERowParallelLinear, TEColumnParallelLinear)):
211+
elif isinstance(module, (ParallelLinearWithCommExt, TELinear)):
213212
for param in module.parameters():
214213
if gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp():
215214
setattr(param, IS_TENSOR_ZERO_PARALLEL, True)

0 commit comments

Comments
 (0)