Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add tp overlap feature #416

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion configs/7B_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,17 @@
"""
parallel = dict(
zero1=dict(size=-1),
tensor=dict(size=1, mode="mtp"),
tensor=dict(
size=1,
mode="mtp",
tp_overlap=False,
tp_overlap_cfg=dict(
tp_comm_overlap_ag=True,
tp_comm_overlap_rs=True,
tp_comm_bulk_wgrad=True,
tp_comm_bulk_dgrad=True,
),
),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True),
)
Expand Down
10 changes: 8 additions & 2 deletions internlm/core/parallel/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,19 @@ def get_parallel_strategies_split_mode(linear_name: str) -> str:
if linear_name in ("gate"):
return "gate" # for MoE model
elif linear_name in ("wqkv", "wq", "wk", "wv", "wkv", "w1", "w3", "w13"):
return "column"
if gpc.config.parallel["tensor"].get("tp_overlap", False):
return "tecolumn"
else:
return "column"
elif linear_name in ("fc1", "fc2", "linear_1", "linear_2"): # for vit model
return "column"
elif linear_name in ("wo", "out_proj", "w2") and tp_mode == TensorParallelMode.isp.name:
return "column"
elif linear_name in ("wo", "out_proj", "w2"):
return "row"
if gpc.config.parallel["tensor"].get("tp_overlap", False):
return "terow"
else:
return "row"
elif linear_name in ("grouped_w1", "grouped_w2", "grouped_w3") and tp_mode == "isp":
return "grouped_wp"
elif linear_name in ("grouped_w1", "grouped_w3"):
Expand Down
1 change: 0 additions & 1 deletion internlm/core/scheduler/pipeline_scheduler_zb.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,6 @@ def _run_steady_loop(
else:
next_unit_chunk_id = 1

# import pdb; pdb.set_trace()
if unit_step == num_units_stage1 - 1:
chunk0_B_need_recv_prev_chunk0_output = False
else:
Expand Down
33 changes: 33 additions & 0 deletions internlm/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from internlm.initialize.initialize_trainer import initialize_trainer
from internlm.model.losses.ce_loss import InternLoss
from internlm.model.metrics import AccPerplex
from internlm.model.modules.utils import is_te_min_version
from internlm.monitor.monitor import send_alert_message
from internlm.train.pipeline import (
get_scheduler_hooks,
Expand Down Expand Up @@ -154,6 +155,9 @@ def __init__(
scheduler_hooks=get_scheduler_hooks(self.metric, optimizer, isp_communicator),
)

if gpc.config.parallel["tensor"].get("tp_overlap", False):
self._initialize_tp_comm_ub()

# set attributes
self._set_attributes(
kwargs["profiling"], train_dl, val_dls, train_state, optimizer, beta2_scheduler, isp_communicator
Expand Down Expand Up @@ -249,6 +253,35 @@ def _initialize_batch_skipper(self, train_state) -> BatchSkipper:
skip_batches = streaming_simple_resume(train_state)
return BatchSkipper(skip_batches)

def _initialize_tp_comm_ub(self):
"""initializing the communicators with user buffers for high-performance tensor-model-parallel
communication overlap"""
try:
from transformer_engine.pytorch import module as te_module

except ImportError:
raise RuntimeError(
"Tensor Parallel Communication/GEMM Overlap optimization needs 'transformer_engine' package"
)

input_shape = [gpc.config.data["seq_len"] * gpc.config.data["micro_bsz"], gpc.config.model["hidden_size"]]

if is_te_min_version("1.9.0"):
# The process group with the target bootstrap backend is created in Transformer Engine.
te_module.base.initialize_ub(
shape=input_shape,
tp_size=gpc.config.parallel["tensor"]["size"],
use_fp8=False,
bootstrap_backend="nccl",
)
else:
# Create a MPI process group to help with TP communication overlap bootstrap.
torch.distributed.new_group(backend="mpi")

te_module.base.initialize_ub(
shape=input_shape, tp_size=gpc.config.parallel["tensor"]["size"], use_fp8=False
)

def _set_attributes(self, profiling, train_dl, val_dls, train_state, optimizer, beta2_scheduler, isp_communicator):
self.profiling = profiling
self.train_dl = train_dl
Expand Down
24 changes: 22 additions & 2 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def args_sanity_check():
gpc.config.parallel.pipeline._add_item("mode", "1F1B")

if "tensor" not in gpc.config.parallel:
gpc.config.parallel._add_item("tensor", dict(size=1, mode=TensorParallelMode.mtp.name))
gpc.config.parallel._add_item(
"tensor", dict(size=1, mode=TensorParallelMode.mtp.name, tp_overlap=False, tp_overlap_cfg=None)
)

if "weight" not in gpc.config.parallel:
gpc.config.parallel._add_item(
Expand Down Expand Up @@ -398,7 +400,9 @@ def args_sanity_check():

# set default value for tensor parallel
if isinstance(gpc.config.parallel["tensor"], int):
gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode=TensorParallelMode.mtp.name)
gpc.config.parallel["tensor"] = dict(
size=gpc.config.parallel["tensor"], mode=TensorParallelMode.mtp.name, tp_overlap=False, tp_overlap_cfg=None
)
if gpc.config.parallel["tensor"].get("mode", None) is None:
gpc.config.parallel["tensor"]["mode"] = TensorParallelMode.mtp.name
if gpc.config.parallel["tensor"]["mode"] == TensorParallelMode.isp.name:
Expand Down Expand Up @@ -455,6 +459,22 @@ def args_sanity_check():
if gpc.config.model.get("parallel_output", False) is False:
logger.warning("When enable sequence parallel, it recommend to enable parallel_output")

if gpc.config.parallel["tensor"].get("tp_overlap", None) is None:
gpc.config.parallel["tensor"]["tp_overlap"] = False
elif gpc.config.parallel["tensor"].get("tp_overlap", None) is True:
assert gpc.config.parallel["tensor"].get("mode", None) in [
TensorParallelMode.msp.name,
TensorParallelMode.fsp.name,
], "tp_overlap can be set to true only in msp and fsp mode"

if gpc.config.parallel["tensor"].get("tp_overlap_cfg", None) is None:
gpc.config.parallel["tensor"]["tp_overlap_cfg"] = dict(
tp_comm_overlap_ag=True,
tp_comm_overlap_rs=True,
tp_comm_bulk_wgrad=True,
tp_comm_bulk_dgrad=True,
)

# set default value for weight parallel
if gpc.config.parallel["weight"].get("overlap", None) is None:
gpc.config.parallel["weight"]["overlap"] = False
Expand Down
29 changes: 24 additions & 5 deletions internlm/model/modeling_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,11 +535,30 @@ def load_hf_weights(folder: str, model: nn.Module) -> None:
dim=0,
)[local_rank]
else:
state_dict[f"layers.{i}.attention.wqkv.weight"] = torch.chunk(
state_dict.pop(f"model.layers.{layer_ids}.attention.wqkv.weight"),
split_size,
dim=0,
)[local_rank]
key = f"model.layers.{layer_ids}.attention.wqkv.weight"
if key in state_dict:
state_dict[f"layers.{i}.attention.wqkv.weight"] = torch.chunk(
state_dict.pop(key),
split_size,
dim=0,
)[local_rank]
else:
wq = torch.chunk(
state_dict.pop(f"model.layers.{layer_ids}.attention.wq.weight"),
split_size,
dim=0,
)[local_rank]
wk = torch.chunk(
state_dict.pop(f"model.layers.{layer_ids}.attention.wk.weight"),
split_size,
dim=0,
)[local_rank]
wv = torch.chunk(
state_dict.pop(f"model.layers.{layer_ids}.attention.wv.weight"),
split_size,
dim=0,
)[local_rank]
state_dict[f"layers.{i}.attention.wqkv.weight"] = torch.cat([wq, wk, wv], dim=0)
wo_name = "self_attn.o_proj" if is_internlm3 else "attention.wo"
state_dict[f"layers.{i}.attention.wo.weight"] = torch.chunk(
state_dict.pop(f"model.layers.{layer_ids}.{wo_name}.weight"),
Expand Down
163 changes: 161 additions & 2 deletions internlm/model/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@

import torch
import torch.distributed as dist

try:
import transformer_engine as te

has_te = True
except (ModuleNotFoundError, ImportError):
has_te = False

from torch import nn

from internlm.accelerator import get_accelerator
Expand All @@ -19,6 +27,7 @@
get_parallel_strategies_split_mode,
get_tensor_split_parallel_mode,
)
from internlm.model.modules.utils import is_te_min_version
from internlm.model.ops.linear import (
gmm_backward_op,
gmm_forward_op,
Expand Down Expand Up @@ -1009,6 +1018,137 @@ def __init__(
self.full_weight_shape = torch.Size((num_groups, in_features, out_features))


if has_te:

class TELinear(te.pytorch.Linear):
"""
Wrapper for the Transformer-Engine's `Linear` layer.
"""

def __init__(
self,
in_features: int,
out_features: int,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
split_mode: str = "none",
tp_comm_buffer_name: str = None,
):
if is_expert:
raise ValueError("Transformer Engine linear layers do not yet support MoE")

# TE returns a zero length Tensor when bias=False and
# return_bias=True. Here we need a single Tensor
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True

extra_kwargs = {"params_dtype": gpc.config.model.dtype}
extra_kwargs["device"] = torch.cuda.current_device()

if gpc.config.parallel["tensor"].get("tp_overlap", False):
if is_te_min_version("1.5.0", check_equality=False):
extra_kwargs["ub_overlap_ag"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
"tp_comm_overlap_ag", True
)
if split_mode == "column":
extra_kwargs["ub_bulk_wgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
"tp_comm_bulk_wgrad", True
)
extra_kwargs["ub_bulk_dgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
"tp_comm_bulk_dgrad", True
)
elif split_mode == "row":
extra_kwargs["ub_overlap_rs"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
"tp_comm_overlap_rs", True
)
else:
raise NotImplementedError("tp overlap is supported only when transformer_engine version >= 1.5.0")
assert (
tp_comm_buffer_name is not None
), "Buffer name should be set to configure communication overlap settings"
extra_kwargs["ub_name"] = tp_comm_buffer_name

parallel_mode = get_tensor_split_parallel_mode(is_expert=is_expert)
tp_size = gpc.get_world_size(parallel_mode)
tp_group = gpc.get_group(parallel_mode)

super().__init__(
in_features=in_features,
out_features=out_features,
sequence_parallel=gpc.config.parallel.sequence_parallel,
tp_group=tp_group,
tp_size=tp_size,
bias=bias,
return_bias=self.te_return_bias,
parallel_mode=split_mode,
**extra_kwargs,
)

def forward(self, x):
"""Forward."""
_is_first_microbatch = self.is_first_microbatch
x = x.transpose(0, 1)
out = super().forward(x, is_first_microbatch=_is_first_microbatch)
out = out.transpose(0, 1)
self.is_first_microbatch = False

return out

class TEColumnParallelLinear(TELinear):
"""
Wrapper for the TELinear layer.
"""

def __init__(
self,
in_features: int,
out_features: int,
bias: bool,
skip_bias_add: bool,
is_expert: bool = False,
tp_comm_buffer_name: str = None,
):
super().__init__(
in_features,
out_features,
bias=bias,
skip_bias_add=skip_bias_add,
is_expert=is_expert,
split_mode="column",
tp_comm_buffer_name=tp_comm_buffer_name,
)

class TERowParallelLinear(TELinear):
"""
Wrapper for the TELinear layer.
"""

def __init__(
self,
in_features: int,
out_features: int,
bias: bool,
skip_bias_add: bool,
is_expert: bool = False,
tp_comm_buffer_name: str = None,
):
super().__init__(
in_features,
out_features,
bias=bias,
skip_bias_add=skip_bias_add,
is_expert=is_expert,
split_mode="row",
tp_comm_buffer_name=tp_comm_buffer_name,
)

else:
TELinear = ParallelLinearWithCommExt
TEColumnParallelLinear = ColumnParallelLinear
TERowParallelLinear = RowParallelLinear


def new_linear(
name: str,
in_features: int,
Expand All @@ -1021,6 +1161,7 @@ def new_linear(
weight_scale: int = 1,
norm_head: bool = False,
is_expert: bool = False,
tp_comm_buffer_name: str = None,
**kwargs,
) -> nn.Linear:

Expand Down Expand Up @@ -1057,7 +1198,7 @@ def new_linear(
weight_scale=weight_scale,
norm_head=norm_head,
)
elif split_mode == "column":
elif split_mode == "column" or (split_mode == "tecolumn" and not has_te):
return ColumnParallelLinear(
in_features,
out_features,
Expand All @@ -1067,7 +1208,16 @@ def new_linear(
dtype,
is_expert,
)
elif split_mode == "row":
elif split_mode == "tecolumn":
return TEColumnParallelLinear(
in_features,
out_features,
bias,
False,
is_expert,
tp_comm_buffer_name,
)
elif split_mode == "row" or (split_mode == "terow" and not has_te):
return RowParallelLinear(
in_features,
out_features,
Expand All @@ -1077,6 +1227,15 @@ def new_linear(
dtype,
is_expert,
)
elif split_mode == "terow":
return TERowParallelLinear(
in_features,
out_features,
bias,
False,
is_expert,
tp_comm_buffer_name,
)
elif split_mode == "grouped_wp":
return GroupedWPLinear(
in_features,
Expand Down
Loading
Loading