diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index e3322d3..ea7dce3 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -17,6 +17,7 @@ #include "infini_train/include/nn/parallel/distributed_data_parallel.h" #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/parallel_functional.h" +#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" #include "infini_train/include/nn/parallel/rank.h" #include "infini_train/include/nn/parallel/reduce_op_type.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" @@ -63,6 +64,8 @@ DEFINE_int32( "When set > 1, enables data parallelism with device=cuda on the specified number of visible CUDA devices."); DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size"); DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel"); +DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages."); + // precision DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)"); @@ -106,6 +109,7 @@ void Train(const nn::parallel::Rank &rank) { int ddp_world_size = global::GetDataParallelSize(); int tp_world_size = global::GetTensorParallelSize(); int sp_world_size = global::GetSequenceParallelEnabled() ? tp_world_size : 0; + int pp_world_size = global::GetPipelineParallelSize(); if (FLAGS_sequence_parallel) { CHECK_EQ(FLAGS_sequence_length % tp_world_size, 0) @@ -114,9 +118,11 @@ void Train(const nn::parallel::Rank &rank) { int ddp_rank = 0; int tp_rank = 0; + int pp_rank = 0; const ProcessGroup *ddp_pg = nullptr; const ProcessGroup *tp_pg = nullptr; + const ProcessGroup *pp_pg = nullptr; if (rank.IsParallel()) { device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank()); @@ -134,6 +140,14 @@ void Train(const nn::parallel::Rank &rank) { // NOTE(zbl): Reserved for VocabParallelEmbedding nn::parallel::tp_rank = tp_rank; } + + if (pp_world_size > 1) { + pp_pg = ProcessGroupFactory::Instance()->GetOrCreate( + GetPipelineParallelProcessGroupName(rank.thread_rank()), GetPipelineParallelGroupRanks(pp_world_size)); + pp_rank = pp_pg->GetGroupRank(rank.thread_rank()); + + nn::parallel::pp_rank = pp_rank; + } } else { device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDefaultDevice() : DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, 0); @@ -182,8 +196,10 @@ void Train(const nn::parallel::Rank &rank) { model = std::make_shared(model, rank.thread_rank()); } + auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size); DistributedDataLoader train_loader(std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), - FLAGS_batch_size, ddp_rank, ddp_world_size); + FLAGS_batch_size * num_micro_batches, ddp_rank, ddp_world_size); + std::optional val_loader = std::nullopt; if (!FLAGS_input_val_bin.empty()) { val_loader = DistributedDataLoader( @@ -211,6 +227,15 @@ void Train(const nn::parallel::Rank &rank) { loss_fn->To(device); LOG(INFO) << "Rank " << rank.thread_rank() << ": start training"; + if (pp_world_size > 1) { + auto shapes = std::vector>{{FLAGS_batch_size, FLAGS_sequence_length, model_config.n_embd}}; + + model = std::make_shared(model, pp_world_size, num_micro_batches, shapes, + pp_rank, std::make_shared(optimizer)); + } + + LOG(INFO) << "start training"; + for (int step = 0; step < FLAGS_num_iteration + 1; ++step) { const bool last_step = step == FLAGS_num_iteration; @@ -233,64 +258,80 @@ void Train(const nn::parallel::Rank &rank) { break; } - // model->Train(); - optimizer.ZeroGrad(); - // if we are trying to overfit a single batch, we reset the loader here - if (FLAGS_overfit_single_batch) { - // train_loader.Reset(); - } - float lossf = 0.0f; #ifdef PROFILE_MODE Profiler::Instance().SetTag("Step_" + std::to_string(step)); #endif - for (int micro_step = 0; micro_step < grad_accum_steps; ++micro_step) { - // enable autocast for the current step - infini_train::AutocastGuard autocast_guard(device->Type(), dtype); - // (bs, seq_len), (bs, seq_len) + float lossf = 0.0f; + // model->Train(); + if (pp_world_size == 1) { + optimizer.ZeroGrad(); + + // if we are trying to overfit a single batch, we reset the loader here + if (FLAGS_overfit_single_batch) { + // train_loader.Reset(); + } + + for (int micro_step = 0; micro_step < grad_accum_steps; ++micro_step) { + // enable autocast for the current step + infini_train::AutocastGuard autocast_guard(device->Type(), dtype); + + // (bs, seq_len), (bs, seq_len) + auto [x, y] = *train_iter; + // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below + // TODO(dcj): support dataloader.reset() later + ++train_iter; + x = std::make_shared(x->To(device)); + y = std::make_shared(y->To(device)); + + LOG(INFO) << "Rank " << rank.thread_rank() << ": start forward"; + // (bs, seq_len, vocab_size) + auto logits = model->Forward({x, y})[0]; + LOG(INFO) << "Rank " << rank.thread_rank() << ": finish model forward, start loss forward"; + auto loss = loss_fn->Forward({logits, y})[0]; + loss = loss / grad_accum_steps; + + // disable autocast for the current step (backward is not under autocast) + autocast_guard.Disable(); + + LOG(INFO) << "Rank " << rank.thread_rank() << ": finish loss forward"; + if (ddp_world_size > 1) { + function::AllReduce(loss, function::ReduceOpType::kAvg); + } + auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice()); + lossf += static_cast(loss_cpu.DataPtr())[0]; + LOG(INFO) << "Rank " << rank.thread_rank() << ": start backward"; + loss->Backward(); + LOG(INFO) << "Rank " << rank.thread_rank() << ": finish backward"; + } + + optimizer.Step(); + } else { auto [x, y] = *train_iter; // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below // TODO(dcj): support dataloader.reset() later ++train_iter; x = std::make_shared(x->To(device)); y = std::make_shared(y->To(device)); - LOG(INFO) << "Rank " << rank.thread_rank() << ": start forward"; - // (bs, seq_len, vocab_size) - auto logits = model->Forward({x, y})[0]; - LOG(INFO) << "Rank " << rank.thread_rank() << ": finish model forward, start loss forward"; - auto loss = loss_fn->Forward({logits, y})[0]; - loss = loss / grad_accum_steps; - - // disable autocast for the current step (backward is not under autocast) - autocast_guard.Disable(); - - LOG(INFO) << "Rank " << rank.thread_rank() << ": finish loss forward"; - if (ddp_world_size > 1) { - function::AllReduce(loss, function::ReduceOpType::kAvg); - } - auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice()); - lossf += static_cast(loss_cpu.DataPtr())[0]; - LOG(INFO) << "Rank " << rank.thread_rank() << ": start backward"; - loss->Backward(); - LOG(INFO) << "Rank " << rank.thread_rank() << ": finish backward"; - } - optimizer.Step(); + lossf = model->TrainStep({x}, {y}, loss_fn); + } const auto iter_end = std::chrono::high_resolution_clock::now(); const double duration_us = std::chrono::duration(iter_end - iter_start).count(); const double tps = FLAGS_total_batch_size / (duration_us / 1e6); - if (rank.IsMainRank()) { - LOG(ERROR) << std::format( - "step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, DP={}, TP={}, SP={})", - step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, tps, ddp_world_size, - tp_world_size, sp_world_size); + if (rank.thread_rank() == pp_world_size - 1) { + LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, " + "DP={}, TP={}, SP={}, PP={})", + step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, + tps, ddp_world_size, tp_world_size, sp_world_size, pp_world_size); if ((step + 1) % FLAGS_freq_generate_txt == 0) { - if (!tokenizer) { - continue; + if (tokenizer) { + // FIXME(jym): to support PP + CHECK_EQ(pp_world_size, 1); + tokenizer->GenerateText(*model, FLAGS_batch_size, FLAGS_sequence_length, FLAGS_text_length, device); } - tokenizer->GenerateText(*model, FLAGS_batch_size, FLAGS_sequence_length, FLAGS_text_length, device); } } } @@ -304,7 +345,8 @@ int main(int argc, char *argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, true); google::InitGoogleLogging(argv[0]); - nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel); + nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel, + FLAGS_pipeline_parallel); LOG(INFO) << nn::parallel::global::ProcessGroupOverview(); diff --git a/example/gpt2/net.cc b/example/gpt2/net.cc index a023316..259439c 100644 --- a/example/gpt2/net.cc +++ b/example/gpt2/net.cc @@ -21,6 +21,7 @@ #include "infini_train/include/nn/modules/normalization.h" #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" #include "infini_train/include/nn/parallel/process_group.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" #include "infini_train/include/nn/parallel/utils.h" @@ -176,6 +177,10 @@ Block::Forward(const std::vector> &x) { } GPT2::GPT2(const GPT2Config &config) : config_(config) { + int pp_size = nn::parallel::global::GetPipelineParallelSize(); + auto [is_first_stage, is_last_stage, start_layer, end_layer] + = nn::parallel::PipelineParallel::GetStageInfo(config_.n_layer, pp_size); + auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); // NOTE(zbl): VocabParallelEmbedding requires vocab_size % tp_size == 0 @@ -184,77 +189,102 @@ GPT2::GPT2(const GPT2Config &config) : config_(config) { CHECK_EQ(config.vocab_size % tp_world_size, 0) << "Vocab size should be divisible by TP world size"; { std::unordered_map> transformer; - transformer[kWTELayerName] = std::make_shared( - config_.vocab_size, config_.n_embd, nn::parallel::global::GetSequenceParallelEnabled()); - transformer[kWPELayerName] = std::make_shared(config_.block_size, config_.n_embd); + if (is_first_stage) { + transformer[kWTELayerName] = std::make_shared( + config_.vocab_size, config_.n_embd, nn::parallel::global::GetSequenceParallelEnabled()); + transformer[kWPELayerName] = std::make_shared(config_.block_size, config_.n_embd); + } + { std::vector> h; - for (int64_t i = 0; i < config_.n_layer; ++i) { h.push_back(std::make_shared(config_)); } - transformer[kHLayerName] = std::make_shared(std::move(h)); + for (int64_t i = start_layer; i < end_layer; ++i) { h.push_back(std::make_shared(config_)); } + transformer[kHLayerName] = std::make_shared(std::move(h)); + if (is_last_stage) { + transformer[kLnFLayerName] = std::make_shared(std::vector{config_.n_embd}); + } + + modules_[kTransformerLayerName] = std::make_shared(std::move(transformer)); + } + if (is_last_stage) { + // don't init this one, we will tie weights + modules_[kLMHeadLayerName] = std::make_shared( + /*in_features=*/config_.n_embd, /*out_features=*/config_.vocab_size, + /*bias=*/false, + // NOTE(zbl): each tp_rank would get sharded [B, T, V_local] as logits + /*gather_output=*/false, + /*input_is_parallel=*/false, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); } - transformer[kLnFLayerName] = std::make_shared(std::vector{config_.n_embd}); - modules_[kTransformerLayerName] = std::make_shared(std::move(transformer)); - } - // don't init this one, we will tie weights - modules_[kLMHeadLayerName] = std::make_shared( - /*in_features=*/config_.n_embd, /*out_features=*/config_.vocab_size, - /*bias=*/false, - // NOTE(zbl): each tp_rank would get sharded [B, T, V_local] as logits - /*gather_output=*/false, - /*input_is_parallel=*/false, - /*skip_bias_add=*/false, - /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); - // https://paperswithcode.com/method/weight-tying - *mutable_module(kTransformerLayerName) - ->mutable_module(kWTELayerName) - ->mutable_parameter(nn::parallel::VocabParallelEmbedding::kParamWeightName) - = module(kLMHeadLayerName).parameter(nn::parallel::ColumnParallelLinear::kParamWeightName); + // FIXME(jym): Assigning the parameter values of wte to LMHead, which is not real tying operation + if (pp_size == 1) { + // https://paperswithcode.com/method/weight-tying + *mutable_module(kTransformerLayerName) + ->mutable_module(kWTELayerName) + ->mutable_parameter(nn::parallel::VocabParallelEmbedding::kParamWeightName) + = module(kLMHeadLayerName).parameter(nn::parallel::ColumnParallelLinear::kParamWeightName); + } + } } std::vector> GPT2::Forward(const std::vector> &x) { + int pp_rank = nn::parallel::pp_rank; + int pp_size = nn::parallel::global::GetPipelineParallelSize(); + bool is_first_stage = (pp_rank == 0); + bool is_last_stage = (pp_rank == pp_size - 1); + // (B, T) - auto &idx = x[0]; - const auto device = idx->GetDevice(); + auto x1 = x[0]; + const auto device = x1->GetDevice(); - const auto t = idx->Dims()[1]; // T + const auto t = x1->Dims()[1]; // T CHECK_LE(t, config_.block_size) << "Cannot forward sequence of length " << t << ", block size is only " << config_.block_size; - // (T_local) - // NOTE(zbl): Slice pos sequence when SP is enabled - auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); - auto sequence_parallel_enabled = nn::parallel::global::GetSequenceParallelEnabled(); - int tp_rank = 0; - if (tp_world_size > 1) { - auto tp_group = nn::parallel::ProcessGroupFactory::Instance()->Get( - nn::parallel::GetTensorParallelProcessGroupName(device->rank().thread_rank())); - tp_rank = tp_group->GetGroupRank(device->rank().thread_rank()); - } - - int64_t t_local = sequence_parallel_enabled ? (t / tp_world_size) : t; - int64_t start = sequence_parallel_enabled ? tp_rank * t_local : 0; - auto pos = nn::init::Arange(start, start + t_local, infini_train::DataType::kINT64, device); // forward the GPT2 model itself auto &transformer = modules_[kTransformerLayerName]; - // (B, T) -> Embedding(V_local, C) -> (B, T, C) - auto tok_emb = transformer->mutable_module(kWTELayerName)->Forward({idx})[0]; - // (T) -> Embedding(T_max, C) -> (T, C) - auto pos_emb = transformer->mutable_module(kWPELayerName)->Forward({pos})[0]; - // (B, T, C) - auto x1 = tok_emb + pos_emb; + if (is_first_stage) { + // (T_local) + // NOTE(zbl): Slice pos sequence when SP is enabled + auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); + auto sequence_parallel_enabled = nn::parallel::global::GetSequenceParallelEnabled(); + int tp_rank = 0; + if (tp_world_size > 1) { + auto tp_group = nn::parallel::ProcessGroupFactory::Instance()->Get( + nn::parallel::GetTensorParallelProcessGroupName(device->rank().thread_rank())); + tp_rank = tp_group->GetGroupRank(device->rank().thread_rank()); + } + int64_t t_local = sequence_parallel_enabled ? (t / tp_world_size) : t; + int64_t start = sequence_parallel_enabled ? tp_rank * t_local : 0; + auto pos = nn::init::Arange(start, start + t_local, infini_train::DataType::kINT64, device); + + // (B, T) -> Embedding(V_local, C) -> (B, T, C) + auto tok_emb = transformer->mutable_module(kWTELayerName)->Forward({x1})[0]; + // (T) -> Embedding(T_max, C) -> (T, C) + auto pos_emb = transformer->mutable_module(kWPELayerName)->Forward({pos})[0]; + // (B, T, C) + x1 = tok_emb + pos_emb; + } // (B, T, C) -> transformer -> (B, T, C) - auto x2 = transformer->mutable_module(kHLayerName)->Forward({x1}); - // (B, T, C) -> Layernorm -> (B, T, C) - auto x3 = transformer->mutable_module(kLnFLayerName)->Forward(x2); - - // TODO(dcj): add inference-time mini-optimization - // (B, T, C) -> Linear(C, V) -> (B, T, V) - auto logits = modules_[kLMHeadLayerName]->Forward(x3); - - // (B, T, V_original) - return logits; + auto h_modules = transformer->mutable_module(kHLayerName); + CHECK_EQ(h_modules->type(), nn::ModuleList::kType) << "Failed to get ModuleList from transformer"; + auto h_layers = std::dynamic_pointer_cast(h_modules); + // (bs, seq_len, n_embd) -> transformer -> (bs, seq_len, n_embd) + for (auto &h : *h_layers) { x1 = h->Forward({x1})[0]; } + + if (is_last_stage) { + // (B, T, C) -> Layernorm -> (B, T, C) + auto x3 = transformer->mutable_module(kLnFLayerName)->Forward({x1}); + + // TODO(dcj): add inference-time mini-optimization + // (B, T, C) -> Linear(C, V) -> (B, T, V) + auto logits = modules_[kLMHeadLayerName]->Forward(x3); + // (B, T, V_original) + return logits; + } + return {x1}; } std::shared_ptr GPT2::FromPretrained(ModelType model_type) { @@ -321,6 +351,10 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { CHECK_EQ(n_embd % n_head, 0) << "n_embd must be divisible by n_head."; CHECK_EQ(n_head % tp_size, 0) << "n_head must be divisible by TP world size."; + int pp_size = nn::parallel::global::GetPipelineParallelSize(); + auto [is_first_stage, is_last_stage, start_layer, end_layer] + = nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size); + auto tp_rank = nn::parallel::tp_rank; // calculate xx_size_per_partition const int64_t vpp = model_vocab_size / tp_size; @@ -343,163 +377,274 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { // transformer.wte.weight (also transformer.lm_head.weight) // full: (model_vocab_size, n_embd) // local: (vocab_size_per_partition, n_embd) - auto &transformer_wte_weight = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kWTELayerName, - nn::parallel::VocabParallelEmbedding::kParamWeightName)]; - ReadMatrixRowShardFloat(ifs, static_cast(transformer_wte_weight->DataPtr()), model_vocab_size, n_embd, - v_start, vpp); + if (is_first_stage) { + auto &transformer_wte_weight + = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kWTELayerName, + nn::parallel::VocabParallelEmbedding::kParamWeightName)]; + ReadMatrixRowShardFloat(ifs, static_cast(transformer_wte_weight->DataPtr()), model_vocab_size, n_embd, + v_start, vpp); + } else if (pp_size > 1 && is_last_stage) { + auto &lm_head_weight = state_dict[std::format("{}.{}", GPT2::kLMHeadLayerName, + nn::parallel::ColumnParallelLinear::kParamWeightName)]; + ifs.read(reinterpret_cast(lm_head_weight->DataPtr()), lm_head_weight->SizeInBytes()); + } else { + size_t wte_bytes = vocab_size * n_embd * sizeof(float); + ifs.seekg(wte_bytes, std::ios::cur); + } + if (tp_size == 1) { // Skip padded vocab part when TP is not enabled ifs.ignore((padded_vocab_size - vocab_size) * n_embd * sizeof(float)); } - // transformer.wpe.weight - auto &transformer_wpe_weight = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kWPELayerName, - nn::Embedding::kParamWeightName)]; - ReadMatrixAllFloat(ifs, static_cast(transformer_wpe_weight->DataPtr()), block_size, n_embd); + + if (is_first_stage) { + // transformer.wpe.weight + auto &transformer_wpe_weight = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, + GPT2::kWPELayerName, nn::Embedding::kParamWeightName)]; + ReadMatrixAllFloat(ifs, static_cast(transformer_wpe_weight->DataPtr()), block_size, n_embd); + } else { + size_t wpe_bytes = block_size * n_embd * sizeof(float); + ifs.seekg(wpe_bytes, std::ios::cur); + } + // transformer.h.{i}.ln_1.weight for (int idx = 0; idx < n_layer; ++idx) { - auto &tensor - = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx), Block::kLn1LayerName, nn::LayerNorm::kParamWeightName)]; - ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + bool owned = (idx >= start_layer && idx < end_layer); + if (owned) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, + std::to_string(idx - start_layer), Block::kLn1LayerName, + nn::LayerNorm::kParamWeightName)]; + ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + } else { + size_t ln_1_w_bytes = n_embd * sizeof(float); + ifs.seekg(ln_1_w_bytes, std::ios::cur); + } } + // transformer.h.{i}.ln_1.bias for (int idx = 0; idx < n_layer; ++idx) { - auto &tensor - = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx), Block::kLn1LayerName, nn::LayerNorm::kParamBiasName)]; - ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + bool owned = (idx >= start_layer && idx < end_layer); + if (owned) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, + std::to_string(idx - start_layer), Block::kLn1LayerName, + nn::LayerNorm::kParamBiasName)]; + ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + } else { + size_t ln_1_b_bytes = n_embd * sizeof(float); + ifs.seekg(ln_1_b_bytes, std::ios::cur); + } } + // transformer.h.{i}.attn.c_attn.weight (ColumnParallelLinear, but actually applies on "rows") for (int idx = 0; idx < n_layer; ++idx) { - auto &tensor - = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx), Block::kAttnLayerName, CausalSelfAttention::kCAttnLayerName, - nn::parallel::ColumnParallelLinear::kParamWeightName)]; - // NOTE(zbl): In the .bin model file, Q/K/V is concated along last dim, - // i.e. [Q|K|V].T = [q1|q2|...|qn|k1|k2|...|kn|v1|v2|...|vn].T - // However, each tp_rank needs to get [q_i|k_i|v_i].T, so we need to jump and read them respectively - float *dst = static_cast(tensor->DataPtr()); - const int64_t local_C = n_embd / tp_size; - const int64_t rows_all = 3 * n_embd; - const int64_t cols_all = n_embd; - const std::streampos base_pos = ifs.tellg(); - // Read q_i -> write to dst rows of [0 : local_C) - ifs.seekg(base_pos); - ReadMatrixRowShardFloat(ifs, - /*dst=*/dst + (0 * local_C) * cols_all, - /*rows=*/rows_all, /*cols=*/cols_all, - /*row_start=*/tp_rank * local_C, /*row_cnt=*/local_C); - // Read k_i -> write to dst rows of [local_C : 2*local_C) - ifs.seekg(base_pos); - ReadMatrixRowShardFloat(ifs, - /*dst=*/dst + (1 * local_C) * cols_all, - /*rows=*/rows_all, /*cols=*/cols_all, - /*row_start=*/n_embd + tp_rank * local_C, /*row_cnt=*/local_C); - // Read v_i -> write to dst rows of [2*local_C : 3*local_C) - ifs.seekg(base_pos); - ReadMatrixRowShardFloat(ifs, - /*dst=*/dst + (2 * local_C) * cols_all, - /*rows=*/rows_all, /*cols=*/cols_all, - /*row_start=*/2 * n_embd + tp_rank * local_C, /*row_cnt=*/local_C); + bool owned = (idx >= start_layer && idx < end_layer); + if (owned) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, + std::to_string(idx - start_layer), Block::kAttnLayerName, + CausalSelfAttention::kCAttnLayerName, + nn::parallel::ColumnParallelLinear::kParamWeightName)]; + // NOTE(zbl): In the .bin model file, Q/K/V is concated along last dim, + // i.e. [Q|K|V].T = [q1|q2|...|qn|k1|k2|...|kn|v1|v2|...|vn].T + // However, each tp_rank needs to get [q_i|k_i|v_i].T, so we need to jump and read them + // respectively + float *dst = static_cast(tensor->DataPtr()); + const int64_t local_C = n_embd / tp_size; + const int64_t rows_all = 3 * n_embd; + const int64_t cols_all = n_embd; + const std::streampos base_pos = ifs.tellg(); + // Read q_i -> write to dst rows of [0 : local_C) + ifs.seekg(base_pos); + ReadMatrixRowShardFloat(ifs, + /*dst=*/dst + (0 * local_C) * cols_all, + /*rows=*/rows_all, /*cols=*/cols_all, + /*row_start=*/tp_rank * local_C, /*row_cnt=*/local_C); + // Read k_i -> write to dst rows of [local_C : 2*local_C) + ifs.seekg(base_pos); + ReadMatrixRowShardFloat(ifs, + /*dst=*/dst + (1 * local_C) * cols_all, + /*rows=*/rows_all, /*cols=*/cols_all, + /*row_start=*/n_embd + tp_rank * local_C, /*row_cnt=*/local_C); + // Read v_i -> write to dst rows of [2*local_C : 3*local_C) + ifs.seekg(base_pos); + ReadMatrixRowShardFloat(ifs, + /*dst=*/dst + (2 * local_C) * cols_all, + /*rows=*/rows_all, /*cols=*/cols_all, + /*row_start=*/2 * n_embd + tp_rank * local_C, /*row_cnt=*/local_C); + } else { + size_t c_attn_w_bytes = qkv_out * n_embd * sizeof(float); + ifs.seekg(c_attn_w_bytes, std::ios::cur); + } } + // transformer.h.{i}.attn.c_attn.bias (ColumnParallelLinear) for (int idx = 0; idx < n_layer; ++idx) { - auto &tensor - = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx), Block::kAttnLayerName, CausalSelfAttention::kCAttnLayerName, - nn::parallel::ColumnParallelLinear::kParamBiasName)]; - // NOTE(zbl): Same as c_attn.weight, the bias for Q/K/V is concated - // i.e. [Q|K|V] = [q1|q2|...|qn|k1|k2|...|kn|v1|v2|...|vn] - // However, each tp_rank needs to get [q_i|k_i|v_i], so we need to jump and read them respectively - float *dst = static_cast(tensor->DataPtr()); - const int64_t local_C = n_embd / tp_size; - const int64_t len_all = 3 * n_embd; - const std::streampos base_pos = ifs.tellg(); - // Read q_i - ifs.seekg(base_pos); - ReadVectorShardFloat(ifs, - /*dst=*/dst + (0 * local_C), - /*len=*/len_all, - /*start=*/tp_rank * local_C, /*cnt=*/local_C); - // Read k_i - ifs.seekg(base_pos); - ReadVectorShardFloat(ifs, - /*dst=*/dst + (1 * local_C), - /*len=*/len_all, - /*start=*/n_embd + tp_rank * local_C, /*cnt=*/local_C); - // Read v_i - ifs.seekg(base_pos); - ReadVectorShardFloat(ifs, - /*dst=*/dst + (2 * local_C), - /*len=*/len_all, - /*start=*/2 * n_embd + tp_rank * local_C, /*cnt=*/local_C); + bool owned = (idx >= start_layer && idx < end_layer); + if (owned) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, + std::to_string(idx - start_layer), Block::kAttnLayerName, + CausalSelfAttention::kCAttnLayerName, + nn::parallel::ColumnParallelLinear::kParamBiasName)]; + // NOTE(zbl): Same as c_attn.weight, the bias for Q/K/V is concated + // i.e. [Q|K|V] = [q1|q2|...|qn|k1|k2|...|kn|v1|v2|...|vn] + // However, each tp_rank needs to get [q_i|k_i|v_i], so we need to jump and read them + // respectively + float *dst = static_cast(tensor->DataPtr()); + const int64_t local_C = n_embd / tp_size; + const int64_t len_all = 3 * n_embd; + const std::streampos base_pos = ifs.tellg(); + // Read q_i + ifs.seekg(base_pos); + ReadVectorShardFloat(ifs, + /*dst=*/dst + (0 * local_C), + /*len=*/len_all, + /*start=*/tp_rank * local_C, /*cnt=*/local_C); + // Read k_i + ifs.seekg(base_pos); + ReadVectorShardFloat(ifs, + /*dst=*/dst + (1 * local_C), + /*len=*/len_all, + /*start=*/n_embd + tp_rank * local_C, /*cnt=*/local_C); + // Read v_i + ifs.seekg(base_pos); + ReadVectorShardFloat(ifs, + /*dst=*/dst + (2 * local_C), + /*len=*/len_all, + /*start=*/2 * n_embd + tp_rank * local_C, /*cnt=*/local_C); + } else { + size_t c_attn_b_bytes = qkv_out * sizeof(float); + ifs.seekg(c_attn_b_bytes, std::ios::cur); + } } + // transformer.h.{i}.attn.c_proj.weight (RowParallelLinear, but actually applies on "columns") for (int idx = 0; idx < n_layer; ++idx) { - auto &tensor - = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx), Block::kAttnLayerName, CausalSelfAttention::kCProjLayerName, - nn::parallel::RowParallelLinear::kParamWeightName)]; - ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), n_embd, n_embd, tp_rank * in_pp, in_pp); + bool owned = (idx >= start_layer && idx < end_layer); + if (owned) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, + std::to_string(idx - start_layer), Block::kAttnLayerName, + CausalSelfAttention::kCProjLayerName, + nn::parallel::RowParallelLinear::kParamWeightName)]; + ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), n_embd, n_embd, tp_rank * in_pp, + in_pp); + } else { + size_t c_proj_w_bytes = n_embd * n_embd * sizeof(float); + ifs.seekg(c_proj_w_bytes, std::ios::cur); + } } + // transformer.h.{i}.attn.c_proj.bias (RowParallelLinear, no shard on bias) for (int idx = 0; idx < n_layer; ++idx) { - auto &tensor - = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx), Block::kAttnLayerName, CausalSelfAttention::kCProjLayerName, - nn::parallel::RowParallelLinear::kParamBiasName)]; - ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + bool owned = (idx >= start_layer && idx < end_layer); + if (owned) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, + std::to_string(idx - start_layer), Block::kAttnLayerName, + CausalSelfAttention::kCProjLayerName, + nn::parallel::RowParallelLinear::kParamBiasName)]; + ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + } else { + size_t c_proj_b_bytes = n_embd * sizeof(float); + ifs.seekg(c_proj_b_bytes, std::ios::cur); + } } + // transformer.h.{i}.ln_2.weight for (int idx = 0; idx < n_layer; ++idx) { - auto &tensor - = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx), Block::kLn2LayerName, nn::LayerNorm::kParamWeightName)]; - ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + bool owned = (idx >= start_layer && idx < end_layer); + if (owned) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, + std::to_string(idx - start_layer), Block::kLn2LayerName, + nn::LayerNorm::kParamWeightName)]; + ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + } else { + size_t ln_2_w_bytes = n_embd * sizeof(float); + ifs.seekg(ln_2_w_bytes, std::ios::cur); + } } + // transformer.h.{i}.ln_2.bias for (int idx = 0; idx < n_layer; ++idx) { - auto &tensor - = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx), Block::kLn2LayerName, nn::LayerNorm::kParamBiasName)]; - ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + bool owned = (idx >= start_layer && idx < end_layer); + if (owned) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, + std::to_string(idx - start_layer), Block::kLn2LayerName, + nn::LayerNorm::kParamBiasName)]; + ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + } else { + size_t ln_2_b_bytes = n_embd * sizeof(float); + ifs.seekg(ln_2_b_bytes, std::ios::cur); + } } + // transformer.h.{i}.mlp.c_fc.weight (ColumnParallelLinear, but actually applies on "rows") for (int idx = 0; idx < n_layer; ++idx) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx), Block::kMlpLayerName, MLP::kCFcLayerName, - nn::parallel::ColumnParallelLinear::kParamWeightName)]; - ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), fc_out, n_embd, fc_start, fc_pp); + bool owned = (idx >= start_layer && idx < end_layer); + if (owned) { + auto &tensor = state_dict[std::format( + "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, std::to_string(idx - start_layer), + Block::kMlpLayerName, MLP::kCFcLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; + ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), fc_out, n_embd, fc_start, fc_pp); + } else { + size_t c_fc_w_bytes = fc_out * n_embd * sizeof(float); + ifs.seekg(c_fc_w_bytes, std::ios::cur); + } } + // transformer.h.{i}.mlp.c_fc.bias (ColumnParallelLinear) for (int idx = 0; idx < n_layer; ++idx) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx), Block::kMlpLayerName, MLP::kCFcLayerName, - nn::parallel::ColumnParallelLinear::kParamBiasName)]; - ReadVectorShardFloat(ifs, static_cast(tensor->DataPtr()), fc_out, fc_start, fc_pp); + bool owned = (idx >= start_layer && idx < end_layer); + if (owned) { + auto &tensor = state_dict[std::format( + "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, std::to_string(idx - start_layer), + Block::kMlpLayerName, MLP::kCFcLayerName, nn::parallel::ColumnParallelLinear::kParamBiasName)]; + ReadVectorShardFloat(ifs, static_cast(tensor->DataPtr()), fc_out, fc_start, fc_pp); + } else { + size_t c_fc_b_bytes = fc_out * sizeof(float); + ifs.seekg(c_fc_b_bytes, std::ios::cur); + } } + // transformer.h.{i}.mlp.c_proj.weight (RowParallelLinear, but actually applies on "columns") for (int idx = 0; idx < n_layer; ++idx) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx), Block::kMlpLayerName, MLP::kCProjLayerName, - nn::parallel::RowParallelLinear::kParamWeightName)]; - ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), n_embd, fc_out, tp_rank * in4_pp, in4_pp); + bool owned = (idx >= start_layer && idx < end_layer); + if (owned) { + auto &tensor = state_dict[std::format( + "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, std::to_string(idx - start_layer), + Block::kMlpLayerName, MLP::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)]; + ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), n_embd, fc_out, tp_rank * in4_pp, + in4_pp); + } else { + size_t c_proj_w_bytes = fc_out * n_embd * sizeof(float); + ifs.seekg(c_proj_w_bytes, std::ios::cur); + } } + // transformer.h.{i}.mlp.c_proj.bias (RowParallelLinear, no shard on bias) for (int idx = 0; idx < n_layer; ++idx) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx), Block::kMlpLayerName, MLP::kCProjLayerName, - nn::parallel::RowParallelLinear::kParamBiasName)]; - ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + bool owned = (idx >= start_layer && idx < end_layer); + if (owned) { + auto &tensor = state_dict[std::format( + "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, std::to_string(idx - start_layer), + Block::kMlpLayerName, MLP::kCProjLayerName, nn::parallel::RowParallelLinear::kParamBiasName)]; + ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + } else { + size_t c_proj_b_bytes = n_embd * sizeof(float); + ifs.seekg(c_proj_b_bytes, std::ios::cur); + } } - // transformer.ln_f.weight - auto &transformer_ln_f_weight = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kLnFLayerName, - nn::LayerNorm::kParamWeightName)]; - ReadVectorAllFloat(ifs, static_cast(transformer_ln_f_weight->DataPtr()), n_embd); - // transformer.ln_f.bias - auto &transformer_ln_f_bias = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kLnFLayerName, - nn::LayerNorm::kParamBiasName)]; - ReadVectorAllFloat(ifs, static_cast(transformer_ln_f_bias->DataPtr()), n_embd); + if (is_last_stage) { + // transformer.ln_f.weight + auto &transformer_ln_f_weight = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, + GPT2::kLnFLayerName, nn::LayerNorm::kParamWeightName)]; + ReadVectorAllFloat(ifs, static_cast(transformer_ln_f_weight->DataPtr()), n_embd); + // transformer.ln_f.bias + auto &transformer_ln_f_bias = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, + GPT2::kLnFLayerName, nn::LayerNorm::kParamBiasName)]; + ReadVectorAllFloat(ifs, static_cast(transformer_ln_f_bias->DataPtr()), n_embd); + } else { + size_t ln_f_w_bytes = n_embd * sizeof(float); + size_t ln_f_b_bytes = n_embd * sizeof(float); + ifs.seekg(ln_f_w_bytes + ln_f_b_bytes, std::ios::cur); + } return local_gpt2; } diff --git a/example/llama3/main.cc b/example/llama3/main.cc index 9dc4cf0..c3df5a9 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -14,6 +14,7 @@ #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/nn/parallel/distributed_data_parallel.h" #include "infini_train/include/nn/parallel/parallel_functional.h" +#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" #include "infini_train/include/nn/parallel/rank.h" #include "infini_train/include/nn/parallel/reduce_op_type.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" @@ -62,6 +63,7 @@ DEFINE_int32( "When set > 1, enables data parallelism with device=cuda on the specified number of visible CUDA devices."); DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size"); DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel"); +DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, , specified the number of PP stages."); // precision DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)"); @@ -89,6 +91,7 @@ void Train(const nn::parallel::Rank &rank) { int ddp_world_size = global::GetDataParallelSize(); int tp_world_size = global::GetTensorParallelSize(); int sp_world_size = global::GetSequenceParallelEnabled() ? tp_world_size : 0; + int pp_world_size = global::GetPipelineParallelSize(); if (FLAGS_sequence_parallel) { CHECK_EQ(FLAGS_sequence_length % tp_world_size, 0) @@ -97,9 +100,11 @@ void Train(const nn::parallel::Rank &rank) { int ddp_rank = 0; int tp_rank = 0; + int pp_rank = 0; const ProcessGroup *ddp_pg = nullptr; const ProcessGroup *tp_pg = nullptr; + const ProcessGroup *pp_pg = nullptr; if (rank.IsParallel()) { device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank()); @@ -117,6 +122,14 @@ void Train(const nn::parallel::Rank &rank) { // NOTE(zbl): Reserved for VocabParallelEmbedding nn::parallel::tp_rank = tp_rank; } + + if (pp_world_size > 1) { + pp_pg = ProcessGroupFactory::Instance()->GetOrCreate( + GetPipelineParallelProcessGroupName(rank.thread_rank()), GetPipelineParallelGroupRanks(pp_world_size)); + pp_rank = pp_pg->GetGroupRank(rank.thread_rank()); + + nn::parallel::pp_rank = pp_rank; + } } else { device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDefaultDevice() : DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, 0); @@ -163,8 +176,11 @@ void Train(const nn::parallel::Rank &rank) { model = std::make_shared(model, rank.thread_rank()); } + auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size); + DistributedDataLoader train_loader(std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), - FLAGS_batch_size, ddp_rank, ddp_world_size); + FLAGS_batch_size * num_micro_batches, ddp_rank, ddp_world_size); + std::optional val_loader = std::nullopt; if (!FLAGS_input_val_bin.empty()) { val_loader = DistributedDataLoader( @@ -190,6 +206,13 @@ void Train(const nn::parallel::Rank &rank) { loss_fn->To(device); LOG(INFO) << "Rank " << rank.thread_rank() << ": start training"; + if (pp_world_size > 1) { + auto shapes = std::vector>{{FLAGS_batch_size, FLAGS_sequence_length, model_config.n_embd}}; + + model = std::make_shared( + model, pp_world_size, num_micro_batches, shapes, pp_rank, std::make_shared(optimizer)); + } + for (int step = 0; step < FLAGS_num_iteration + 1; ++step) { const bool last_step = step == FLAGS_num_iteration; @@ -212,65 +235,81 @@ void Train(const nn::parallel::Rank &rank) { break; } - // model->Train(); - optimizer.ZeroGrad(); - // if we are trying to overfit a single batch, we reset the loader here - if (FLAGS_overfit_single_batch) { - // train_loader.Reset(); - } - float lossf = 0.0f; #ifdef PROFILE_MODE Profiler::Instance().SetTag("Step_" + std::to_string(step)); #endif - for (int micro_step = 0; micro_step < grad_accum_steps; ++micro_step) { - // enable autocast for the current step - infini_train::AutocastGuard autocast_guard(device->Type(), dtype); - // (bs, seq_len), (bs, seq_len) + float lossf = 0.0f; + if (pp_world_size == 1) { + // model->Train(); + optimizer.ZeroGrad(); + + // if we are trying to overfit a single batch, we reset the loader here + if (FLAGS_overfit_single_batch) { + // train_loader.Reset(); + } + + for (int micro_step = 0; micro_step < grad_accum_steps; ++micro_step) { + // enable autocast for the current step + infini_train::AutocastGuard autocast_guard(device->Type(), dtype); + + // (bs, seq_len), (bs, seq_len) + auto [x, y] = *train_iter; + // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below + // TODO(dcj): support dataloader.reset() later + ++train_iter; + x = std::make_shared(x->To(device)); + y = std::make_shared(y->To(device)); + + LOG(INFO) << "Rank " << rank.thread_rank() << ": start forward"; + // (bs, seq_len, vocab_size) + auto logits = model->Forward({x, y})[0]; + LOG(INFO) << "Rank " << rank.thread_rank() << ": finish model forward, start loss forward"; + auto loss = loss_fn->Forward({logits, y})[0]; + loss = loss / grad_accum_steps; + + // disable autocast for the current step (backward is not under autocast) + autocast_guard.Disable(); + + LOG(INFO) << "Rank " << rank.thread_rank() << ": finish loss forward"; + if (ddp_world_size > 1) { + function::AllReduce(loss, function::ReduceOpType::kAvg); + } + auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice()); + lossf += static_cast(loss_cpu.DataPtr())[0]; + LOG(INFO) << "Rank " << rank.thread_rank() << ": start backward"; + loss->Backward(); + LOG(INFO) << "Rank " << rank.thread_rank() << ": finish backward"; + } + + optimizer.Step(); + } else { auto [x, y] = *train_iter; // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below // TODO(dcj): support dataloader.reset() later ++train_iter; x = std::make_shared(x->To(device)); y = std::make_shared(y->To(device)); - LOG(INFO) << "Rank " << rank.thread_rank() << ": start forward"; - // (bs, seq_len, vocab_size) - auto logits = model->Forward({x, y})[0]; - LOG(INFO) << "Rank " << rank.thread_rank() << ": finish model forward, start loss forward"; - auto loss = loss_fn->Forward({logits, y})[0]; - loss = loss / grad_accum_steps; - - // disable autocast for the current step (backward is not under autocast) - autocast_guard.Disable(); - - LOG(INFO) << "Rank " << rank.thread_rank() << ": finish loss forward"; - if (ddp_world_size > 1) { - function::AllReduce(loss, function::ReduceOpType::kAvg); - } - auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice()); - lossf += static_cast(loss_cpu.DataPtr())[0]; - LOG(INFO) << "Rank " << rank.thread_rank() << ": start backward"; - loss->Backward(); - LOG(INFO) << "Rank " << rank.thread_rank() << ": finish backward"; - } - optimizer.Step(); + lossf = model->TrainStep({x}, {y}, loss_fn); + } const auto iter_end = std::chrono::high_resolution_clock::now(); const double duration_us = std::chrono::duration(iter_end - iter_start).count(); const double tps = FLAGS_total_batch_size / (duration_us / 1e6); - if (rank.IsMainRank()) { - LOG(ERROR) << std::format( - "step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, DP={}, TP={}, SP={})", - step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, tps, ddp_world_size, - tp_world_size, sp_world_size); + if (rank.thread_rank() == pp_world_size - 1) { + LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, " + "DP={}, TP={}, SP={}, PP={})", + step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, + tps, ddp_world_size, tp_world_size, sp_world_size, pp_world_size); if ((step + 1) % FLAGS_freq_generate_txt == 0) { - if (!tokenizer) { - continue; + // FIXME(jym): to support PP + if (tokenizer) { + CHECK_EQ(pp_world_size, 1); + tokenizer->GenerateText(*model, FLAGS_batch_size, FLAGS_sequence_length, FLAGS_text_length, device); } - tokenizer->GenerateText(*model, FLAGS_batch_size, FLAGS_sequence_length, FLAGS_text_length, device); } } } @@ -284,7 +323,8 @@ int main(int argc, char *argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, true); google::InitGoogleLogging(argv[0]); - nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel); + nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel, + FLAGS_pipeline_parallel); LOG(INFO) << nn::parallel::global::ProcessGroupOverview(); diff --git a/example/llama3/net.cc b/example/llama3/net.cc index 400d8b3..3e7becf 100644 --- a/example/llama3/net.cc +++ b/example/llama3/net.cc @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -22,6 +23,7 @@ #include "infini_train/include/nn/modules/normalization.h" #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" #include "infini_train/include/tensor.h" @@ -99,8 +101,8 @@ ApplyRotaryEmbedding(const std::shared_ptr &xq, const std::shared_ptr PrecomputeFreqsCis(int64_t dim, int64_t end, float theta = 10000.0f, bool use_scaled = false, const infini_train::Device *device - = DeviceManager::Instance()->GetDefaultDevice(), - DataType dtype = DataType::kFLOAT32) { + = DeviceManager::Instance()->GetDefaultDevice()) { + DataType dtype = DataType::kFLOAT32; CHECK_GE(dim, 2) << "dim must be >= 2 for slicing"; auto arange = nn::init::Arange(0, dim, dtype, device)->Slice(0, 0, dim, 2); auto freqs = 1.0f / nn::function::Pow(theta, arange / float(dim)); @@ -308,6 +310,7 @@ std::vector> Block::Forward(const std::vector 1 ? x[1] : nullptr; const auto start_pos = x.size() > 2 ? x[2] : nullptr; const auto mask = x.size() > 3 ? x[3] : nullptr; + // (bs, seq_len, n_embd) -> RMSNorm -> (bs, seq_len, n_embd) -> attention -> (bs, seq_len, n_embd) // -> Add -> (bs, seq_len, n_embd) auto x1 = x[0] @@ -323,83 +326,92 @@ std::vector> Block::Forward(const std::vector> transformer; + int pp_size = nn::parallel::global::GetPipelineParallelSize(); + auto [is_first_stage, is_last_stage, start_layer, end_layer] + = nn::parallel::PipelineParallel::GetStageInfo(config_.n_layer, pp_size); + + std::unordered_map> transformer; + if (is_first_stage) { transformer[kWTELayerName] = std::make_shared( config.vocab_size, config.n_embd, nn::parallel::global::GetSequenceParallelEnabled()); - { - std::vector> h; - for (int64_t i = 0; i < config.n_layer; ++i) { h.push_back(std::make_shared(config)); } - transformer[kHLayerName] = std::make_shared(std::move(h)); - } + } + + std::vector> h_local; + for (int64_t i = start_layer; i < end_layer; ++i) { h_local.push_back(std::make_shared(config)); } + transformer[kHLayerName] = std::make_shared(std::move(h_local)); + + if (is_last_stage) { transformer[kLnFLayerName] = std::make_shared(config.n_embd, config.norm_eps); - modules_[kTransformerLayerName] = std::make_shared(std::move(transformer)); + // NOTE(zbl): weight-tying is possible but torch script did not do so + modules_[kLMHeadLayerName] = std::make_shared( + /*in_features=*/config_.n_embd, /*out_features=*/config_.vocab_size, + /*bias=*/false, + // NOTE(zbl): each rank would get sharded [B, T, V_local] as logits + /*gather_output=*/false, + /*input_is_parallel=*/false, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); } - // NOTE(zbl): weight-tying is possible but torch script did not do so - modules_[kLMHeadLayerName] = std::make_shared( - /*in_features=*/config_.n_embd, /*out_features=*/config_.vocab_size, - /*bias=*/false, - // NOTE(zbl): each tp_rank would get sharded [B, T, V_local] as logits - /*gather_output=*/false, - /*input_is_parallel=*/false, - /*skip_bias_add=*/false, - /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); + modules_[kTransformerLayerName] = std::make_shared(std::move(transformer)); } std::vector> LLaMA3::Forward(const std::vector> &x) { + int pp_rank = nn::parallel::pp_rank; + int pp_size = nn::parallel::global::GetPipelineParallelSize(); + bool is_first_stage = (pp_rank == 0); + bool is_last_stage = (pp_rank == pp_size - 1); + // (bs, seq_len) - auto &idx = x[0]; - const auto device = idx->GetDevice(); - const auto t = idx->Dims()[1]; // seq_len + auto x1 = x[0]; + const auto device = x1->GetDevice(); + const auto t = x1->Dims()[1]; // seq_len CHECK_LE(t, config_.block_size) << "Cannot forward sequence of length " << t << ", block size is only " << config_.block_size; // Init freqs_cis on device only once // TODO(zbl): consider moving this to model construction if (buffers_[kFreqsCisName] == nullptr) { - buffers_[kFreqsCisName] = PrecomputeFreqsCis( - config_.n_embd / config_.n_head, config_.block_size * 2, config_.rope_theta, config_.use_scaled_rope, - device, - modules_[kLMHeadLayerName]->parameter(nn::parallel::ColumnParallelLinear::kParamWeightName)->Dtype()); + buffers_[kFreqsCisName] = PrecomputeFreqsCis(config_.n_embd / config_.n_head, config_.block_size * 2, + config_.rope_theta, config_.use_scaled_rope, device); } // forward the LLaMA3 model itself auto &transformer = modules_[kTransformerLayerName]; - // (bs, seq_len) -> Embedding(vocab_size, n_embd) -> (bs, seq_len, n_embd) - auto x1 = transformer->mutable_module(kWTELayerName)->Forward({idx})[0]; + + if (is_first_stage) { + // (bs, seq_len) -> Embedding(vocab_size, n_embd) -> (bs, seq_len, n_embd) + x1 = transformer->mutable_module(kWTELayerName)->Forward({x1})[0]; + } // TODO(zbl): dynamic start_pos int64_t start_pos = 0; auto freqs_view = buffers_[kFreqsCisName]->Slice(0, start_pos, start_pos + t, 1); // TODO(lzm): add dtype support for nn::function::Ones later - std::shared_ptr ones = std::make_shared(nn::function::Ones({t, t})->To(idx->GetDevice())); + std::shared_ptr ones = std::make_shared(nn::function::Ones({t, t})->To(x1->GetDevice())); std::shared_ptr mask = nn::function::Triu(ones, 1)->View({1, 1, t, t}); - // TODO(zbl): nn::function::Ones builds tensor in FP32 by default - if (modules_[kLMHeadLayerName]->parameter(nn::parallel::ColumnParallelLinear::kParamWeightName)->Dtype() - == DataType::kBFLOAT16) { - mask = std::make_shared(mask->To(DataType::kBFLOAT16)); - } + std::shared_ptr start_pos_ptr = nullptr; auto h_modules = transformer->mutable_module(kHLayerName); - if (h_modules->type() == nn::ModuleList::kType) { - auto h_layers = std::dynamic_pointer_cast(h_modules); - // (bs, seq_len, n_embd) -> transformer -> (bs, seq_len, n_embd) - for (auto &h : *h_layers) { x1 = h->Forward({x1, freqs_view, start_pos_ptr, mask})[0]; } - } else { - LOG(FATAL) << "Failed to get ModuleList from transformer"; - } + CHECK_EQ(h_modules->type(), nn::ModuleList::kType) << "Failed to get ModuleList from transformer"; + auto h_layers = std::dynamic_pointer_cast(h_modules); + // (bs, seq_len, n_embd) -> transformer -> (bs, seq_len, n_embd) + for (auto &h : *h_layers) { x1 = h->Forward({x1, freqs_view, start_pos_ptr, mask})[0]; } + + if (is_last_stage) { + // (bs, seq_len, n_embd) -> RMSNorm -> (bs, seq_len, n_embd) + auto x2 = transformer->mutable_module(kLnFLayerName)->Forward({x1}); - // (bs, seq_len, n_embd) -> RMSNorm -> (bs, seq_len, n_embd) - auto x2 = transformer->mutable_module(kLnFLayerName)->Forward({x1}); + // TODO(zbl): add inference-time mini-optimization + // (bs, seq_len, n_embd) -> Linear(n_embd, vocab_size) -> (bs, seq_len, vocab_size) + auto logits = modules_[kLMHeadLayerName]->Forward(x2); - // TODO(zbl): add inference-time mini-optimization - // (bs, seq_len, n_embd) -> Linear(n_embd, vocab_size) -> (bs, seq_len, vocab_size) - auto logits = modules_[kLMHeadLayerName]->Forward(x2); + // (bs, seq_len, vocab_size) + return logits; + } - // (bs, seq_len, vocab_size) - return logits; + return {x1}; } std::shared_ptr LLaMA3::FromPretrained(ModelType model_type) { @@ -453,6 +465,9 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { .use_scaled_rope = static_cast(use_scaled_rope), .norm_eps = norm_eps, .max_gen_batch_size = max_gen_bs}); + int pp_size = nn::parallel::global::GetPipelineParallelSize(); + auto [is_first_stage, is_last_stage, start_layer, end_layer] + = nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size); const int tp_size = nn::parallel::global::GetTensorParallelSize(); const int tp_rank = nn::parallel::tp_rank; @@ -521,114 +536,168 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { // ========== Read Sharded Params ========== // transformer.wte.weight : (vocab_size, n_embd) -> local tp_rank: rows of [v_start : v_start+vpp) - { + if (is_first_stage) { auto &wte = state_dict[std::format("{}.{}.{}", kTransformerLayerName, kWTELayerName, nn::parallel::VocabParallelEmbedding::kParamWeightName)]; ReadMatrixRowShardFloat(ifs, static_cast(wte->DataPtr()), /*rows=*/vocab_size, /*cols=*/n_embd, /*row_start=*/v_start, /*row_cnt=*/vpp); + } else { + size_t wte_bytes = static_cast(vocab_size) * n_embd * sizeof(float); + ifs.seekg(wte_bytes, std::ios::cur); } // transformer.h.{i}.ln_1.weight : Full version RMSNorm for (int i = 0; i < static_cast(n_layer); ++i) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, std::to_string(i), - Block::kLn1LayerName, RMSNorm::kParamWeightName)]; - ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + bool owned = (i >= start_layer && i < end_layer); + + if (owned) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, + std::to_string(i - start_layer), Block::kLn1LayerName, + RMSNorm::kParamWeightName)]; + ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + } else { + size_t ln_1_bytes = n_embd * sizeof(float); + ifs.seekg(ln_1_bytes, std::ios::cur); + } } // transformer.h.{i}.attn.c_attn.weight : ColumnParallelLinear, but actually applies on "rows" // W-qkv should be [Q(=n_embd) | K(=n_kv_head*head_dim) | V(=n_kv_head*head_dim)] × n_embd for (int i = 0; i < static_cast(n_layer); ++i) { - auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, std::to_string(i), Block::kAttnLayerName, - CausalSelfAttention::kCAttnLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; - float *dst = static_cast(tensor->DataPtr()); - const std::streampos base_pos = ifs.tellg(); - - // Q block -> [0 : q_local_rows) - ifs.seekg(base_pos); - ReadMatrixRowShardFloat(ifs, - /*dst=*/dst + (0 * attn_cols), - /*rows=*/attn_rows_all, /*cols=*/attn_cols, - /*row_start=*/tp_rank * q_local_rows, /*row_cnt=*/q_local_rows); - - // K block -> [q_local_rows : q_local_rows + kv_local_rows) - ifs.seekg(base_pos); - ReadMatrixRowShardFloat(ifs, - /*dst=*/dst + (q_local_rows * attn_cols), - /*rows=*/attn_rows_all, /*cols=*/attn_cols, - /*row_start=*/q_out_rows + tp_rank * kv_local_rows, /*row_cnt=*/kv_local_rows); - - // V block -> [q_local_rows + kv_local_rows : q_local_rows + 2*kv_local_rows) - ifs.seekg(base_pos); - ReadMatrixRowShardFloat(ifs, - /*dst=*/dst + ((q_local_rows + kv_local_rows) * attn_cols), - /*rows=*/attn_rows_all, /*cols=*/attn_cols, - /*row_start=*/q_out_rows + kv_out_rows + tp_rank * kv_local_rows, - /*row_cnt=*/kv_local_rows); + bool owned = (i >= start_layer && i < end_layer); + if (owned) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, + std::to_string(i - start_layer), Block::kAttnLayerName, + CausalSelfAttention::kCAttnLayerName, + nn::parallel::ColumnParallelLinear::kParamWeightName)]; + + float *dst = static_cast(tensor->DataPtr()); + const std::streampos base_pos = ifs.tellg(); + + // Q block -> [0 : q_local_rows) + ifs.seekg(base_pos); + ReadMatrixRowShardFloat(ifs, + /*dst=*/dst + (0 * attn_cols), + /*rows=*/attn_rows_all, /*cols=*/attn_cols, + /*row_start=*/tp_rank * q_local_rows, /*row_cnt=*/q_local_rows); + + // K block -> [q_local_rows : q_local_rows + kv_local_rows) + ifs.seekg(base_pos); + ReadMatrixRowShardFloat(ifs, + /*dst=*/dst + (q_local_rows * attn_cols), + /*rows=*/attn_rows_all, /*cols=*/attn_cols, + /*row_start=*/q_out_rows + tp_rank * kv_local_rows, /*row_cnt=*/kv_local_rows); + + // V block -> [q_local_rows + kv_local_rows : q_local_rows + 2*kv_local_rows) + ifs.seekg(base_pos); + ReadMatrixRowShardFloat(ifs, + /*dst=*/dst + ((q_local_rows + kv_local_rows) * attn_cols), + /*rows=*/attn_rows_all, /*cols=*/attn_cols, + /*row_start=*/q_out_rows + kv_out_rows + tp_rank * kv_local_rows, + /*row_cnt=*/kv_local_rows); + } else { + size_t qkv_bytes = static_cast(attn_rows_all) * attn_cols * sizeof(float); + ifs.seekg(qkv_bytes, std::ios::cur); + } } // transformer.h.{i}.attn.c_proj.weight : RowParallelLinear, but actually applies on "columns" for (int i = 0; i < static_cast(n_layer); ++i) { - auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, std::to_string(i), Block::kAttnLayerName, - CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)]; - ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), - /*rows=*/n_embd, /*cols=*/n_embd, - /*col_start=*/tp_rank * in_pp, /*col_cnt=*/in_pp); + bool owned = (i >= start_layer && i < end_layer); + if (owned) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, + std::to_string(i - start_layer), Block::kAttnLayerName, + CausalSelfAttention::kCProjLayerName, + nn::parallel::RowParallelLinear::kParamWeightName)]; + ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), + /*rows=*/n_embd, /*cols=*/n_embd, + /*col_start=*/tp_rank * in_pp, /*col_cnt=*/in_pp); + } else { + size_t c_proj_bytes = static_cast(n_embd) * n_embd * sizeof(float); + ifs.seekg(c_proj_bytes, std::ios::cur); + } } // transformer.h.{i}.ln_2.weight : Full version RMSNorm for (int i = 0; i < static_cast(n_layer); ++i) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, std::to_string(i), - Block::kLn2LayerName, RMSNorm::kParamWeightName)]; - ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + bool owned = (i >= start_layer && i < end_layer); + if (owned) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, + std::to_string(i - start_layer), Block::kLn2LayerName, + RMSNorm::kParamWeightName)]; + ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + } else { + size_t ln_2_bytes = static_cast(n_embd) * sizeof(float); + ifs.seekg(ln_2_bytes, std::ios::cur); + } } // transformer.h.{i}.mlp.c_fc.weight : ColumnParallelLinear, but actually applies on "rows" for (int i = 0; i < static_cast(n_layer); ++i) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, - std::to_string(i), Block::kMlpLayerName, MLP::kCFcLayerName, - nn::parallel::ColumnParallelLinear::kParamWeightName)]; - ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), - /*rows=*/fc_out, /*cols=*/n_embd, - /*row_start=*/tp_rank * fc_pp, /*row_cnt=*/fc_pp); + bool owned = (i >= start_layer && i < end_layer); + if (owned) { + auto &tensor = state_dict[std::format( + "{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, std::to_string(i - start_layer), + Block::kMlpLayerName, MLP::kCFcLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; + ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), + /*rows=*/fc_out, /*cols=*/n_embd, + /*row_start=*/tp_rank * fc_pp, /*row_cnt=*/fc_pp); + } else { + size_t fc_bytes = static_cast(ffn_hidden) * n_embd * sizeof(float); + ifs.seekg(fc_bytes, std::ios::cur); + } } // transformer.h.{i}.mlp.c_fc2.weight : ColumnParallelLinear, but actually applies on "rows" for (int i = 0; i < static_cast(n_layer); ++i) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, - std::to_string(i), Block::kMlpLayerName, MLP::kCFc2LayerName, - nn::parallel::ColumnParallelLinear::kParamWeightName)]; - ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), - /*rows=*/fc_out, /*cols=*/n_embd, - /*row_start=*/tp_rank * fc_pp, /*row_cnt=*/fc_pp); + bool owned = (i >= start_layer && i < end_layer); + if (owned) { + auto &tensor = state_dict[std::format( + "{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, std::to_string(i - start_layer), + Block::kMlpLayerName, MLP::kCFc2LayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; + ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), + /*rows=*/fc_out, /*cols=*/n_embd, + /*row_start=*/tp_rank * fc_pp, /*row_cnt=*/fc_pp); + } else { + size_t fc2_bytes = static_cast(ffn_hidden) * n_embd * sizeof(float); + ifs.seekg(fc2_bytes, std::ios::cur); + } } // transformer.h.{i}.mlp.c_proj.weight : RowParallelLinear, but actually applies on "columns" for (int i = 0; i < static_cast(n_layer); ++i) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, - std::to_string(i), Block::kMlpLayerName, MLP::kCProjLayerName, - nn::parallel::RowParallelLinear::kParamWeightName)]; - ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), - /*rows=*/n_embd, /*cols=*/fc_out, - /*col_start=*/tp_rank * in_fc_pp, /*col_cnt=*/in_fc_pp); + bool owned = (i >= start_layer && i < end_layer); + if (owned) { + auto &tensor = state_dict[std::format( + "{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, std::to_string(i - start_layer), + Block::kMlpLayerName, MLP::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)]; + ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), + /*rows=*/n_embd, /*cols=*/fc_out, + /*col_start=*/tp_rank * in_fc_pp, /*col_cnt=*/in_fc_pp); + } else { + size_t c_proj_bytes = static_cast(n_embd) * ffn_hidden * sizeof(float); + ifs.seekg(c_proj_bytes, std::ios::cur); + } } // transformer.ln_f.weight : Full version RMSNorm - { - auto &ln_f - = state_dict[std::format("{}.{}.{}", kTransformerLayerName, kLnFLayerName, RMSNorm::kParamWeightName)]; - ReadVectorAllFloat(ifs, static_cast(ln_f->DataPtr()), n_embd); - } - // lm_head.weight : (vocab_size, n_embd) -> ColumnParallelLinear, but actually applies on "rows" { - auto &lm_head - = state_dict[std::format("{}.{}", kLMHeadLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; - ReadMatrixRowShardFloat(ifs, static_cast(lm_head->DataPtr()), - /*rows=*/vocab_size, /*cols=*/n_embd, - /*row_start=*/v_start, /*row_cnt=*/vpp); + if (is_last_stage) { + auto &ln_f + = state_dict[std::format("{}.{}.{}", kTransformerLayerName, kLnFLayerName, RMSNorm::kParamWeightName)]; + auto &lm_head = state_dict[std::format("{}.{}", kLMHeadLayerName, + nn::parallel::ColumnParallelLinear::kParamWeightName)]; + ReadVectorAllFloat(ifs, static_cast(ln_f->DataPtr()), n_embd); + ReadMatrixRowShardFloat(ifs, static_cast(lm_head->DataPtr()), + /*rows=*/vocab_size, /*cols=*/n_embd, + /*row_start=*/v_start, /*row_cnt=*/vpp); + } else { + size_t ln_f_bytes = static_cast(n_embd) * sizeof(float); + size_t lm_head_bytes = static_cast(vocab_size) * n_embd * sizeof(float); + ifs.seekg(ln_f_bytes + lm_head_bytes, std::ios::cur); + } } return llama3; diff --git a/infini_train/include/nn/modules/module.h b/infini_train/include/nn/modules/module.h index ee3be58..c090dd3 100644 --- a/infini_train/include/nn/modules/module.h +++ b/infini_train/include/nn/modules/module.h @@ -37,7 +37,6 @@ class Module : public std::enable_shared_from_this { bool has_parameter(const std::string &name) const; std::shared_ptr *mutable_parameter(const std::string &name); const std::shared_ptr ¶meter(const std::string &name) const; - virtual std::vector> Buffers() const; std::vector> modules(); @@ -48,6 +47,12 @@ class Module : public std::enable_shared_from_this { virtual std::vector> Forward(const std::vector> &input_tensors); + virtual float TrainStep(const std::vector> &input_tensors, + const std::vector> &targets, + const std::shared_ptr &loss_fn) { + return 0.0f; + }; + virtual void To(const Device *device); virtual void To(DataType dtype); diff --git a/infini_train/include/nn/parallel/global.h b/infini_train/include/nn/parallel/global.h index 0ae7599..1a6e22f 100644 --- a/infini_train/include/nn/parallel/global.h +++ b/infini_train/include/nn/parallel/global.h @@ -26,7 +26,8 @@ class GlobalEnv { public: static GlobalEnv &Instance(); - void Init(int threads_per_process, int tensor_parallel_size, bool sequence_parallel_enabled = false); + void Init(int threads_per_process, int tensor_parallel_size, bool sequence_parallel_enabled, + int pipeline_parallel_size); int world_size() const; @@ -44,6 +45,8 @@ class GlobalEnv { int data_parallel_size() const; + int pipeline_parallel_size() const; + Layout layout() const; private: @@ -65,14 +68,18 @@ class GlobalEnv { int data_parallel_size_ = 1; + int pipeline_parallel_size_ = 1; + mutable std::mutex mutex_; bool initialized_ = false; Layout layout_; }; -inline void InitAllEnv(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled = false) { - GlobalEnv::Instance().Init(nthread_per_process, tensor_parallel_size, sequence_parallel_enabled); +inline void InitAllEnv(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled, + int pipeline_parallel_size) { + GlobalEnv::Instance().Init(nthread_per_process, tensor_parallel_size, sequence_parallel_enabled, + pipeline_parallel_size); } inline int GetWorldSize() { return GlobalEnv::Instance().world_size(); } @@ -84,6 +91,7 @@ inline int GetLocalProcRank() { return GlobalEnv::Instance().local_proc_rank(); inline int GetTensorParallelSize() { return GlobalEnv::Instance().tensor_parallel_size(); } inline bool GetSequenceParallelEnabled() { return GlobalEnv::Instance().sequence_parallel_enabled(); } inline int GetDataParallelSize() { return GlobalEnv::Instance().data_parallel_size(); } +inline int GetPipelineParallelSize() { return GlobalEnv::Instance().pipeline_parallel_size(); } // Layout Helper Functions inline int GetRankOf(int dp, int tp, int pp) { return GlobalEnv::Instance().layout().RankOf(dp, tp, pp); } diff --git a/infini_train/include/nn/parallel/pp/pipeline_parallel.h b/infini_train/include/nn/parallel/pp/pipeline_parallel.h new file mode 100644 index 0000000..842bff4 --- /dev/null +++ b/infini_train/include/nn/parallel/pp/pipeline_parallel.h @@ -0,0 +1,44 @@ +// pipeline_parallel.h +#pragma once + +#include +#include + +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/optimizer.h" + +namespace infini_train { +class Tensor; +class Device; +} // namespace infini_train + +namespace infini_train::nn::parallel { +class PipelineStage; +class PipelineSchedule; + +extern thread_local int pp_rank; + +class PipelineParallel : public Module { +public: + PipelineParallel(const std::shared_ptr module, int num_stages, int num_micro_batches, + const std::vector> &recv_shape, int rank, + const std::shared_ptr &optimizer); + + float TrainStep(const std::vector> &input, + const std::vector> &target, const std::shared_ptr &loss_fn); + + static std::tuple GetStageInfo(int total_layers, int pp_size); + +private: + int num_stages_ = -1; + int rank_ = -1; + std::shared_ptr pipeline_stage_ = nullptr; + std::shared_ptr schedule_ = nullptr; + + void BuildPipelineStage(const std::shared_ptr &model, const std::shared_ptr &optimizer, + const std::vector> &recv_shape); + + void SetupSchedule(int num_micro_batches); +}; + +} // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/pp/pipeline_schedule.h b/infini_train/include/nn/parallel/pp/pipeline_schedule.h new file mode 100644 index 0000000..e5f13f8 --- /dev/null +++ b/infini_train/include/nn/parallel/pp/pipeline_schedule.h @@ -0,0 +1,60 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/modules/module.h" + +namespace infini_train { +class Tensor; +} // namespace infini_train + +namespace infini_train::nn::parallel { + +class PipelineStage; + +class PipelineSchedule { +public: + PipelineSchedule(std::shared_ptr stage, int num_stages, int num_micro_batches, int stage_index) + : stage_(std::move(stage)), num_micro_batches_(num_micro_batches), stage_index_(stage_index) {} + + virtual ~PipelineSchedule() = default; + + float Step(std::shared_ptr input, std::shared_ptr target, + const std::shared_ptr &loss_fn); + + virtual float StepMicroBatches(const std::vector> &arg_mbs, + const std::vector> &target_mbs, + const std::shared_ptr &loss_fn) + = 0; + + std::vector> ReceiveFromPrev(); + std::vector> SendToNext(const std::vector> &tensors); + +protected: + int num_micro_batches_ = -1; + int stage_index_ = -1; + std::shared_ptr stage_ = nullptr; +}; + +class ScheduleGPipe : public PipelineSchedule { +public: + ScheduleGPipe(std::shared_ptr stage, int num_stages, int num_micro_batches, int stage_index) + : PipelineSchedule(std::move(stage), num_stages, num_micro_batches, stage_index){}; + + float StepMicroBatches(const std::vector> &arg_mbs, + const std::vector> &target_mbs, + const std::shared_ptr &loss_fn) override; +}; + +class Schedule1F1B : public PipelineSchedule { +public: + Schedule1F1B(std::shared_ptr stage, int num_stages, int num_micro_batches, int stage_index) + : PipelineSchedule(std::move(stage), num_stages, num_micro_batches, stage_index){}; + + float StepMicroBatches(const std::vector> &arg_mbs, + const std::vector> &target_mbs, + const std::shared_ptr &loss_fn) override; +}; + +} // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/pp/pipeline_stage.h b/infini_train/include/nn/parallel/pp/pipeline_stage.h new file mode 100644 index 0000000..7d6b408 --- /dev/null +++ b/infini_train/include/nn/parallel/pp/pipeline_stage.h @@ -0,0 +1,44 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/optimizer.h" + +namespace infini_train { +class Tensor; +class Device; +} // namespace infini_train + +namespace infini_train::nn::parallel { + +class PipelineStage { +public: + PipelineStage(const std::shared_ptr &model, int stage_index, int num_stages, + const std::vector> &recv_shape, std::shared_ptr optimizer); + + std::vector> ForwardOneChunk(const std::vector> &inputs); + + bool IsFirstStage() const { return stage_index_ == 0; } + bool IsLastStage() const { return stage_index_ == num_stages_ - 1; } + int stage_index() const { return stage_index_; } + int prev_rank() const { return prev_rank_; } + int next_rank() const { return next_rank_; } + int num_stages() const { return num_stages_; } + const Device *device() const { return device_; } + const std::vector> &recv_shape() const { return recv_shape_; } + std::shared_ptr optimizer() { return optimizer_; } + +private: + int stage_index_ = -1; + int num_stages_ = -1; + int prev_rank_ = -1; + int next_rank_ = -1; + const Device *device_ = nullptr; + std::shared_ptr model_ = nullptr; + std::shared_ptr optimizer_ = nullptr; + std::vector> recv_shape_; +}; + +} // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/pp/send_recv.h b/infini_train/include/nn/parallel/pp/send_recv.h new file mode 100644 index 0000000..f76f4c7 --- /dev/null +++ b/infini_train/include/nn/parallel/pp/send_recv.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include + +namespace infini_train { +class Tensor; +class Device; +} // namespace infini_train + +namespace infini_train::nn::parallel { + +std::vector> ISend(const std::vector> &input_tensors, + const Device *target_device, int cur_rank, int peer_rank, + const std::vector> &shape); + +std::vector> IRecv(const std::vector> &outputs, + const Device *src_device, int cur_rank, int peer_rank); +} // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/utils.h b/infini_train/include/nn/parallel/utils.h index 2610329..3eb3960 100644 --- a/infini_train/include/nn/parallel/utils.h +++ b/infini_train/include/nn/parallel/utils.h @@ -8,7 +8,11 @@ std::string GetDataParallelProcessGroupName(int thread_rank); std::string GetTensorParallelProcessGroupName(int thread_rank); +std::string GetPipelineParallelProcessGroupName(int thread_rank); + std::vector GetDataParallelGroupRanks(int rank); std::vector GetTensorParallelGroupRanks(int rank); + +std::vector GetPipelineParallelGroupRanks(int pp_world_size); } // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/global.cc b/infini_train/src/nn/parallel/global.cc index 3feaab9..8c97b22 100644 --- a/infini_train/src/nn/parallel/global.cc +++ b/infini_train/src/nn/parallel/global.cc @@ -84,7 +84,8 @@ GlobalEnv &GlobalEnv::Instance() { return instance; } -void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled) { +void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled, + int pipeline_parallel_size) { std::lock_guard lock(mutex_); CHECK(!initialized_) << "Repeated initialization of GlobalEnv!"; @@ -98,7 +99,8 @@ void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool seq CHECK_GE(tensor_parallel_size, 1) << "Tensor Parallel size must be >= 1"; tensor_parallel_size_ = tensor_parallel_size; sequence_parallel_enabled_ = sequence_parallel_enabled; - data_parallel_size_ = world_size_ / tensor_parallel_size_; + pipeline_parallel_size_ = pipeline_parallel_size; + data_parallel_size_ = world_size_ / tensor_parallel_size_ / pipeline_parallel_size_; layout_.sizes[DP] = data_parallel_size_; layout_.sizes[TP] = tensor_parallel_size_; @@ -149,6 +151,11 @@ int GlobalEnv::data_parallel_size() const { return data_parallel_size_; } +int GlobalEnv::pipeline_parallel_size() const { + CHECK(initialized_) << "GlobalEnv is not initialized!"; + return pipeline_parallel_size_; +} + Layout GlobalEnv::layout() const { CHECK(initialized_) << "GlobalEnv is not initialized!"; return layout_; diff --git a/infini_train/src/nn/parallel/pp/pipeline_parallel.cc b/infini_train/src/nn/parallel/pp/pipeline_parallel.cc new file mode 100644 index 0000000..6fffcfb --- /dev/null +++ b/infini_train/src/nn/parallel/pp/pipeline_parallel.cc @@ -0,0 +1,67 @@ +// pipeline_parallel.cc +#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" + +#include +#include + +#include "infini_train/include/nn/modules/container.h" +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/parallel/pp/pipeline_schedule.h" +#include "infini_train/include/nn/parallel/pp/pipeline_stage.h" +#include "infini_train/include/optimizer.h" + +namespace infini_train::nn::parallel { + +thread_local int pp_rank = 0; + +void PipelineParallel::BuildPipelineStage(const std::shared_ptr &module, + const std::shared_ptr &optimizer, + const std::vector> &recv_shape) { + pipeline_stage_ = std::make_shared(module, rank_, num_stages_, recv_shape, optimizer); +} + +void PipelineParallel::SetupSchedule(int num_micro_batches) { + schedule_ = std::make_shared(pipeline_stage_, num_stages_, num_micro_batches, rank_); +} + +float PipelineParallel::TrainStep(const std::vector> &input, + const std::vector> &target, + const std::shared_ptr &loss_fn) { + std::shared_ptr stage_input; + std::shared_ptr stage_target = target[0]; + if (rank_ == 0) { + stage_input = input[0]; + } + + return schedule_->Step(stage_input, stage_target, loss_fn); +} + +std::tuple PipelineParallel::GetStageInfo(int total_layers, int pp_size) { + int rank = pp_rank; + bool is_first_stage = (pp_rank == 0); + bool is_last_stage = (pp_rank == pp_size - 1); + + int layers_per_stage = total_layers / pp_size; + int remainder = total_layers % pp_size; + int start_layer, end_layer; + if (pp_rank < remainder) { + start_layer = pp_rank * (layers_per_stage + 1); + end_layer = start_layer + layers_per_stage + 1; + } else { + start_layer = pp_rank * layers_per_stage + remainder; + end_layer = start_layer + layers_per_stage; + } + + return {is_first_stage, is_last_stage, start_layer, end_layer}; +} + +PipelineParallel::PipelineParallel(const std::shared_ptr module, int num_stages, int num_micro_batches, + const std::vector> &recv_shape, int rank, + const std::shared_ptr &optimizer) + : num_stages_(num_stages), rank_(rank) { + BuildPipelineStage(module, optimizer, recv_shape); + + SetupSchedule(num_micro_batches); +} + +} // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/pp/pipeline_schedule.cc b/infini_train/src/nn/parallel/pp/pipeline_schedule.cc new file mode 100644 index 0000000..884e24a --- /dev/null +++ b/infini_train/src/nn/parallel/pp/pipeline_schedule.cc @@ -0,0 +1,129 @@ +// pipeline_schedule.cc +#include "infini_train/include/nn/parallel/pp/pipeline_schedule.h" + +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/autograd/grad_mode.h" +#include "infini_train/include/device.h" +#include "infini_train/include/nn/init.h" +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/parallel/pp/pipeline_stage.h" +#include "infini_train/include/nn/parallel/pp/send_recv.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::parallel { + +float PipelineSchedule::Step(std::shared_ptr input, std::shared_ptr target, + const std::shared_ptr &loss_fn) { + std::vector> micro_batches(num_micro_batches_); + std::vector> target_mbs(num_micro_batches_); + if (stage_->IsFirstStage()) { + { + autograd::NoGradGuard no_grad; + micro_batches = input->Split(input->Dims()[0] / num_micro_batches_); + } + } + + if (stage_->IsLastStage()) { + { + autograd::NoGradGuard no_grad; + target_mbs = target->Split(target->Dims()[0] / num_micro_batches_); + } + } + + const auto &optimizer = stage_->optimizer(); + + optimizer->ZeroGrad(); + + float lossf = StepMicroBatches(micro_batches, target_mbs, loss_fn); + + optimizer->Step(); + + return lossf; +} + +std::vector> PipelineSchedule::ReceiveFromPrev() { + std::vector> recv_tensors; + auto &shapes = stage_->recv_shape(); + for (size_t i = 0; i < shapes.size(); ++i) { + // FIXME(jym): The data type between stages is not float32, which will cause a crash + auto tensor = std::make_shared(shapes[i], DataType::kFLOAT32, stage_->device()); + tensor->set_requires_grad(true); + tensor->set_is_leaf(false); + recv_tensors.push_back(tensor); + } + + return IRecv(recv_tensors, stage_->device(), stage_->stage_index(), stage_->prev_rank()); +} + +std::vector> PipelineSchedule::SendToNext(const std::vector> &tensors) { + return ISend(tensors, stage_->device(), stage_->stage_index(), stage_->next_rank(), stage_->recv_shape()); +} + +float ScheduleGPipe::StepMicroBatches(const std::vector> µbatch_inputs, + const std::vector> µbatch_targets, + const std::shared_ptr &loss_fn) { + const auto n = num_micro_batches_; + if (n == 0) { + return 0.0f; + } + + std::vector>> outputs(n); + + // ======== Forward Pass ======== + for (int mb = 0; mb < n; ++mb) { + std::vector> inputs; + if (stage_->IsFirstStage()) { + inputs = {microbatch_inputs[mb]}; + } else { + inputs = ReceiveFromPrev(); + } + + outputs[mb] = stage_->ForwardOneChunk(inputs); + + if (!stage_->IsLastStage()) { + outputs[mb] = SendToNext(outputs[mb]); + } + } + + // ======== Backward Pass ======== + float total_loss = 0.0f; + if (!stage_->IsLastStage()) { + for (int mb = 0; mb < n; ++mb) { + auto out_tensor = outputs[mb][0]; + + auto gradient = std::make_shared(out_tensor->Dims(), out_tensor->Dtype(), out_tensor->GetDevice()); + + out_tensor->Backward(gradient); + } + } else { + for (int mb = 0; mb < n; ++mb) { + auto target = microbatch_targets[mb]; + auto output = outputs[mb][0]; + + if (!target || !output) { + LOG(FATAL) << "Output or target is null at mb=" << mb; + } + + auto target_on_device = target->To(output->GetDevice()); + auto loss = loss_fn->Forward({output, std::make_shared(target_on_device)})[0]; + if (!loss) { + LOG(FATAL) << "[ERROR] loss is nullptr at mb = " << mb; + } + + loss = loss / n; + auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice()); + total_loss += static_cast(loss_cpu.DataPtr())[0]; + + loss->Backward(); + } + } + + return total_loss; +} + +} // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/pp/pipeline_stage.cc b/infini_train/src/nn/parallel/pp/pipeline_stage.cc new file mode 100644 index 0000000..f09ae13 --- /dev/null +++ b/infini_train/src/nn/parallel/pp/pipeline_stage.cc @@ -0,0 +1,24 @@ +#include "infini_train/include/nn/parallel/pp/pipeline_stage.h" + +#include + +#include "glog/logging.h" + +#include "infini_train/include/device.h" + +namespace infini_train::nn::parallel { + +PipelineStage::PipelineStage(const std::shared_ptr &model, int stage_index, int num_stages, + const std::vector> &recv_shape, std::shared_ptr optimizer) + : model_(model), stage_index_(stage_index), num_stages_(num_stages), + prev_rank_(stage_index > 0 ? stage_index - 1 : -1), + next_rank_(stage_index < num_stages - 1 ? stage_index + 1 : -1), recv_shape_(recv_shape), + optimizer_(std::move(optimizer)), + device_(DeviceManager::Instance()->GetAllAvailableDevices(DeviceType::kCUDA).at(stage_index)) {} + +std::vector> +PipelineStage::ForwardOneChunk(const std::vector> &inputs) { + return model_->Forward(inputs); +} + +} // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/pp/send_recv.cc b/infini_train/src/nn/parallel/pp/send_recv.cc new file mode 100644 index 0000000..31cd065 --- /dev/null +++ b/infini_train/src/nn/parallel/pp/send_recv.cc @@ -0,0 +1,129 @@ +#include "infini_train/include/nn/parallel/pp/send_recv.h" + +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/autograd/function.h" +#include "infini_train/include/device.h" +#include "infini_train/include/nn/parallel/process_group.h" +#include "infini_train/include/nn/parallel/utils.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::parallel { + +namespace functions { +class ISend : public autograd::Function { +public: + static constexpr char kType[] = "ISendFunction"; + + explicit ISend(const Device *target_device, int cur_rank, int peer_rank, + const std::vector> &shape) + : autograd::Function(kType), target_device_(target_device), cur_rank_(cur_rank), peer_rank_(peer_rank), + shapes_(shape) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + const Device *target_device_ = nullptr; + const Device *input_device_ = nullptr; + int cur_rank_ = -1; + int peer_rank_ = -1; + const std::vector> &shapes_; +}; + +class IRecv : public autograd::Function { +public: + static constexpr char kType[] = "IRecvFunction"; + + explicit IRecv(const Device *src_device, int cur_rank, int peer_rank) + : autograd::Function(kType), src_device_(src_device), cur_rank_(cur_rank), peer_rank_(peer_rank) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + const Device *src_device_ = nullptr; + const Device *cur_device_ = nullptr; + int cur_rank_ = -1; + int peer_rank_ = -1; +}; + +std::vector> ISend::Forward(const std::vector> &input_tensors) { + const auto &input = input_tensors[0]; + input_device_ = input->GetDevice(); + + auto pp_group = ProcessGroupFactory::Instance()->Get( + GetPipelineParallelProcessGroupName(input_device_->rank().thread_rank())); + + pp_group->NcclSend(input_tensors, peer_rank_); + + std::vector> outputs; + for (auto t : input_tensors) { outputs.push_back(t); } + return outputs; +} + +std::vector> ISend::Backward(const std::vector> &grad_outputs) { + std::vector> recv_tensors; + for (int shape_i = 0; shape_i < shapes_.size(); ++shape_i) { + auto r_tensor = std::make_shared(shapes_[shape_i], DataType::kFLOAT32, input_device_); + recv_tensors.push_back(r_tensor); + } + + auto pp_group = ProcessGroupFactory::Instance()->Get( + GetPipelineParallelProcessGroupName(input_device_->rank().thread_rank())); + + return pp_group->NcclRecv(recv_tensors, peer_rank_); +} + +std::vector> IRecv::Forward(const std::vector> &recv_tensors) { + CHECK_NOTNULL(src_device_); + auto pp_group + = ProcessGroupFactory::Instance()->Get(GetPipelineParallelProcessGroupName(src_device_->rank().thread_rank())); + pp_group->NcclRecv(recv_tensors, peer_rank_); + + std::vector> outputs; + for (auto t : recv_tensors) { + auto t_item = std::make_shared(*t); + t_item->set_requires_grad(true); + outputs.push_back(t); + } + return outputs; +} + +void IRecv::SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) { + if (output_tensors.empty()) { + return; + } + cur_device_ = output_tensors[0]->GetDevice(); +} + +std::vector> IRecv::Backward(const std::vector> &grad_outputs) { + auto pp_group + = ProcessGroupFactory::Instance()->Get(GetPipelineParallelProcessGroupName(cur_device_->rank().thread_rank())); + return pp_group->NcclSend(grad_outputs, peer_rank_); +} +} // namespace functions + +std::vector> ISend(const std::vector> &input_tensors, + const Device *target_device, int cur_rank, int peer_rank, + const std::vector> &shape) { + auto func = std::make_shared(target_device, cur_rank, peer_rank, shape); + return func->Apply(input_tensors); +} + +std::vector> IRecv(const std::vector> &outputs, + const Device *src_device, int cur_rank, int peer_rank) { + auto func = std::make_shared(src_device, cur_rank, peer_rank); + return func->Apply(outputs); +} +} // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/utils.cc b/infini_train/src/nn/parallel/utils.cc index 8c4ad8c..4f66188 100644 --- a/infini_train/src/nn/parallel/utils.cc +++ b/infini_train/src/nn/parallel/utils.cc @@ -12,8 +12,18 @@ std::string GetTensorParallelProcessGroupName(int thread_rank) { return "TP" + std::to_string(global::GetGroupId(global::TP, thread_rank)); } +std::string GetPipelineParallelProcessGroupName(int thread_rank) { + return "PP" + std::to_string(global::GetGroupId(global::PP, thread_rank)); +} + std::vector GetDataParallelGroupRanks(int thread_rank) { return global::GetGroupRanks(global::DP, thread_rank); } std::vector GetTensorParallelGroupRanks(int thread_rank) { return global::GetGroupRanks(global::TP, thread_rank); } +std::vector GetPipelineParallelGroupRanks(int pp_world_size) { + std::vector ranks; + ranks.reserve(pp_world_size); + for (int i = 0; i < pp_world_size; ++i) { ranks.push_back(i); } + return ranks; +} } // namespace infini_train::nn::parallel diff --git a/infini_train/src/optimizer.cc b/infini_train/src/optimizer.cc index 988cfab..c163ea6 100644 --- a/infini_train/src/optimizer.cc +++ b/infini_train/src/optimizer.cc @@ -20,6 +20,10 @@ SGD::SGD(const std::vector> ¶ms, float learning_rate void SGD::Step() { for (auto param : params_) { + if (!param->grad()) { + LOG(INFO) << "Skipping param with null grad."; + continue; + } auto device = param->GetDevice(); device->SetDevice(); auto kernel = Dispatcher::Instance().GetKernel({device->Type(), "AccumulateGrad"}); diff --git a/scripts/run_models_and_profile.bash b/scripts/run_models_and_profile.bash old mode 100644 new mode 100755