9
9
10
10
import torch
11
11
import torch .distributed as dist
12
- import transformer_engine as te
12
+
13
+ try :
14
+ import transformer_engine as te
15
+
16
+ has_te = True
17
+ except (ModuleNotFoundError , ImportError ):
18
+ has_te = False
19
+
13
20
from torch import nn
14
21
15
22
from internlm .accelerator import get_accelerator
@@ -1011,161 +1018,163 @@ def __init__(
1011
1018
self .full_weight_shape = torch .Size ((num_groups , in_features , out_features ))
1012
1019
1013
1020
1014
- class TEColumnParallelLinear (te .pytorch .Linear ):
1015
- """
1016
- Wrapper for the Transformer-Engine's `Linear` layer.
1017
- """
1021
+ if has_te :
1018
1022
1019
- def __init__ (
1020
- self ,
1021
- in_features : int ,
1022
- out_features : int ,
1023
- bias : bool ,
1024
- skip_bias_add : bool ,
1025
- is_expert : bool ,
1026
- tp_comm_buffer_name : str = None ,
1027
- ):
1028
- if is_expert :
1029
- raise ValueError ("Transformer Engine linear layers do not yet support MoE" )
1030
-
1031
- # TE returns a zero length Tensor when bias=False and
1032
- # return_bias=True, but we prefer None. So in that case we
1033
- # tell TE to not return the bias, and return None
1034
- # ourselves. This way our forward always returns two values
1035
- # and we don't have to deal with the zero length Tensor.
1036
- self .te_return_bias = skip_bias_add and bias
1037
- self .is_first_microbatch = True
1038
-
1039
- extra_kwargs = {"params_dtype" : gpc .config .model .dtype }
1040
- if is_te_min_version ("0.12.0" ):
1041
- extra_kwargs ["device" ] = torch .cuda .current_device ()
1042
-
1043
- if gpc .config .parallel ["tensor" ]["tp_overlap" ]:
1044
- extra_kwargs ["ub_bulk_wgrad" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1045
- "tp_comm_bulk_wgrad" , True
1046
- )
1047
- extra_kwargs ["ub_bulk_dgrad" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1048
- "tp_comm_bulk_dgrad" , True
1049
- )
1050
- if is_te_min_version ("1.5.0" , check_equality = False ):
1051
- extra_kwargs ["ub_overlap_ag" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1052
- "tp_comm_overlap_ag" , True
1053
- )
1054
- else :
1055
- raise NotImplementedError ("tp overlap is supported only when transformer_engine version >= 1.5.0" )
1056
- assert (
1057
- tp_comm_buffer_name is not None
1058
- ), "Buffer name should be set to configure communication overlap settings"
1059
- extra_kwargs ["ub_name" ] = tp_comm_buffer_name
1060
-
1061
- parallel_mode = get_tensor_split_parallel_mode (is_expert = is_expert )
1062
- tp_size = gpc .get_world_size (parallel_mode )
1063
- tp_group = gpc .get_group (parallel_mode )
1064
- super ().__init__ (
1065
- in_features = in_features ,
1066
- out_features = out_features ,
1067
- sequence_parallel = gpc .config .parallel .sequence_parallel ,
1068
- tp_group = tp_group ,
1069
- tp_size = tp_size ,
1070
- bias = bias ,
1071
- return_bias = self .te_return_bias ,
1072
- parallel_mode = "column" ,
1073
- ** extra_kwargs ,
1074
- )
1075
-
1076
- def forward (self , x ):
1077
- """Forward."""
1078
- _is_first_microbatch = self .is_first_microbatch
1079
- x = x .transpose (0 , 1 )
1080
- out = super ().forward (x , is_first_microbatch = _is_first_microbatch )
1081
- out = out .transpose (0 , 1 )
1082
-
1083
- self .is_first_microbatch = False
1023
+ class TEColumnParallelLinear (te .pytorch .Linear ):
1024
+ """
1025
+ Wrapper for the Transformer-Engine's `Linear` layer.
1026
+ """
1084
1027
1085
- return out
1028
+ def __init__ (
1029
+ self ,
1030
+ in_features : int ,
1031
+ out_features : int ,
1032
+ bias : bool ,
1033
+ skip_bias_add : bool ,
1034
+ is_expert : bool ,
1035
+ tp_comm_buffer_name : str = None ,
1036
+ ):
1037
+ if is_expert :
1038
+ raise ValueError ("Transformer Engine linear layers do not yet support MoE" )
1039
+
1040
+ # TE returns a zero length Tensor when bias=False and
1041
+ # return_bias=True, but we prefer None. So in that case we
1042
+ # tell TE to not return the bias, and return None
1043
+ # ourselves. This way our forward always returns two values
1044
+ # and we don't have to deal with the zero length Tensor.
1045
+ self .te_return_bias = skip_bias_add and bias
1046
+ self .is_first_microbatch = True
1047
+
1048
+ extra_kwargs = {"params_dtype" : gpc .config .model .dtype }
1049
+ if is_te_min_version ("0.12.0" ):
1050
+ extra_kwargs ["device" ] = torch .cuda .current_device ()
1051
+
1052
+ if gpc .config .parallel ["tensor" ]["tp_overlap" ]:
1053
+ extra_kwargs ["ub_bulk_wgrad" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1054
+ "tp_comm_bulk_wgrad" , True
1055
+ )
1056
+ extra_kwargs ["ub_bulk_dgrad" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1057
+ "tp_comm_bulk_dgrad" , True
1058
+ )
1059
+ if is_te_min_version ("1.5.0" , check_equality = False ):
1060
+ extra_kwargs ["ub_overlap_ag" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1061
+ "tp_comm_overlap_ag" , True
1062
+ )
1063
+ else :
1064
+ raise NotImplementedError ("tp overlap is supported only when transformer_engine version >= 1.5.0" )
1065
+ assert (
1066
+ tp_comm_buffer_name is not None
1067
+ ), "Buffer name should be set to configure communication overlap settings"
1068
+ extra_kwargs ["ub_name" ] = tp_comm_buffer_name
1069
+
1070
+ parallel_mode = get_tensor_split_parallel_mode (is_expert = is_expert )
1071
+ tp_size = gpc .get_world_size (parallel_mode )
1072
+ tp_group = gpc .get_group (parallel_mode )
1073
+ super ().__init__ (
1074
+ in_features = in_features ,
1075
+ out_features = out_features ,
1076
+ sequence_parallel = gpc .config .parallel .sequence_parallel ,
1077
+ tp_group = tp_group ,
1078
+ tp_size = tp_size ,
1079
+ bias = bias ,
1080
+ return_bias = self .te_return_bias ,
1081
+ parallel_mode = "column" ,
1082
+ ** extra_kwargs ,
1083
+ )
1086
1084
1085
+ def forward (self , x ):
1086
+ """Forward."""
1087
+ _is_first_microbatch = self .is_first_microbatch
1088
+ x = x .transpose (0 , 1 )
1089
+ out = super ().forward (x , is_first_microbatch = _is_first_microbatch )
1090
+ out = out .transpose (0 , 1 )
1087
1091
1088
- class TERowParallelLinear (te .pytorch .Linear ):
1089
- """
1090
- Wrapper for the Transformer-Engine's `Linear` layer.
1091
- """
1092
+ self .is_first_microbatch = False
1092
1093
1093
- def __init__ (
1094
- self ,
1095
- in_features : int ,
1096
- out_features : int ,
1097
- bias : bool ,
1098
- skip_bias_add : bool ,
1099
- is_expert : bool = False ,
1100
- tp_comm_buffer_name : str = None ,
1101
- ):
1102
- # TE returns a zero length Tensor when bias=False and
1103
- # return_bias=True. Here we need a single Tensor
1104
- self .te_return_bias = skip_bias_add and bias
1105
- self .is_first_microbatch = True
1106
-
1107
- extra_kwargs = {"params_dtype" : gpc .config .model .dtype }
1108
- if is_te_min_version ("0.12.0" ):
1109
- extra_kwargs ["device" ] = torch .cuda .current_device ()
1110
-
1111
- if gpc .config .parallel ["tensor" ]["tp_overlap" ]:
1112
- if is_te_min_version ("1.5.0" ):
1113
- extra_kwargs ["ub_overlap_ag" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1114
- "tp_comm_overlap_ag" , True
1115
- )
1116
- extra_kwargs ["ub_overlap_rs" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1117
- "tp_comm_overlap_rs" , True
1118
- )
1119
- # Disable ub overlap for experts.
1120
- if is_expert :
1121
- extra_kwargs ["ub_overlap_ag" ] = False
1122
- extra_kwargs ["ub_overlap_rs" ] = False
1123
- else :
1124
- raise NotImplementedError ("tp overlap is supported only when transformer_engine version >= 1.5.0" )
1125
- assert (
1126
- tp_comm_buffer_name is not None
1127
- ), "Buffer name should be set to configure communication overlap settings"
1128
- extra_kwargs ["ub_name" ] = tp_comm_buffer_name
1094
+ return out
1129
1095
1130
- self .expert_parallel = gpc .config .parallel ["expert" ].get ("size" , 1 ) > 1
1131
- parallel_mode = get_tensor_split_parallel_mode (is_expert = is_expert )
1132
- # Disable communications in TE when using TP or EP by making TE agnostic of model parallel.
1133
- tp_size = gpc .get_world_size (parallel_mode )
1134
- tp_group = gpc .get_group (parallel_mode )
1135
- explicit_expert_comm = is_expert and (tp_size > 1 or self .expert_parallel )
1136
-
1137
- split_mode = "row"
1138
- if explicit_expert_comm :
1139
- assert in_features % tp_size == 0 , "{} is not divisible by {}" .format (in_features , tp_size )
1140
- in_features = in_features // tp_size
1141
- split_mode = None
1142
- tp_size = 1
1143
- tp_group = None
1096
+ class TERowParallelLinear (te .pytorch .Linear ):
1097
+ """
1098
+ Wrapper for the Transformer-Engine's `Linear` layer.
1099
+ """
1144
1100
1145
- super ().__init__ (
1146
- in_features = in_features ,
1147
- out_features = out_features ,
1148
- sequence_parallel = gpc .config .parallel .sequence_parallel ,
1149
- tp_group = tp_group ,
1150
- tp_size = tp_size ,
1151
- bias = bias ,
1152
- return_bias = self .te_return_bias ,
1153
- parallel_mode = split_mode ,
1154
- ** extra_kwargs ,
1155
- )
1101
+ def __init__ (
1102
+ self ,
1103
+ in_features : int ,
1104
+ out_features : int ,
1105
+ bias : bool ,
1106
+ skip_bias_add : bool ,
1107
+ is_expert : bool = False ,
1108
+ tp_comm_buffer_name : str = None ,
1109
+ ):
1110
+ # TE returns a zero length Tensor when bias=False and
1111
+ # return_bias=True. Here we need a single Tensor
1112
+ self .te_return_bias = skip_bias_add and bias
1113
+ self .is_first_microbatch = True
1114
+
1115
+ extra_kwargs = {"params_dtype" : gpc .config .model .dtype }
1116
+ if is_te_min_version ("0.12.0" ):
1117
+ extra_kwargs ["device" ] = torch .cuda .current_device ()
1118
+
1119
+ if gpc .config .parallel ["tensor" ]["tp_overlap" ]:
1120
+ if is_te_min_version ("1.5.0" ):
1121
+ extra_kwargs ["ub_overlap_ag" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1122
+ "tp_comm_overlap_ag" , True
1123
+ )
1124
+ extra_kwargs ["ub_overlap_rs" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1125
+ "tp_comm_overlap_rs" , True
1126
+ )
1127
+ # Disable ub overlap for experts.
1128
+ if is_expert :
1129
+ extra_kwargs ["ub_overlap_ag" ] = False
1130
+ extra_kwargs ["ub_overlap_rs" ] = False
1131
+ else :
1132
+ raise NotImplementedError ("tp overlap is supported only when transformer_engine version >= 1.5.0" )
1133
+ assert (
1134
+ tp_comm_buffer_name is not None
1135
+ ), "Buffer name should be set to configure communication overlap settings"
1136
+ extra_kwargs ["ub_name" ] = tp_comm_buffer_name
1137
+
1138
+ self .expert_parallel = gpc .config .parallel ["expert" ].get ("size" , 1 ) > 1
1139
+ parallel_mode = get_tensor_split_parallel_mode (is_expert = is_expert )
1140
+ # Disable communications in TE when using TP or EP by making TE agnostic of model parallel.
1141
+ tp_size = gpc .get_world_size (parallel_mode )
1142
+ tp_group = gpc .get_group (parallel_mode )
1143
+ explicit_expert_comm = is_expert and (tp_size > 1 or self .expert_parallel )
1144
+
1145
+ split_mode = "row"
1146
+ if explicit_expert_comm :
1147
+ assert in_features % tp_size == 0 , "{} is not divisible by {}" .format (in_features , tp_size )
1148
+ in_features = in_features // tp_size
1149
+ split_mode = None
1150
+ tp_size = 1
1151
+ tp_group = None
1152
+
1153
+ super ().__init__ (
1154
+ in_features = in_features ,
1155
+ out_features = out_features ,
1156
+ sequence_parallel = gpc .config .parallel .sequence_parallel ,
1157
+ tp_group = tp_group ,
1158
+ tp_size = tp_size ,
1159
+ bias = bias ,
1160
+ return_bias = self .te_return_bias ,
1161
+ parallel_mode = split_mode ,
1162
+ ** extra_kwargs ,
1163
+ )
1156
1164
1157
- for param in self .parameters ():
1158
- setattr (param , "allreduce" , not (is_expert and self .expert_parallel ))
1165
+ def forward (self , x ):
1166
+ """Forward."""
1167
+ _is_first_microbatch = self .is_first_microbatch
1168
+ x = x .transpose (0 , 1 )
1169
+ out = super ().forward (x , is_first_microbatch = _is_first_microbatch )
1170
+ out = out .transpose (0 , 1 )
1171
+ self .is_first_microbatch = False
1159
1172
1160
- def forward (self , x ):
1161
- """Forward."""
1162
- _is_first_microbatch = self .is_first_microbatch
1163
- x = x .transpose (0 , 1 )
1164
- out = super ().forward (x , is_first_microbatch = _is_first_microbatch )
1165
- out = out .transpose (0 , 1 )
1166
- self .is_first_microbatch = False
1173
+ return out
1167
1174
1168
- return out
1175
+ else :
1176
+ TEColumnParallelLinear = ColumnParallelLinear
1177
+ TERowParallelLinear = RowParallelLinear
1169
1178
1170
1179
1171
1180
def new_linear (
@@ -1217,7 +1226,7 @@ def new_linear(
1217
1226
weight_scale = weight_scale ,
1218
1227
norm_head = norm_head ,
1219
1228
)
1220
- elif split_mode == "column" :
1229
+ elif split_mode == "column" or ( split_mode == "tecolumn" and not has_te ) :
1221
1230
return ColumnParallelLinear (
1222
1231
in_features ,
1223
1232
out_features ,
@@ -1236,7 +1245,7 @@ def new_linear(
1236
1245
is_expert ,
1237
1246
tp_comm_buffer_name ,
1238
1247
)
1239
- elif split_mode == "row" :
1248
+ elif split_mode == "row" or ( split_mode == "terow" and not has_te ) :
1240
1249
return RowParallelLinear (
1241
1250
in_features ,
1242
1251
out_features ,
0 commit comments