@@ -1020,7 +1020,7 @@ def __init__(
1020
1020
1021
1021
if has_te :
1022
1022
1023
- class TEColumnParallelLinear (te .pytorch .Linear ):
1023
+ class TELinear (te .pytorch .Linear ):
1024
1024
"""
1025
1025
Wrapper for the Transformer-Engine's `Linear` layer.
1026
1026
"""
@@ -1032,34 +1032,36 @@ def __init__(
1032
1032
bias : bool ,
1033
1033
skip_bias_add : bool ,
1034
1034
is_expert : bool ,
1035
+ split_mode : str = "none" ,
1035
1036
tp_comm_buffer_name : str = None ,
1036
1037
):
1037
1038
if is_expert :
1038
1039
raise ValueError ("Transformer Engine linear layers do not yet support MoE" )
1039
1040
1040
1041
# TE returns a zero length Tensor when bias=False and
1041
- # return_bias=True, but we prefer None. So in that case we
1042
- # tell TE to not return the bias, and return None
1043
- # ourselves. This way our forward always returns two values
1044
- # and we don't have to deal with the zero length Tensor.
1042
+ # return_bias=True. Here we need a single Tensor
1045
1043
self .te_return_bias = skip_bias_add and bias
1046
1044
self .is_first_microbatch = True
1047
1045
1048
1046
extra_kwargs = {"params_dtype" : gpc .config .model .dtype }
1049
- if is_te_min_version ("0.12.0" ):
1050
- extra_kwargs ["device" ] = torch .cuda .current_device ()
1047
+ extra_kwargs ["device" ] = torch .cuda .current_device ()
1051
1048
1052
1049
if gpc .config .parallel ["tensor" ].get ("tp_overlap" , False ):
1053
- extra_kwargs ["ub_bulk_wgrad" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1054
- "tp_comm_bulk_wgrad" , True
1055
- )
1056
- extra_kwargs ["ub_bulk_dgrad" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1057
- "tp_comm_bulk_dgrad" , True
1058
- )
1059
1050
if is_te_min_version ("1.5.0" , check_equality = False ):
1060
1051
extra_kwargs ["ub_overlap_ag" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1061
1052
"tp_comm_overlap_ag" , True
1062
1053
)
1054
+ if split_mode == "column" :
1055
+ extra_kwargs ["ub_bulk_wgrad" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1056
+ "tp_comm_bulk_wgrad" , True
1057
+ )
1058
+ extra_kwargs ["ub_bulk_dgrad" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1059
+ "tp_comm_bulk_dgrad" , True
1060
+ )
1061
+ elif split_mode == "row" :
1062
+ extra_kwargs ["ub_overlap_rs" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1063
+ "tp_comm_overlap_rs" , True
1064
+ )
1063
1065
else :
1064
1066
raise NotImplementedError ("tp overlap is supported only when transformer_engine version >= 1.5.0" )
1065
1067
assert (
@@ -1070,6 +1072,7 @@ def __init__(
1070
1072
parallel_mode = get_tensor_split_parallel_mode (is_expert = is_expert )
1071
1073
tp_size = gpc .get_world_size (parallel_mode )
1072
1074
tp_group = gpc .get_group (parallel_mode )
1075
+
1073
1076
super ().__init__ (
1074
1077
in_features = in_features ,
1075
1078
out_features = out_features ,
@@ -1078,7 +1081,7 @@ def __init__(
1078
1081
tp_size = tp_size ,
1079
1082
bias = bias ,
1080
1083
return_bias = self .te_return_bias ,
1081
- parallel_mode = "column" ,
1084
+ parallel_mode = split_mode ,
1082
1085
** extra_kwargs ,
1083
1086
)
1084
1087
@@ -1088,14 +1091,13 @@ def forward(self, x):
1088
1091
x = x .transpose (0 , 1 )
1089
1092
out = super ().forward (x , is_first_microbatch = _is_first_microbatch )
1090
1093
out = out .transpose (0 , 1 )
1091
-
1092
1094
self .is_first_microbatch = False
1093
1095
1094
1096
return out
1095
1097
1096
- class TERowParallelLinear ( te . pytorch . Linear ):
1098
+ class TEColumnParallelLinear ( TELinear ):
1097
1099
"""
1098
- Wrapper for the Transformer-Engine's `Linear` layer.
1100
+ Wrapper for the TELinear layer.
1099
1101
"""
1100
1102
1101
1103
def __init__ (
@@ -1107,72 +1109,42 @@ def __init__(
1107
1109
is_expert : bool = False ,
1108
1110
tp_comm_buffer_name : str = None ,
1109
1111
):
1110
- # TE returns a zero length Tensor when bias=False and
1111
- # return_bias=True. Here we need a single Tensor
1112
- self .te_return_bias = skip_bias_add and bias
1113
- self .is_first_microbatch = True
1114
-
1115
- extra_kwargs = {"params_dtype" : gpc .config .model .dtype }
1116
- if is_te_min_version ("0.12.0" ):
1117
- extra_kwargs ["device" ] = torch .cuda .current_device ()
1118
-
1119
- if gpc .config .parallel ["tensor" ].get ("tp_overlap" , False ):
1120
- if is_te_min_version ("1.5.0" ):
1121
- extra_kwargs ["ub_overlap_ag" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1122
- "tp_comm_overlap_ag" , True
1123
- )
1124
- extra_kwargs ["ub_overlap_rs" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1125
- "tp_comm_overlap_rs" , True
1126
- )
1127
- # Disable ub overlap for experts.
1128
- if is_expert :
1129
- extra_kwargs ["ub_overlap_ag" ] = False
1130
- extra_kwargs ["ub_overlap_rs" ] = False
1131
- else :
1132
- raise NotImplementedError ("tp overlap is supported only when transformer_engine version >= 1.5.0" )
1133
- assert (
1134
- tp_comm_buffer_name is not None
1135
- ), "Buffer name should be set to configure communication overlap settings"
1136
- extra_kwargs ["ub_name" ] = tp_comm_buffer_name
1137
-
1138
- self .expert_parallel = gpc .config .parallel ["expert" ].get ("size" , 1 ) > 1
1139
- parallel_mode = get_tensor_split_parallel_mode (is_expert = is_expert )
1140
- # Disable communications in TE when using TP or EP by making TE agnostic of model parallel.
1141
- tp_size = gpc .get_world_size (parallel_mode )
1142
- tp_group = gpc .get_group (parallel_mode )
1143
- explicit_expert_comm = is_expert and (tp_size > 1 or self .expert_parallel )
1144
-
1145
- split_mode = "row"
1146
- if explicit_expert_comm :
1147
- assert in_features % tp_size == 0 , "{} is not divisible by {}" .format (in_features , tp_size )
1148
- in_features = in_features // tp_size
1149
- split_mode = None
1150
- tp_size = 1
1151
- tp_group = None
1152
-
1153
1112
super ().__init__ (
1154
- in_features = in_features ,
1155
- out_features = out_features ,
1156
- sequence_parallel = gpc .config .parallel .sequence_parallel ,
1157
- tp_group = tp_group ,
1158
- tp_size = tp_size ,
1113
+ in_features ,
1114
+ out_features ,
1159
1115
bias = bias ,
1160
- return_bias = self .te_return_bias ,
1161
- parallel_mode = split_mode ,
1162
- ** extra_kwargs ,
1116
+ skip_bias_add = skip_bias_add ,
1117
+ is_expert = is_expert ,
1118
+ split_mode = "column" ,
1119
+ tp_comm_buffer_name = tp_comm_buffer_name ,
1163
1120
)
1164
1121
1165
- def forward (self , x ):
1166
- """Forward."""
1167
- _is_first_microbatch = self .is_first_microbatch
1168
- x = x .transpose (0 , 1 )
1169
- out = super ().forward (x , is_first_microbatch = _is_first_microbatch )
1170
- out = out .transpose (0 , 1 )
1171
- self .is_first_microbatch = False
1122
+ class TERowParallelLinear (TELinear ):
1123
+ """
1124
+ Wrapper for the TELinear layer.
1125
+ """
1172
1126
1173
- return out
1127
+ def __init__ (
1128
+ self ,
1129
+ in_features : int ,
1130
+ out_features : int ,
1131
+ bias : bool ,
1132
+ skip_bias_add : bool ,
1133
+ is_expert : bool = False ,
1134
+ tp_comm_buffer_name : str = None ,
1135
+ ):
1136
+ super ().__init__ (
1137
+ in_features ,
1138
+ out_features ,
1139
+ bias = bias ,
1140
+ skip_bias_add = skip_bias_add ,
1141
+ is_expert = is_expert ,
1142
+ split_mode = "row" ,
1143
+ tp_comm_buffer_name = tp_comm_buffer_name ,
1144
+ )
1174
1145
1175
1146
else :
1147
+ TELinear = ParallelLinearWithCommExt
1176
1148
TEColumnParallelLinear = ColumnParallelLinear
1177
1149
TERowParallelLinear = RowParallelLinear
1178
1150
0 commit comments