From 867d47443e34ce4d56b2eff0b88b4563f48cb62f Mon Sep 17 00:00:00 2001 From: JYMiracle305 <604951424@qq.com> Date: Wed, 5 Nov 2025 17:20:03 +0800 Subject: [PATCH 1/3] feat: add pipeline parallel --- example/common/utils.cc | 7 + example/common/utils.h | 1 + example/gpt2/main.cc | 85 +++++++-- example/gpt2/net.cc | 54 ++++++ example/gpt2/net.h | 3 + example/llama3/main.cc | 82 +++++++-- example/llama3/net.cc | 79 ++++++++ example/llama3/net.h | 4 + infini_train/include/nn/modules/container.h | 10 + infini_train/include/nn/modules/module.h | 11 +- infini_train/include/nn/parallel/global.h | 13 +- .../nn/parallel/pp/pipeline_parallel.h | 50 +++++ .../nn/parallel/pp/pipeline_schedule.h | 55 ++++++ .../include/nn/parallel/pp/pipeline_stage.h | 41 +++++ .../include/nn/parallel/pp/send_recv.h | 16 ++ infini_train/include/nn/parallel/utils.h | 2 + infini_train/src/nn/parallel/global.cc | 9 +- .../src/nn/parallel/pp/pipeline_parallel.cc | 108 +++++++++++ .../src/nn/parallel/pp/pipeline_schedule.cc | 171 ++++++++++++++++++ .../src/nn/parallel/pp/pipeline_stage.cc | 32 ++++ infini_train/src/nn/parallel/pp/send_recv.cc | 135 ++++++++++++++ infini_train/src/nn/parallel/utils.cc | 4 + scripts/run_models_and_profile.bash | 0 23 files changed, 945 insertions(+), 27 deletions(-) create mode 100644 infini_train/include/nn/parallel/pp/pipeline_parallel.h create mode 100644 infini_train/include/nn/parallel/pp/pipeline_schedule.h create mode 100644 infini_train/include/nn/parallel/pp/pipeline_stage.h create mode 100644 infini_train/include/nn/parallel/pp/send_recv.h create mode 100644 infini_train/src/nn/parallel/pp/pipeline_parallel.cc create mode 100644 infini_train/src/nn/parallel/pp/pipeline_schedule.cc create mode 100644 infini_train/src/nn/parallel/pp/pipeline_stage.cc create mode 100644 infini_train/src/nn/parallel/pp/send_recv.cc mode change 100644 => 100755 scripts/run_models_and_profile.bash diff --git a/example/common/utils.cc b/example/common/utils.cc index 03cc7aa0..347aa195 100644 --- a/example/common/utils.cc +++ b/example/common/utils.cc @@ -61,4 +61,11 @@ void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t s ifs.seekg(base + std::streamoff(len * sizeof(float))); } +std::vector GetPipelineParallelGroupRanks(int pp_world_size) { + std::vector ranks; + ranks.reserve(pp_world_size); + for (int i = 0; i < pp_world_size; ++i) { ranks.push_back(i); } + return ranks; +} + } // namespace infini_train diff --git a/example/common/utils.h b/example/common/utils.h index 5bab3e97..05ff7b36 100644 --- a/example/common/utils.h +++ b/example/common/utils.h @@ -30,4 +30,5 @@ void ReadVectorAllFloat(std::ifstream &ifs, float *dst, int64_t len); void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t start, int64_t cnt); +std::vector GetPipelineParallelGroupRanks(int rank); } // namespace infini_train diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index e3322d3d..72a2b45f 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -17,6 +17,7 @@ #include "infini_train/include/nn/parallel/distributed_data_parallel.h" #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/parallel_functional.h" +#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" #include "infini_train/include/nn/parallel/rank.h" #include "infini_train/include/nn/parallel/reduce_op_type.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" @@ -63,6 +64,11 @@ DEFINE_int32( "When set > 1, enables data parallelism with device=cuda on the specified number of visible CUDA devices."); DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size"); DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel"); +DEFINE_bool( + pipeline_parallel, false, + "use pipeline parallelism or not, will always use device=cuda and use all cuda visible devices when set to true"); +DEFINE_uint32(num_microbatches, 4, "the num of microbatches in pipeline parallelism"); + // precision DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)"); @@ -106,6 +112,7 @@ void Train(const nn::parallel::Rank &rank) { int ddp_world_size = global::GetDataParallelSize(); int tp_world_size = global::GetTensorParallelSize(); int sp_world_size = global::GetSequenceParallelEnabled() ? tp_world_size : 0; + int pp_world_size = global::GetPipelineParallelSize(); if (FLAGS_sequence_parallel) { CHECK_EQ(FLAGS_sequence_length % tp_world_size, 0) @@ -114,9 +121,11 @@ void Train(const nn::parallel::Rank &rank) { int ddp_rank = 0; int tp_rank = 0; + int pp_rank = 0; const ProcessGroup *ddp_pg = nullptr; const ProcessGroup *tp_pg = nullptr; + const ProcessGroup *pp_pg = nullptr; if (rank.IsParallel()) { device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank()); @@ -134,6 +143,12 @@ void Train(const nn::parallel::Rank &rank) { // NOTE(zbl): Reserved for VocabParallelEmbedding nn::parallel::tp_rank = tp_rank; } + + if (pp_world_size > 1) { + pp_pg = ProcessGroupFactory::Instance()->GetOrCreate( + GetPipelineParallelProcessGroupName(rank.thread_rank()), GetPipelineParallelGroupRanks(pp_world_size)); + pp_rank = pp_pg->GetGroupRank(rank.thread_rank()); + } } else { device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDefaultDevice() : DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, 0); @@ -174,6 +189,10 @@ void Train(const nn::parallel::Rank &rank) { LOG(FATAL) << "Rank " << rank.thread_rank() << ": Datatype " << FLAGS_dtype << " not supported."; } + // TODO(jym): Temporary implementation before 3D parallelism + if (FLAGS_pipeline_parallel) { + ddp_world_size = 1; + } // NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions // before wrapping the model with DistributedDataParallel (DDP). // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors @@ -182,8 +201,17 @@ void Train(const nn::parallel::Rank &rank) { model = std::make_shared(model, rank.thread_rank()); } - DistributedDataLoader train_loader(std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), - FLAGS_batch_size, ddp_rank, ddp_world_size); + std::unique_ptr train_loader; + if (FLAGS_pipeline_parallel) { + train_loader = std::make_unique( + std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), + FLAGS_batch_size * pp_world_size); + } else { + train_loader = std::make_unique( + std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), FLAGS_batch_size, + ddp_rank, ddp_world_size); + } + std::optional val_loader = std::nullopt; if (!FLAGS_input_val_bin.empty()) { val_loader = DistributedDataLoader( @@ -201,9 +229,13 @@ void Train(const nn::parallel::Rank &rank) { } // TODO(dcj): support more complex optimizer later - auto optimizer = optimizers::SGD(model->Parameters(), FLAGS_learning_rate); + auto lr = FLAGS_learning_rate; + auto optimizer_factory = [lr](const std::vector> ¶ms) { + return std::make_shared(params, lr); + }; + auto optimizer = optimizer_factory(model->Parameters()); - auto train_iter = train_loader.begin(); + auto train_iter = train_loader->begin(); std::shared_ptr loss_fn = (tp_world_size > 1) ? std::static_pointer_cast( std::make_shared(model_config.original_vocab_size)) @@ -211,6 +243,19 @@ void Train(const nn::parallel::Rank &rank) { loss_fn->To(device); LOG(INFO) << "Rank " << rank.thread_rank() << ": start training"; + if (FLAGS_pipeline_parallel) { + CHECK_EQ((FLAGS_batch_size * pp_world_size) % FLAGS_num_microbatches, 0) + << "FLAGS_batch_size (" << (FLAGS_batch_size * pp_world_size) + << ") must be divisible by FLAGS_num_microbatches (" << FLAGS_num_microbatches << ")"; + auto shapes = std::vector>{{(FLAGS_batch_size * pp_world_size) / FLAGS_num_microbatches, + FLAGS_sequence_length, model->GetConfig()["n_embd"]}}; + + model = std::make_shared(model, pp_world_size, FLAGS_num_microbatches, shapes, + pp_rank, optimizer_factory); + } + + LOG(INFO) << "start training"; + for (int step = 0; step < FLAGS_num_iteration + 1; ++step) { const bool last_step = step == FLAGS_num_iteration; @@ -234,7 +279,9 @@ void Train(const nn::parallel::Rank &rank) { } // model->Train(); - optimizer.ZeroGrad(); + if (!FLAGS_pipeline_parallel) { + optimizer->ZeroGrad(); + } // if we are trying to overfit a single batch, we reset the loader here if (FLAGS_overfit_single_batch) { // train_loader.Reset(); @@ -254,6 +301,19 @@ void Train(const nn::parallel::Rank &rank) { ++train_iter; x = std::make_shared(x->To(device)); y = std::make_shared(y->To(device)); + + if (FLAGS_pipeline_parallel) { + lossf = model->TrainStep({x}, {y}, loss_fn); + + auto loss_tensor = std::make_shared(std::vector{}, DataType::kFLOAT32); + static_cast(loss_tensor->DataPtr())[0] = lossf; + auto loss_device_ptr = std::make_shared(loss_tensor->To(device)); + function::AllReduce(loss_device_ptr, function::ReduceOpType::kMax); + auto loss_copy = loss_device_ptr->To(DeviceManager::Instance()->GetDefaultDevice()); + lossf = static_cast(loss_copy.DataPtr())[0]; + continue; + } + LOG(INFO) << "Rank " << rank.thread_rank() << ": start forward"; // (bs, seq_len, vocab_size) auto logits = model->Forward({x, y})[0]; @@ -274,17 +334,19 @@ void Train(const nn::parallel::Rank &rank) { loss->Backward(); LOG(INFO) << "Rank " << rank.thread_rank() << ": finish backward"; } - optimizer.Step(); + if (!FLAGS_pipeline_parallel) { + optimizer->Step(); + } const auto iter_end = std::chrono::high_resolution_clock::now(); const double duration_us = std::chrono::duration(iter_end - iter_start).count(); const double tps = FLAGS_total_batch_size / (duration_us / 1e6); if (rank.IsMainRank()) { - LOG(ERROR) << std::format( - "step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, DP={}, TP={}, SP={})", - step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, tps, ddp_world_size, - tp_world_size, sp_world_size); + LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, " + "DP={}, TP={}, SP={}, PP={})", + step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, + tps, ddp_world_size, tp_world_size, sp_world_size, pp_world_size); if ((step + 1) % FLAGS_freq_generate_txt == 0) { if (!tokenizer) { @@ -304,7 +366,8 @@ int main(int argc, char *argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, true); google::InitGoogleLogging(argv[0]); - nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel); + nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel, + FLAGS_pipeline_parallel); LOG(INFO) << nn::parallel::global::ProcessGroupOverview(); diff --git a/example/gpt2/net.cc b/example/gpt2/net.cc index a023316e..6640c026 100644 --- a/example/gpt2/net.cc +++ b/example/gpt2/net.cc @@ -212,6 +212,60 @@ GPT2::GPT2(const GPT2Config &config) : config_(config) { = module(kLMHeadLayerName).parameter(nn::parallel::ColumnParallelLinear::kParamWeightName); } +class EmbeddingLayer : public nn::Module { + +public: + EmbeddingLayer(std::shared_ptr wte, std::shared_ptr wpe) { + modules_["wte"] = wte; + modules_["wpe"] = wpe; + } + + std::vector> + Forward(const std::vector> &inputs) override { + auto &input_ids = inputs[0]; // (bs, seq_len) + const int seq_len = input_ids->Dims()[1]; + const auto device = input_ids->GetDevice(); + + // position ids: [0, 1, ..., seq_len-1] + auto pos_ids = nn::init::Arange(0, seq_len, infini_train::DataType::kINT64, device); + // (bs, seq_len) -> wte -> (bs, seq_len, n_embd) + auto tok_emb = modules_["wte"]->Forward({input_ids})[0]; + // (seq_len,) -> wpe -> (seq_len, n_embd) + auto pos_emb = modules_["wpe"]->Forward({pos_ids})[0]; + + auto output = tok_emb + pos_emb; // (bs, seq_len, n_embd) + + return {output}; + } +}; + +std::vector> GPT2::GetPipelineLayers() { + auto &transformer = modules_[kTransformerLayerName]; + + std::vector> layers; + + auto embedding_layer = std::make_shared(transformer->mutable_module(kWTELayerName), + transformer->mutable_module(kWPELayerName)); + layers.push_back(embedding_layer); + + auto seq = std::dynamic_pointer_cast(transformer->mutable_module(kHLayerName)); + if (seq) { + for (int idx = 0; idx < seq->size(); ++idx) { + auto sub_module = (*seq)[idx]; + layers.push_back(sub_module); + } + } + + layers.push_back(transformer->mutable_module(kLnFLayerName)); + layers.push_back(modules_[kLMHeadLayerName]); + + return layers; +} + +std::unordered_map GPT2::GetConfig() const { + return {{"n_embd", config_.n_embd}, {"n_head", config_.n_head}}; +} + std::vector> GPT2::Forward(const std::vector> &x) { // (B, T) diff --git a/example/gpt2/net.h b/example/gpt2/net.h index dba5cdfe..548e43a2 100644 --- a/example/gpt2/net.h +++ b/example/gpt2/net.h @@ -89,6 +89,9 @@ class GPT2 : public infini_train::nn::CloneableModule { explicit GPT2(const GPT2Config &config); + std::unordered_map GetConfig() const override; + std::vector> GetPipelineLayers() override; + std::vector> Forward(const std::vector> &x) override; diff --git a/example/llama3/main.cc b/example/llama3/main.cc index 9dc4cf07..f3795267 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -14,6 +14,7 @@ #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/nn/parallel/distributed_data_parallel.h" #include "infini_train/include/nn/parallel/parallel_functional.h" +#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" #include "infini_train/include/nn/parallel/rank.h" #include "infini_train/include/nn/parallel/reduce_op_type.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" @@ -62,6 +63,10 @@ DEFINE_int32( "When set > 1, enables data parallelism with device=cuda on the specified number of visible CUDA devices."); DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size"); DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel"); +DEFINE_bool( + pipeline_parallel, false, + "use pipeline parallelism or not, will always use device=cuda and use all cuda visible devices when set to true"); +DEFINE_uint32(num_microbatches, 4, "the num of microbatches in pipeline parallelism"); // precision DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)"); @@ -89,6 +94,7 @@ void Train(const nn::parallel::Rank &rank) { int ddp_world_size = global::GetDataParallelSize(); int tp_world_size = global::GetTensorParallelSize(); int sp_world_size = global::GetSequenceParallelEnabled() ? tp_world_size : 0; + int pp_world_size = global::GetPipelineParallelSize(); if (FLAGS_sequence_parallel) { CHECK_EQ(FLAGS_sequence_length % tp_world_size, 0) @@ -97,9 +103,11 @@ void Train(const nn::parallel::Rank &rank) { int ddp_rank = 0; int tp_rank = 0; + int pp_rank = 0; const ProcessGroup *ddp_pg = nullptr; const ProcessGroup *tp_pg = nullptr; + const ProcessGroup *pp_pg = nullptr; if (rank.IsParallel()) { device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank()); @@ -117,6 +125,12 @@ void Train(const nn::parallel::Rank &rank) { // NOTE(zbl): Reserved for VocabParallelEmbedding nn::parallel::tp_rank = tp_rank; } + + if (pp_world_size > 1) { + pp_pg = ProcessGroupFactory::Instance()->GetOrCreate( + GetPipelineParallelProcessGroupName(rank.thread_rank()), GetPipelineParallelGroupRanks(pp_world_size)); + pp_rank = pp_pg->GetGroupRank(rank.thread_rank()); + } } else { device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDefaultDevice() : DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, 0); @@ -155,6 +169,10 @@ void Train(const nn::parallel::Rank &rank) { LOG(FATAL) << "Rank " << rank.thread_rank() << ": Datatype " << FLAGS_dtype << " not supported."; } + // TODO(jym): Temporary implementation before 3D parallelism + if (FLAGS_pipeline_parallel) { + ddp_world_size = 1; + } // NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions // before wrapping the model with DistributedDataParallel (DDP). // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors @@ -163,8 +181,17 @@ void Train(const nn::parallel::Rank &rank) { model = std::make_shared(model, rank.thread_rank()); } - DistributedDataLoader train_loader(std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), - FLAGS_batch_size, ddp_rank, ddp_world_size); + std::unique_ptr train_loader; + if (FLAGS_pipeline_parallel) { + train_loader = std::make_unique( + std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), + FLAGS_batch_size * pp_world_size); + } else { + train_loader = std::make_unique( + std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), FLAGS_batch_size, + ddp_rank, ddp_world_size); + } + std::optional val_loader = std::nullopt; if (!FLAGS_input_val_bin.empty()) { val_loader = DistributedDataLoader( @@ -181,15 +208,31 @@ void Train(const nn::parallel::Rank &rank) { } // TODO(dcj): support more complex optimizer later - auto optimizer = optimizers::Adam(model->Parameters(), FLAGS_learning_rate); + auto lr = FLAGS_learning_rate; + auto optimizer_factory = [lr](const std::vector> ¶ms) { + return std::make_shared(params, lr); + }; + auto optimizer = optimizer_factory(model->Parameters()); - auto train_iter = train_loader.begin(); + auto train_iter = train_loader->begin(); std::shared_ptr loss_fn = (tp_world_size > 1) ? std::static_pointer_cast(std::make_shared()) : std::static_pointer_cast(std::make_shared()); loss_fn->To(device); LOG(INFO) << "Rank " << rank.thread_rank() << ": start training"; + if (FLAGS_pipeline_parallel) { + CHECK_EQ((FLAGS_batch_size * pp_world_size) % FLAGS_num_microbatches, 0) + << "FLAGS_batch_size (" << (FLAGS_batch_size * pp_world_size) + << ") must be divisible by FLAGS_num_microbatches (" << FLAGS_num_microbatches << ")"; + + auto shapes = std::vector>{{FLAGS_batch_size * pp_world_size / FLAGS_num_microbatches, + FLAGS_sequence_length, model->GetConfig()["n_embd"]}}; + + model = std::make_shared(model, pp_world_size, FLAGS_num_microbatches, shapes, + pp_rank, optimizer_factory); + } + for (int step = 0; step < FLAGS_num_iteration + 1; ++step) { const bool last_step = step == FLAGS_num_iteration; @@ -213,7 +256,9 @@ void Train(const nn::parallel::Rank &rank) { } // model->Train(); - optimizer.ZeroGrad(); + if (!FLAGS_pipeline_parallel) { + optimizer->ZeroGrad(); + } // if we are trying to overfit a single batch, we reset the loader here if (FLAGS_overfit_single_batch) { // train_loader.Reset(); @@ -233,6 +278,18 @@ void Train(const nn::parallel::Rank &rank) { ++train_iter; x = std::make_shared(x->To(device)); y = std::make_shared(y->To(device)); + if (FLAGS_pipeline_parallel) { + lossf = model->TrainStep({x}, {y}, loss_fn); + + auto loss_tensor = std::make_shared(std::vector{}, DataType::kFLOAT32); + static_cast(loss_tensor->DataPtr())[0] = lossf; + auto loss_device_ptr = std::make_shared(loss_tensor->To(device)); + function::AllReduce(loss_device_ptr, function::ReduceOpType::kMax); + auto loss_copy = loss_device_ptr->To(DeviceManager::Instance()->GetDefaultDevice()); + lossf = static_cast(loss_copy.DataPtr())[0]; + continue; + } + LOG(INFO) << "Rank " << rank.thread_rank() << ": start forward"; // (bs, seq_len, vocab_size) auto logits = model->Forward({x, y})[0]; @@ -254,17 +311,19 @@ void Train(const nn::parallel::Rank &rank) { LOG(INFO) << "Rank " << rank.thread_rank() << ": finish backward"; } - optimizer.Step(); + if (!FLAGS_pipeline_parallel) { + optimizer->Step(); + } const auto iter_end = std::chrono::high_resolution_clock::now(); const double duration_us = std::chrono::duration(iter_end - iter_start).count(); const double tps = FLAGS_total_batch_size / (duration_us / 1e6); if (rank.IsMainRank()) { - LOG(ERROR) << std::format( - "step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, DP={}, TP={}, SP={})", - step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, tps, ddp_world_size, - tp_world_size, sp_world_size); + LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, " + "DP={}, TP={}, SP={}, PP={})", + step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, + tps, ddp_world_size, tp_world_size, sp_world_size, pp_world_size); if ((step + 1) % FLAGS_freq_generate_txt == 0) { if (!tokenizer) { @@ -284,7 +343,8 @@ int main(int argc, char *argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, true); google::InitGoogleLogging(argv[0]); - nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel); + nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel, + FLAGS_pipeline_parallel); LOG(INFO) << nn::parallel::global::ProcessGroupOverview(); diff --git a/example/llama3/net.cc b/example/llama3/net.cc index 400d8b35..44cabe36 100644 --- a/example/llama3/net.cc +++ b/example/llama3/net.cc @@ -308,6 +308,7 @@ std::vector> Block::Forward(const std::vector 1 ? x[1] : nullptr; const auto start_pos = x.size() > 2 ? x[2] : nullptr; const auto mask = x.size() > 3 ? x[3] : nullptr; + // (bs, seq_len, n_embd) -> RMSNorm -> (bs, seq_len, n_embd) -> attention -> (bs, seq_len, n_embd) // -> Add -> (bs, seq_len, n_embd) auto x1 = x[0] @@ -346,6 +347,84 @@ LLaMA3::LLaMA3(const LLaMA3Config &config) : config_(config) { /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); } +class LLaMALayer : public nn::Module { + std::shared_ptr h_layer_; + std::shared_ptr freqs_cis_ = nullptr; // freqs_cis (shape: [max_seq_len, head_dim]) + int64_t start_pos_ = 0; + DataType dtype_; + LLaMA3Config config_; + +public: + LLaMALayer(std::shared_ptr h_layer, DataType dtype, LLaMA3Config config) + : h_layer_(h_layer), dtype_(dtype), config_(config) {} + + void SetFreqsCis(std::shared_ptr freqs) { freqs_cis_ = freqs; } + void SetStartPos(int64_t pos) { start_pos_ = pos; } + + std::vector> Parameters() const override { return h_layer_->Parameters(); } + + std::vector> Forward(const std::vector> &inputs) override { + auto &x = inputs[0]; // (bs, seq_len, n_embd) + const int seq_len = x->Dims()[1]; + const auto device = x->GetDevice(); + + if (freqs_cis_ == nullptr) { + freqs_cis_ = PrecomputeFreqsCis(config_.n_embd / config_.n_head, config_.block_size * 2, config_.rope_theta, + config_.use_scaled_rope, device, dtype_); + } + + // slice freqs_cis: [start_pos:start_pos+seq_len] + auto freqs_view = freqs_cis_->Slice(0, start_pos_, start_pos_ + seq_len, 1); + + std::shared_ptr start_pos_ptr = nullptr; + + // causal mask: (1, 1, seq_len, seq_len) + std::shared_ptr ones = std::make_shared(nn::function::Ones({seq_len, seq_len})->To(device)); + auto mask = nn::function::Triu(ones, 1)->View({1, 1, seq_len, seq_len}); + if (dtype_ == DataType::kBFLOAT16) { + mask = std::make_shared(mask->To(DataType::kBFLOAT16)); + } + + // DecoderLayer: {x, freqs, start_pos, mask} + // std::vector> args = {x, freqs_view, nullptr, mask}; + auto output = h_layer_->Forward({x, freqs_view, start_pos_ptr, mask}); + return output; + } +}; + +std::vector> LLaMA3::GetPipelineLayers() { + std::vector> layers; + + auto transformer = modules_[kTransformerLayerName]; + layers.push_back(transformer->mutable_module(kWTELayerName)); + + auto seq = std::dynamic_pointer_cast(transformer->mutable_module(kHLayerName)); + auto dtype = modules_[kLMHeadLayerName]->parameter(nn::Linear::kParamWeightName)->Dtype(); + int idx = 0; + for (auto h : *seq) { + layers.push_back(std::make_shared(h, dtype, config_)); + ++idx; + } + // for (int idx = 0; idx < seq->size(); ++idx) { + // auto wrapped_layer = (*seq)[idx]; + // layers.push_back(std::make_shared(wrapped_layer, dtype, config_)); + // for (int idx = 1; idx < seq->size(); ++idx) { + // auto wrapped_layer = (*seq)[idx]; + // layers.push_back(wrapped_layer); + // } + // } + + layers.push_back(transformer->mutable_module(kLnFLayerName)); + + layers.push_back(modules_[kLMHeadLayerName]); + + return layers; +} + +std::unordered_map LLaMA3::GetConfig() const { + return {{"n_embd", config_.n_embd}, {"n_head", config_.n_head}}; +} + std::vector> LLaMA3::Forward(const std::vector> &x) { // (bs, seq_len) auto &idx = x[0]; diff --git a/example/llama3/net.h b/example/llama3/net.h index ec0199aa..d89907c4 100644 --- a/example/llama3/net.h +++ b/example/llama3/net.h @@ -129,6 +129,10 @@ class LLaMA3 : public infini_train::nn::CloneableModule { explicit LLaMA3(const LLaMA3Config &config); + std::vector> GetPipelineLayers() override; + + std::unordered_map GetConfig() const; + std::vector> Forward(const std::vector> &x) override; diff --git a/infini_train/include/nn/modules/container.h b/infini_train/include/nn/modules/container.h index d1356712..493c1ba8 100644 --- a/infini_train/include/nn/modules/container.h +++ b/infini_train/include/nn/modules/container.h @@ -17,6 +17,16 @@ class Sequential : public CloneableModule { explicit Sequential(std::vector> &&layers); std::vector> Forward(const std::vector> &input_tensors) override; + + size_t size() const { return modules_.size(); } + + std::shared_ptr operator[](size_t idx) const { + auto it = modules_.find(std::to_string(idx)); + if (it != modules_.end()) { + return it->second; + } + return nullptr; + } }; class ModuleDict : public CloneableModule { diff --git a/infini_train/include/nn/modules/module.h b/infini_train/include/nn/modules/module.h index ee3be588..85fc35f8 100644 --- a/infini_train/include/nn/modules/module.h +++ b/infini_train/include/nn/modules/module.h @@ -37,7 +37,6 @@ class Module : public std::enable_shared_from_this { bool has_parameter(const std::string &name) const; std::shared_ptr *mutable_parameter(const std::string &name); const std::shared_ptr ¶meter(const std::string &name) const; - virtual std::vector> Buffers() const; std::vector> modules(); @@ -48,6 +47,16 @@ class Module : public std::enable_shared_from_this { virtual std::vector> Forward(const std::vector> &input_tensors); + virtual std::vector> GetPipelineLayers() { return {}; } + + virtual std::unordered_map GetConfig() const { return {}; } + + virtual float TrainStep(const std::vector> &input_tensors, + const std::vector> &targets, + const std::shared_ptr &loss_fn) { + return 0.0f; + }; + virtual void To(const Device *device); virtual void To(DataType dtype); diff --git a/infini_train/include/nn/parallel/global.h b/infini_train/include/nn/parallel/global.h index 0ae75995..5c9aa198 100644 --- a/infini_train/include/nn/parallel/global.h +++ b/infini_train/include/nn/parallel/global.h @@ -26,7 +26,8 @@ class GlobalEnv { public: static GlobalEnv &Instance(); - void Init(int threads_per_process, int tensor_parallel_size, bool sequence_parallel_enabled = false); + void Init(int threads_per_process, int tensor_parallel_size, bool sequence_parallel_enabled = false, + bool pipeline_parallel = false); int world_size() const; @@ -44,6 +45,8 @@ class GlobalEnv { int data_parallel_size() const; + int pipeline_parallel_size() const; + Layout layout() const; private: @@ -65,14 +68,17 @@ class GlobalEnv { int data_parallel_size_ = 1; + int pipeline_parallel_size_ = 1; + mutable std::mutex mutex_; bool initialized_ = false; Layout layout_; }; -inline void InitAllEnv(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled = false) { - GlobalEnv::Instance().Init(nthread_per_process, tensor_parallel_size, sequence_parallel_enabled); +inline void InitAllEnv(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled = false, + bool pipeline_parallel = false) { + GlobalEnv::Instance().Init(nthread_per_process, tensor_parallel_size, sequence_parallel_enabled, pipeline_parallel); } inline int GetWorldSize() { return GlobalEnv::Instance().world_size(); } @@ -84,6 +90,7 @@ inline int GetLocalProcRank() { return GlobalEnv::Instance().local_proc_rank(); inline int GetTensorParallelSize() { return GlobalEnv::Instance().tensor_parallel_size(); } inline bool GetSequenceParallelEnabled() { return GlobalEnv::Instance().sequence_parallel_enabled(); } inline int GetDataParallelSize() { return GlobalEnv::Instance().data_parallel_size(); } +inline int GetPipelineParallelSize() { return GlobalEnv::Instance().pipeline_parallel_size(); } // Layout Helper Functions inline int GetRankOf(int dp, int tp, int pp) { return GlobalEnv::Instance().layout().RankOf(dp, tp, pp); } diff --git a/infini_train/include/nn/parallel/pp/pipeline_parallel.h b/infini_train/include/nn/parallel/pp/pipeline_parallel.h new file mode 100644 index 00000000..166c6849 --- /dev/null +++ b/infini_train/include/nn/parallel/pp/pipeline_parallel.h @@ -0,0 +1,50 @@ +// pipeline_parallel.h +#pragma once + +#include +#include + +#include "infini_train/include/device.h" +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/tensor.h" + +#include "infini_train/include/nn/parallel/pp/pipeline_schedule.h" +#include "infini_train/include/nn/parallel/pp/pipeline_stage.h" + +namespace infini_train::nn::parallel { + +using OptimizerFactory = std::function(const std::vector> ¶ms)>; + +class PipelineParallel : public Module { +public: + PipelineParallel(const std::shared_ptr &model, int num_stages, int num_microbatches, + const std::vector> &recv_shape, int rank, OptimizerFactory optimizer_factory); + + float TrainStep(const std::vector> &input, + const std::vector> &target, const std::shared_ptr &loss_fn); + +private: + int num_stages_; + int rank_; + std::vector devices_; + std::shared_ptr original_model_; + std::shared_ptr pipeline_stage_; + std::shared_ptr schedule_; + + std::vector>> + SplitLayersIntoStages(std::vector> layers); + + void SplitModel(const std::vector> &recv_shape, OptimizerFactory optimizer_factory); + + std::vector> + CreateOptimizers(const std::vector>> &stage_layers, + OptimizerFactory optimizer_factory); + + void BuildPipelineStage(const std::vector>> &stage_layers, + const std::vector> &optimizers, + const std::vector> &recv_shape); + + void SetupSchedule(int num_microbatches); +}; + +} // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/pp/pipeline_schedule.h b/infini_train/include/nn/parallel/pp/pipeline_schedule.h new file mode 100644 index 00000000..0e51b51a --- /dev/null +++ b/infini_train/include/nn/parallel/pp/pipeline_schedule.h @@ -0,0 +1,55 @@ +#pragma once +#include +#include + +#include "infini_train/include/nn/parallel/pp/pipeline_stage.h" + +namespace infini_train::nn::parallel { + +class PipelineSchedule { +public: + PipelineSchedule(std::shared_ptr stage, int num_stages, int num_microbatches, int stage_index) + : stage_(std::move(stage)), num_microbatches_(num_microbatches), stage_index_(stage_index) {} + + virtual ~PipelineSchedule() = default; + + float Step(std::shared_ptr input, std::shared_ptr target, const std::shared_ptr &loss_fn); + + virtual float StepMicrobatches(const std::vector> &arg_mbs, + const std::vector> &target_mbs, + const std::shared_ptr &loss_fn) + = 0; + + std::vector> ReceiveFromPrev(); + std::vector> SendToNext(const std::vector> &tensors); + +protected: + int num_microbatches_; + int stage_index_; + std::shared_ptr stage_; + +private: + std::vector> SplitTensor(std::shared_ptr full_inputs); +}; + +class ScheduleGPipe : public PipelineSchedule { +public: + ScheduleGPipe(std::shared_ptr stage, int num_stages, int num_microbatches, int stage_index) + : PipelineSchedule(std::move(stage), num_stages, num_microbatches, stage_index){}; + + float StepMicrobatches(const std::vector> &arg_mbs, + const std::vector> &target_mbs, + const std::shared_ptr &loss_fn) override; +}; + +class Schedule1F1B : public PipelineSchedule { +public: + Schedule1F1B(std::shared_ptr stage, int num_stages, int num_microbatches, int stage_index) + : PipelineSchedule(std::move(stage), num_stages, num_microbatches, stage_index){}; + + float StepMicrobatches(const std::vector> &arg_mbs, + const std::vector> &target_mbs, + const std::shared_ptr &loss_fn) override; +}; + +} // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/pp/pipeline_stage.h b/infini_train/include/nn/parallel/pp/pipeline_stage.h new file mode 100644 index 00000000..0f0a1515 --- /dev/null +++ b/infini_train/include/nn/parallel/pp/pipeline_stage.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include + +#include "infini_train/include/device.h" +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/optimizer.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::parallel { + +class PipelineStage { +public: + PipelineStage(const std::vector> &layers, int stage_index, int num_stages, + const std::vector> &recvShape, std::shared_ptr optim); + + std::vector> ForwardOneChunk(const std::vector> &inputs); + + bool IsFirstStage() const { return stage_index_ == 0; } + bool IsLastStage() const { return stage_index_ == num_stages_ - 1; } + int stage_index() const { return stage_index_; } + int prev_rank() const { return prev_rank_; } + int next_rank() const { return next_rank_; } + int num_stages() const { return num_stages_; } + const Device *device() const { return device_; } + std::vector> recv_shape() const { return recv_shape_; } + std::shared_ptr optimizer() { return optim_; } + +private: + int stage_index_; + int num_stages_; + int prev_rank_; + int next_rank_; + const Device *device_ = nullptr; + std::vector> recv_shape_; + std::vector> layers_; + std::shared_ptr optim_; +}; + +} // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/pp/send_recv.h b/infini_train/include/nn/parallel/pp/send_recv.h new file mode 100644 index 00000000..be5fe47f --- /dev/null +++ b/infini_train/include/nn/parallel/pp/send_recv.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include + +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::parallel { + +std::vector> ISend(const std::vector> &input_tensors, + const Device *target_device, int cur_rank, int peer_rank, + std::vector> shape); + +std::vector> IRecv(const std::vector> &outputs, + const Device *src_device, int cur_rank, int peer_rank); +} // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/utils.h b/infini_train/include/nn/parallel/utils.h index 26103294..aa7a473a 100644 --- a/infini_train/include/nn/parallel/utils.h +++ b/infini_train/include/nn/parallel/utils.h @@ -8,6 +8,8 @@ std::string GetDataParallelProcessGroupName(int thread_rank); std::string GetTensorParallelProcessGroupName(int thread_rank); +std::string GetPipelineParallelProcessGroupName(int thread_rank); + std::vector GetDataParallelGroupRanks(int rank); std::vector GetTensorParallelGroupRanks(int rank); diff --git a/infini_train/src/nn/parallel/global.cc b/infini_train/src/nn/parallel/global.cc index 3feaab9c..3f6fcb4f 100644 --- a/infini_train/src/nn/parallel/global.cc +++ b/infini_train/src/nn/parallel/global.cc @@ -84,7 +84,8 @@ GlobalEnv &GlobalEnv::Instance() { return instance; } -void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled) { +void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled, + bool pipeline_parallel) { std::lock_guard lock(mutex_); CHECK(!initialized_) << "Repeated initialization of GlobalEnv!"; @@ -99,6 +100,7 @@ void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool seq tensor_parallel_size_ = tensor_parallel_size; sequence_parallel_enabled_ = sequence_parallel_enabled; data_parallel_size_ = world_size_ / tensor_parallel_size_; + pipeline_parallel_size_ = pipeline_parallel ? nthread_per_process : 1; layout_.sizes[DP] = data_parallel_size_; layout_.sizes[TP] = tensor_parallel_size_; @@ -149,6 +151,11 @@ int GlobalEnv::data_parallel_size() const { return data_parallel_size_; } +int GlobalEnv::pipeline_parallel_size() const { + CHECK(initialized_) << "GlobalEnv is not initialized!"; + return pipeline_parallel_size_; +} + Layout GlobalEnv::layout() const { CHECK(initialized_) << "GlobalEnv is not initialized!"; return layout_; diff --git a/infini_train/src/nn/parallel/pp/pipeline_parallel.cc b/infini_train/src/nn/parallel/pp/pipeline_parallel.cc new file mode 100644 index 00000000..20e8ef36 --- /dev/null +++ b/infini_train/src/nn/parallel/pp/pipeline_parallel.cc @@ -0,0 +1,108 @@ +// pipeline_parallel.cc +#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" + +#include +#include + +#include "infini_train/include/nn/modules/container.h" +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/parallel/pp/pipeline_schedule.h" +#include "infini_train/include/nn/parallel/pp/pipeline_stage.h" +#include "infini_train/include/optimizer.h" + +namespace infini_train::nn::parallel { + +std::vector>> +PipelineParallel::SplitLayersIntoStages(std::vector> layers) { + const int total_layers = layers.size(); + CHECK_GT(total_layers, 0) << "Model has no layers to split!"; + CHECK_GE(num_stages_, 1) << "num_stages must be >= 1"; + CHECK_LE(num_stages_, total_layers) << "num_stages (" << num_stages_ << ") cannot be greater than total layers (" + << total_layers << ")"; + + std::vector>> stages(num_stages_); + int base_layers_per_stage = total_layers / num_stages_; + int remainder = total_layers % num_stages_; + int layer_idx = 0; + + for (int s = 0; s < num_stages_; ++s) { + int layers_in_this_stage = base_layers_per_stage + (s < remainder ? 1 : 0); + for (int i = 0; i < layers_in_this_stage; ++i) { + auto layer = layers[layer_idx]; + stages[s].emplace_back(layer); + layer_idx++; + } + } + + return stages; +} + +std::vector> +PipelineParallel::CreateOptimizers(const std::vector>> &stage_layers, + OptimizerFactory optimizer_factory) { + std::vector> optims; + optims.reserve(stage_layers.size()); + + for (int s = 0; s < num_stages_; ++s) { + std::vector> params; + for (const auto &layer : stage_layers[s]) { + layer->To(devices_[s]); + auto layer_params = layer->Parameters(); + params.insert(params.end(), layer_params.begin(), layer_params.end()); + } + + auto optim = optimizer_factory(params); + CHECK(optim != nullptr) << "Optimizer factory returned null optimizer for stage " << s; + optims.push_back(std::move(optim)); + } + return optims; +} + +void PipelineParallel::BuildPipelineStage(const std::vector>> &stage_layers, + const std::vector> &optimizers, + const std::vector> &recv_shape) { + pipeline_stage_ + = std::make_shared(stage_layers[rank_], rank_, num_stages_, recv_shape, optimizers[rank_]); +} + +void PipelineParallel::SplitModel(const std::vector> &recv_shape, + OptimizerFactory optimizer_factory) { + auto layers = original_model_->GetPipelineLayers(); + CHECK(!layers.empty()) << "SplitModel: GetPipelineLayers returned empty vector"; + + auto stage_layer = SplitLayersIntoStages(layers); + + auto optimizer = CreateOptimizers(stage_layer, optimizer_factory); + + BuildPipelineStage(stage_layer, optimizer, recv_shape); +} + +void PipelineParallel::SetupSchedule(int num_microbatches) { + schedule_ = std::make_shared(pipeline_stage_, num_stages_, num_microbatches, rank_); +} + +float PipelineParallel::TrainStep(const std::vector> &input, + const std::vector> &target, + const std::shared_ptr &loss_fn) { + std::shared_ptr stage_input; + std::shared_ptr stage_target = target[0]; + if (rank_ == 0) { + stage_input = input[0]; + } + + return schedule_->Step(stage_input, stage_target, loss_fn); +} + +PipelineParallel::PipelineParallel(const std::shared_ptr &model, int num_stages, int num_microbatches, + const std::vector> &recv_shape, int rank, + OptimizerFactory optimizer_factory) + : original_model_(model), devices_(DeviceManager::Instance()->GetAllAvailableDevices(DeviceType::kCUDA)), + num_stages_(num_stages), rank_(rank) { + CHECK(!devices_.empty()) << "Devices list is empty"; + + SplitModel(recv_shape, optimizer_factory); + + SetupSchedule(num_microbatches); +} + +} // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/pp/pipeline_schedule.cc b/infini_train/src/nn/parallel/pp/pipeline_schedule.cc new file mode 100644 index 00000000..34bbae19 --- /dev/null +++ b/infini_train/src/nn/parallel/pp/pipeline_schedule.cc @@ -0,0 +1,171 @@ +// pipeline_schedule.cc +#include "infini_train/include/nn/parallel/pp/pipeline_schedule.h" + +#include "glog/logging.h" +#include +#include +#include +#include +#include + +#include "infini_train/include/device.h" +#include "infini_train/include/nn/init.h" +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/parallel/parallel_functional.h" +#include "infini_train/include/nn/parallel/pp/send_recv.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::parallel { + +std::vector> PipelineSchedule::SplitTensor(std::shared_ptr full_inputs) { + const auto n = num_microbatches_; + if (n == 1) { + return {full_inputs}; + } + + const auto &first_dims = full_inputs->Dims(); + if (first_dims.empty()) { + LOG(FATAL) << "SplitTensor: tensor has no dimensions."; + } + int64_t batch_size = first_dims[0]; + int microbatch_size = batch_size / n; + int remainder = batch_size % n; + + std::vector> micro_batches; + + int start_idx = 0; + int end_idx = 0; + for (int mb = 0; mb < n; ++mb) { + int current_size = microbatch_size + (mb == n - 1 ? remainder : 0); + end_idx = start_idx + current_size; + + if (start_idx < 0 || end_idx > batch_size || start_idx >= end_idx) { + LOG(FATAL) << "Invalid slice range: [%d, %d), batch_size=%ld" << start_idx << end_idx << batch_size; + } + + if (full_inputs->Dims()[0] != batch_size) { + LOG(FATAL) << "SplitTensor: tensor size mismatch on dim 0."; + } + + auto sliced = full_inputs->Slice(0, start_idx, end_idx); + + micro_batches.push_back(sliced); + + start_idx = end_idx; + } + + return micro_batches; +} + +float PipelineSchedule::Step(std::shared_ptr input, std::shared_ptr target, + const std::shared_ptr &loss_fn) { + std::vector> micro_batches(num_microbatches_); + std::vector> target_mbs(num_microbatches_); + if (stage_->IsFirstStage()) { + micro_batches = SplitTensor(input); + } + if (stage_->IsLastStage()) { + target_mbs = SplitTensor(target); + } + + const auto &optim = stage_->optimizer(); + + optim->ZeroGrad(); + + float lossf = StepMicrobatches(micro_batches, target_mbs, loss_fn); + + optim->Step(); + + return lossf; +} + +std::vector> PipelineSchedule::ReceiveFromPrev() { + std::vector> recv_tensors; + if (!stage_->IsFirstStage()) { + auto shapes = stage_->recv_shape(); + for (size_t i = 0; i < shapes.size(); ++i) { + auto tensor = std::make_shared(shapes[i], DataType::kFLOAT32, stage_->device()); + tensor->set_requires_grad(true); + tensor->set_is_leaf(false); + recv_tensors.push_back(tensor); + } + return IRecv(recv_tensors, stage_->device(), stage_->stage_index(), stage_->prev_rank()); + } + return recv_tensors; +} + +std::vector> PipelineSchedule::SendToNext(const std::vector> &tensors) { + if (!stage_->IsLastStage()) { + return ISend(tensors, stage_->device(), stage_->stage_index(), stage_->next_rank(), stage_->recv_shape()); + } + return tensors; +} + +float ScheduleGPipe::StepMicrobatches(const std::vector> µbatch_inputs, + const std::vector> µbatch_targets, + const std::shared_ptr &loss_fn) { + const auto n = num_microbatches_; + if (n == 0) { + return 0.0f; + } + std::vector>> outputs(n); + + // ======== Forward Pass ======== + for (int mb = 0; mb < n; ++mb) { + std::vector> inputs; + + if (stage_->IsFirstStage()) { + inputs = {microbatch_inputs[mb]}; + } else { + inputs = ReceiveFromPrev(); + } + + outputs[mb] = stage_->ForwardOneChunk(inputs); + + for (auto &t : outputs[mb]) { + if (!t) { + t = std::make_shared((std::vector){}, DataType::kFLOAT32, stage_->device()); + } + } + + outputs[mb] = SendToNext(outputs[mb]); + } + + // ======== Backward Pass ======== + float total_loss = 0.0f; + if (!stage_->IsLastStage()) { + for (int mb = 0; mb < n; ++mb) { + auto out_tensor = outputs[mb][0]; + + auto gradient = std::make_shared(out_tensor->Dims(), out_tensor->Dtype(), out_tensor->GetDevice()); + + out_tensor->Backward(gradient); + } + } else { + for (int mb = 0; mb < n; ++mb) { + auto target = microbatch_targets[mb]; + auto output = outputs[mb][0]; + + if (!target || !output) { + LOG(FATAL) << "Output or target is null at mb=" << mb; + } + + auto target_on_device = target->To(output->GetDevice()); + auto loss = loss_fn->Forward({output, std::make_shared(target_on_device)})[0]; + if (!loss) { + LOG(INFO) << "[ERROR] loss is nullptr at mb = " << mb; + continue; + } + + loss = loss / n; + auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice()); + total_loss += static_cast(loss_cpu.DataPtr())[0]; + + loss->Backward(); + } + } + + return total_loss; +} + +} // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/pp/pipeline_stage.cc b/infini_train/src/nn/parallel/pp/pipeline_stage.cc new file mode 100644 index 00000000..fb6d13be --- /dev/null +++ b/infini_train/src/nn/parallel/pp/pipeline_stage.cc @@ -0,0 +1,32 @@ +#include "infini_train/include/nn/parallel/pp/pipeline_stage.h" + +#include "glog/logging.h" + +#include + +#include "infini_train/include/device.h" +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/nn/init.h" + +namespace infini_train::nn::parallel { + +PipelineStage::PipelineStage(const std::vector> &layers, int stage_index, int num_stages, + const std::vector> &recvShape, std::shared_ptr optim) + : stage_index_(stage_index), num_stages_(num_stages), layers_(layers), + prev_rank_(stage_index > 0 ? stage_index - 1 : -1), + next_rank_(stage_index < num_stages - 1 ? stage_index + 1 : -1), recv_shape_(recvShape), optim_(std::move(optim)), + device_(DeviceManager::Instance()->GetAllAvailableDevices(DeviceType::kCUDA).at(stage_index)) {} + +std::vector> +PipelineStage::ForwardOneChunk(const std::vector> &inputs) { + std::vector> current = inputs; + int i = 0; + for (const auto &layer : layers_) { + current = layer->Forward(current); + ++i; + } + + return current; +} + +} // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/pp/send_recv.cc b/infini_train/src/nn/parallel/pp/send_recv.cc new file mode 100644 index 00000000..d132ba49 --- /dev/null +++ b/infini_train/src/nn/parallel/pp/send_recv.cc @@ -0,0 +1,135 @@ +#include "infini_train/include/nn/parallel/pp/send_recv.h" + +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/autograd/function.h" +#include "infini_train/include/device.h" +#include "infini_train/include/nn/parallel/process_group.h" +#include "infini_train/include/nn/parallel/utils.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::parallel { + +namespace functions { +class ISend : public autograd::Function { +public: + static constexpr char kType[] = "ISendFunction"; + + explicit ISend(const Device *target_device, int cur_rank, int peer_rank, std::vector> shape) + : autograd::Function(kType), target_device_(target_device), cur_rank_(cur_rank), peer_rank_(peer_rank), + shapes_(shape) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + const Device *target_device_ = nullptr; + const Device *input_device_ = nullptr; + int cur_rank_ = -1; + int peer_rank_ = -1; + std::vector> shapes_; +}; + +class IRecv : public autograd::Function { +public: + static constexpr char kType[] = "IRecvFunction"; + + explicit IRecv(const Device *src_device, int cur_rank, int peer_rank) + : autograd::Function(kType), src_device_(src_device), cur_rank_(cur_rank), peer_rank_(peer_rank) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + const Device *src_device_ = nullptr; + const Device *cur_device_ = nullptr; + int cur_rank_ = -1; + int peer_rank_ = -1; +}; + +std::vector> ISend::Forward(const std::vector> &input_tensors) { + const auto &input = input_tensors[0]; + input_device_ = input->GetDevice(); + + auto pp_group = ProcessGroupFactory::Instance()->Get( + GetPipelineParallelProcessGroupName(input_device_->rank().thread_rank())); + + pp_group->NcclSend(input_tensors, peer_rank_); + + std::vector> outputs; + for (auto t : input_tensors) { outputs.push_back(t); } + return outputs; +} + +void ISend::SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) {} + +std::vector> ISend::Backward(const std::vector> &grad_outputs) { + auto shapes = shapes_; + std::vector> recv_tensors; + for (int shape_i = 0; shape_i < shapes.size(); ++shape_i) { + auto r_tensor = std::make_shared(shapes[shape_i], DataType::kFLOAT32, input_device_); + recv_tensors.push_back(r_tensor); + } + + auto pp_group = ProcessGroupFactory::Instance()->Get( + GetPipelineParallelProcessGroupName(input_device_->rank().thread_rank())); + + return pp_group->NcclRecv(recv_tensors, peer_rank_); +} + +std::vector> IRecv::Forward(const std::vector> &recv_tensors) { + CHECK_NE(src_device_, nullptr) << "src_device_ must be set"; + + auto pp_group + = ProcessGroupFactory::Instance()->Get(GetPipelineParallelProcessGroupName(src_device_->rank().thread_rank())); + pp_group->NcclRecv(recv_tensors, peer_rank_); + + std::vector> outputs; + for (auto t : recv_tensors) { + auto t_item = std::make_shared(*t); + t_item->set_requires_grad(true); + outputs.push_back(t); + } + return outputs; +} + +void IRecv::SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) { + if (output_tensors.empty()) { + return; + } + cur_device_ = output_tensors[0]->GetDevice(); +} + +std::vector> IRecv::Backward(const std::vector> &grad_outputs) { + auto pp_group + = ProcessGroupFactory::Instance()->Get(GetPipelineParallelProcessGroupName(cur_device_->rank().thread_rank())); + return pp_group->NcclSend(grad_outputs, peer_rank_); +} +} // namespace functions + +std::vector> ISend(const std::vector> &input_tensors, + const Device *target_device, int cur_rank, int peer_rank, + std::vector> shape) { + auto func = std::make_shared(target_device, cur_rank, peer_rank, shape); + return func->Apply(input_tensors); +} + +std::vector> IRecv(const std::vector> &outputs, + const Device *src_device, int cur_rank, int peer_rank) { + auto func = std::make_shared(src_device, cur_rank, peer_rank); + return func->Apply(outputs); +} +} // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/utils.cc b/infini_train/src/nn/parallel/utils.cc index 8c4ad8cd..5456e4e7 100644 --- a/infini_train/src/nn/parallel/utils.cc +++ b/infini_train/src/nn/parallel/utils.cc @@ -12,6 +12,10 @@ std::string GetTensorParallelProcessGroupName(int thread_rank) { return "TP" + std::to_string(global::GetGroupId(global::TP, thread_rank)); } +std::string GetPipelineParallelProcessGroupName(int thread_rank) { + return "PP" + std::to_string(global::GetGroupId(global::PP, thread_rank)); +} + std::vector GetDataParallelGroupRanks(int thread_rank) { return global::GetGroupRanks(global::DP, thread_rank); } std::vector GetTensorParallelGroupRanks(int thread_rank) { return global::GetGroupRanks(global::TP, thread_rank); } diff --git a/scripts/run_models_and_profile.bash b/scripts/run_models_and_profile.bash old mode 100644 new mode 100755 From 962509cd82a4b902fc9dd8f33e06974e9f5368fb Mon Sep 17 00:00:00 2001 From: JYMiracle305 <604951424@qq.com> Date: Thu, 6 Nov 2025 18:14:32 +0800 Subject: [PATCH 2/3] init pp_world_size accuratly --- example/gpt2/main.cc | 22 +++++++++------------- example/llama3/main.cc | 22 +++++++++------------- example/llama3/net.cc | 8 -------- infini_train/include/nn/parallel/global.h | 7 ++++--- infini_train/src/nn/parallel/global.cc | 6 +++--- 5 files changed, 25 insertions(+), 40 deletions(-) diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index 72a2b45f..77f7479f 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -64,9 +64,9 @@ DEFINE_int32( "When set > 1, enables data parallelism with device=cuda on the specified number of visible CUDA devices."); DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size"); DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel"); -DEFINE_bool( - pipeline_parallel, false, - "use pipeline parallelism or not, will always use device=cuda and use all cuda visible devices when set to true"); +DEFINE_uint32( + pipeline_parallel, 1, + "Pipeline Parallel world size, will always use device=cuda and use all cuda visible devices when set to true"); DEFINE_uint32(num_microbatches, 4, "the num of microbatches in pipeline parallelism"); // precision @@ -155,7 +155,7 @@ void Train(const nn::parallel::Rank &rank) { } // calculate gradient accumulation from the desired total batch size and the current run configuration - const auto tokens_per_fwdbwd = FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size; + const auto tokens_per_fwdbwd = FLAGS_batch_size * FLAGS_sequence_length * (ddp_world_size * pp_world_size); CHECK_EQ(FLAGS_total_batch_size % tokens_per_fwdbwd, 0); const auto grad_accum_steps = FLAGS_total_batch_size / tokens_per_fwdbwd; LOG(INFO) << "total desired batch size: " << FLAGS_total_batch_size @@ -189,10 +189,6 @@ void Train(const nn::parallel::Rank &rank) { LOG(FATAL) << "Rank " << rank.thread_rank() << ": Datatype " << FLAGS_dtype << " not supported."; } - // TODO(jym): Temporary implementation before 3D parallelism - if (FLAGS_pipeline_parallel) { - ddp_world_size = 1; - } // NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions // before wrapping the model with DistributedDataParallel (DDP). // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors @@ -202,7 +198,7 @@ void Train(const nn::parallel::Rank &rank) { } std::unique_ptr train_loader; - if (FLAGS_pipeline_parallel) { + if (pp_world_size > 1) { train_loader = std::make_unique( std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), FLAGS_batch_size * pp_world_size); @@ -243,7 +239,7 @@ void Train(const nn::parallel::Rank &rank) { loss_fn->To(device); LOG(INFO) << "Rank " << rank.thread_rank() << ": start training"; - if (FLAGS_pipeline_parallel) { + if (pp_world_size > 1) { CHECK_EQ((FLAGS_batch_size * pp_world_size) % FLAGS_num_microbatches, 0) << "FLAGS_batch_size (" << (FLAGS_batch_size * pp_world_size) << ") must be divisible by FLAGS_num_microbatches (" << FLAGS_num_microbatches << ")"; @@ -279,7 +275,7 @@ void Train(const nn::parallel::Rank &rank) { } // model->Train(); - if (!FLAGS_pipeline_parallel) { + if (pp_world_size == 1) { optimizer->ZeroGrad(); } // if we are trying to overfit a single batch, we reset the loader here @@ -302,7 +298,7 @@ void Train(const nn::parallel::Rank &rank) { x = std::make_shared(x->To(device)); y = std::make_shared(y->To(device)); - if (FLAGS_pipeline_parallel) { + if (pp_world_size > 1) { lossf = model->TrainStep({x}, {y}, loss_fn); auto loss_tensor = std::make_shared(std::vector{}, DataType::kFLOAT32); @@ -335,7 +331,7 @@ void Train(const nn::parallel::Rank &rank) { LOG(INFO) << "Rank " << rank.thread_rank() << ": finish backward"; } - if (!FLAGS_pipeline_parallel) { + if (pp_world_size == 1) { optimizer->Step(); } const auto iter_end = std::chrono::high_resolution_clock::now(); diff --git a/example/llama3/main.cc b/example/llama3/main.cc index f3795267..c983268c 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -63,9 +63,9 @@ DEFINE_int32( "When set > 1, enables data parallelism with device=cuda on the specified number of visible CUDA devices."); DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size"); DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel"); -DEFINE_bool( - pipeline_parallel, false, - "use pipeline parallelism or not, will always use device=cuda and use all cuda visible devices when set to true"); +DEFINE_uint32( + pipeline_parallel, 1, + "Pipeline Parallel world size, will always use device=cuda and use all cuda visible devices when set to true"); DEFINE_uint32(num_microbatches, 4, "the num of microbatches in pipeline parallelism"); // precision DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)"); @@ -137,7 +137,7 @@ void Train(const nn::parallel::Rank &rank) { } // calculate gradient accumulation from the desired total batch size and the current run configuration - const auto tokens_per_fwdbwd = FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size; + const auto tokens_per_fwdbwd = FLAGS_batch_size * FLAGS_sequence_length * (ddp_world_size * pp_world_size); CHECK_EQ(FLAGS_total_batch_size % tokens_per_fwdbwd, 0); const auto grad_accum_steps = FLAGS_total_batch_size / tokens_per_fwdbwd; if (rank.IsMainRank()) { @@ -169,10 +169,6 @@ void Train(const nn::parallel::Rank &rank) { LOG(FATAL) << "Rank " << rank.thread_rank() << ": Datatype " << FLAGS_dtype << " not supported."; } - // TODO(jym): Temporary implementation before 3D parallelism - if (FLAGS_pipeline_parallel) { - ddp_world_size = 1; - } // NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions // before wrapping the model with DistributedDataParallel (DDP). // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors @@ -182,7 +178,7 @@ void Train(const nn::parallel::Rank &rank) { } std::unique_ptr train_loader; - if (FLAGS_pipeline_parallel) { + if (pp_world_size > 1) { train_loader = std::make_unique( std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), FLAGS_batch_size * pp_world_size); @@ -221,7 +217,7 @@ void Train(const nn::parallel::Rank &rank) { loss_fn->To(device); LOG(INFO) << "Rank " << rank.thread_rank() << ": start training"; - if (FLAGS_pipeline_parallel) { + if (pp_world_size > 1) { CHECK_EQ((FLAGS_batch_size * pp_world_size) % FLAGS_num_microbatches, 0) << "FLAGS_batch_size (" << (FLAGS_batch_size * pp_world_size) << ") must be divisible by FLAGS_num_microbatches (" << FLAGS_num_microbatches << ")"; @@ -256,7 +252,7 @@ void Train(const nn::parallel::Rank &rank) { } // model->Train(); - if (!FLAGS_pipeline_parallel) { + if (pp_world_size == 1) { optimizer->ZeroGrad(); } // if we are trying to overfit a single batch, we reset the loader here @@ -278,7 +274,7 @@ void Train(const nn::parallel::Rank &rank) { ++train_iter; x = std::make_shared(x->To(device)); y = std::make_shared(y->To(device)); - if (FLAGS_pipeline_parallel) { + if (pp_world_size > 1) { lossf = model->TrainStep({x}, {y}, loss_fn); auto loss_tensor = std::make_shared(std::vector{}, DataType::kFLOAT32); @@ -311,7 +307,7 @@ void Train(const nn::parallel::Rank &rank) { LOG(INFO) << "Rank " << rank.thread_rank() << ": finish backward"; } - if (!FLAGS_pipeline_parallel) { + if (pp_world_size == 1) { optimizer->Step(); } diff --git a/example/llama3/net.cc b/example/llama3/net.cc index 44cabe36..856ccbf6 100644 --- a/example/llama3/net.cc +++ b/example/llama3/net.cc @@ -405,14 +405,6 @@ std::vector> LLaMA3::GetPipelineLayers() { layers.push_back(std::make_shared(h, dtype, config_)); ++idx; } - // for (int idx = 0; idx < seq->size(); ++idx) { - // auto wrapped_layer = (*seq)[idx]; - // layers.push_back(std::make_shared(wrapped_layer, dtype, config_)); - // for (int idx = 1; idx < seq->size(); ++idx) { - // auto wrapped_layer = (*seq)[idx]; - // layers.push_back(wrapped_layer); - // } - // } layers.push_back(transformer->mutable_module(kLnFLayerName)); diff --git a/infini_train/include/nn/parallel/global.h b/infini_train/include/nn/parallel/global.h index 5c9aa198..4b44e408 100644 --- a/infini_train/include/nn/parallel/global.h +++ b/infini_train/include/nn/parallel/global.h @@ -27,7 +27,7 @@ class GlobalEnv { static GlobalEnv &Instance(); void Init(int threads_per_process, int tensor_parallel_size, bool sequence_parallel_enabled = false, - bool pipeline_parallel = false); + int pipeline_parallel_size = 1); int world_size() const; @@ -77,8 +77,9 @@ class GlobalEnv { }; inline void InitAllEnv(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled = false, - bool pipeline_parallel = false) { - GlobalEnv::Instance().Init(nthread_per_process, tensor_parallel_size, sequence_parallel_enabled, pipeline_parallel); + int pipeline_parallel_size = 1) { + GlobalEnv::Instance().Init(nthread_per_process, tensor_parallel_size, sequence_parallel_enabled, + pipeline_parallel_size); } inline int GetWorldSize() { return GlobalEnv::Instance().world_size(); } diff --git a/infini_train/src/nn/parallel/global.cc b/infini_train/src/nn/parallel/global.cc index 3f6fcb4f..8c97b225 100644 --- a/infini_train/src/nn/parallel/global.cc +++ b/infini_train/src/nn/parallel/global.cc @@ -85,7 +85,7 @@ GlobalEnv &GlobalEnv::Instance() { } void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled, - bool pipeline_parallel) { + int pipeline_parallel_size) { std::lock_guard lock(mutex_); CHECK(!initialized_) << "Repeated initialization of GlobalEnv!"; @@ -99,8 +99,8 @@ void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool seq CHECK_GE(tensor_parallel_size, 1) << "Tensor Parallel size must be >= 1"; tensor_parallel_size_ = tensor_parallel_size; sequence_parallel_enabled_ = sequence_parallel_enabled; - data_parallel_size_ = world_size_ / tensor_parallel_size_; - pipeline_parallel_size_ = pipeline_parallel ? nthread_per_process : 1; + pipeline_parallel_size_ = pipeline_parallel_size; + data_parallel_size_ = world_size_ / tensor_parallel_size_ / pipeline_parallel_size_; layout_.sizes[DP] = data_parallel_size_; layout_.sizes[TP] = tensor_parallel_size_; From 09735f24d8f095ef861e7d14eed26d6e9e73e326 Mon Sep 17 00:00:00 2001 From: JYMiracle305 <604951424@qq.com> Date: Sat, 15 Nov 2025 23:10:20 +0800 Subject: [PATCH 3/3] feat: Pipeline parallelism divides the model into chunks during construction --- example/common/utils.cc | 7 - example/common/utils.h | 1 - example/gpt2/main.cc | 145 +++-- example/gpt2/net.cc | 539 ++++++++++-------- example/gpt2/net.h | 3 - example/llama3/main.cc | 144 +++-- example/llama3/net.cc | 368 ++++++------ example/llama3/net.h | 4 - infini_train/include/nn/modules/container.h | 10 - infini_train/include/nn/modules/module.h | 4 - infini_train/include/nn/parallel/global.h | 8 +- .../nn/parallel/pp/pipeline_parallel.h | 46 +- .../nn/parallel/pp/pipeline_schedule.h | 45 +- .../include/nn/parallel/pp/pipeline_stage.h | 27 +- .../include/nn/parallel/pp/send_recv.h | 7 +- infini_train/include/nn/parallel/utils.h | 2 + .../src/nn/parallel/pp/pipeline_parallel.cc | 103 +--- .../src/nn/parallel/pp/pipeline_schedule.cc | 112 ++-- .../src/nn/parallel/pp/pipeline_stage.cc | 24 +- infini_train/src/nn/parallel/pp/send_recv.cc | 22 +- infini_train/src/nn/parallel/utils.cc | 6 + infini_train/src/optimizer.cc | 4 + 22 files changed, 789 insertions(+), 842 deletions(-) diff --git a/example/common/utils.cc b/example/common/utils.cc index 347aa195..03cc7aa0 100644 --- a/example/common/utils.cc +++ b/example/common/utils.cc @@ -61,11 +61,4 @@ void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t s ifs.seekg(base + std::streamoff(len * sizeof(float))); } -std::vector GetPipelineParallelGroupRanks(int pp_world_size) { - std::vector ranks; - ranks.reserve(pp_world_size); - for (int i = 0; i < pp_world_size; ++i) { ranks.push_back(i); } - return ranks; -} - } // namespace infini_train diff --git a/example/common/utils.h b/example/common/utils.h index 05ff7b36..5bab3e97 100644 --- a/example/common/utils.h +++ b/example/common/utils.h @@ -30,5 +30,4 @@ void ReadVectorAllFloat(std::ifstream &ifs, float *dst, int64_t len); void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t start, int64_t cnt); -std::vector GetPipelineParallelGroupRanks(int rank); } // namespace infini_train diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index 77f7479f..ea7dce3d 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -64,10 +64,7 @@ DEFINE_int32( "When set > 1, enables data parallelism with device=cuda on the specified number of visible CUDA devices."); DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size"); DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel"); -DEFINE_uint32( - pipeline_parallel, 1, - "Pipeline Parallel world size, will always use device=cuda and use all cuda visible devices when set to true"); -DEFINE_uint32(num_microbatches, 4, "the num of microbatches in pipeline parallelism"); +DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages."); // precision DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)"); @@ -148,6 +145,8 @@ void Train(const nn::parallel::Rank &rank) { pp_pg = ProcessGroupFactory::Instance()->GetOrCreate( GetPipelineParallelProcessGroupName(rank.thread_rank()), GetPipelineParallelGroupRanks(pp_world_size)); pp_rank = pp_pg->GetGroupRank(rank.thread_rank()); + + nn::parallel::pp_rank = pp_rank; } } else { device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDefaultDevice() @@ -155,7 +154,7 @@ void Train(const nn::parallel::Rank &rank) { } // calculate gradient accumulation from the desired total batch size and the current run configuration - const auto tokens_per_fwdbwd = FLAGS_batch_size * FLAGS_sequence_length * (ddp_world_size * pp_world_size); + const auto tokens_per_fwdbwd = FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size; CHECK_EQ(FLAGS_total_batch_size % tokens_per_fwdbwd, 0); const auto grad_accum_steps = FLAGS_total_batch_size / tokens_per_fwdbwd; LOG(INFO) << "total desired batch size: " << FLAGS_total_batch_size @@ -197,16 +196,9 @@ void Train(const nn::parallel::Rank &rank) { model = std::make_shared(model, rank.thread_rank()); } - std::unique_ptr train_loader; - if (pp_world_size > 1) { - train_loader = std::make_unique( - std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), - FLAGS_batch_size * pp_world_size); - } else { - train_loader = std::make_unique( - std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), FLAGS_batch_size, - ddp_rank, ddp_world_size); - } + auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size); + DistributedDataLoader train_loader(std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), + FLAGS_batch_size * num_micro_batches, ddp_rank, ddp_world_size); std::optional val_loader = std::nullopt; if (!FLAGS_input_val_bin.empty()) { @@ -225,13 +217,9 @@ void Train(const nn::parallel::Rank &rank) { } // TODO(dcj): support more complex optimizer later - auto lr = FLAGS_learning_rate; - auto optimizer_factory = [lr](const std::vector> ¶ms) { - return std::make_shared(params, lr); - }; - auto optimizer = optimizer_factory(model->Parameters()); + auto optimizer = optimizers::SGD(model->Parameters(), FLAGS_learning_rate); - auto train_iter = train_loader->begin(); + auto train_iter = train_loader.begin(); std::shared_ptr loss_fn = (tp_world_size > 1) ? std::static_pointer_cast( std::make_shared(model_config.original_vocab_size)) @@ -240,14 +228,10 @@ void Train(const nn::parallel::Rank &rank) { LOG(INFO) << "Rank " << rank.thread_rank() << ": start training"; if (pp_world_size > 1) { - CHECK_EQ((FLAGS_batch_size * pp_world_size) % FLAGS_num_microbatches, 0) - << "FLAGS_batch_size (" << (FLAGS_batch_size * pp_world_size) - << ") must be divisible by FLAGS_num_microbatches (" << FLAGS_num_microbatches << ")"; - auto shapes = std::vector>{{(FLAGS_batch_size * pp_world_size) / FLAGS_num_microbatches, - FLAGS_sequence_length, model->GetConfig()["n_embd"]}}; - - model = std::make_shared(model, pp_world_size, FLAGS_num_microbatches, shapes, - pp_rank, optimizer_factory); + auto shapes = std::vector>{{FLAGS_batch_size, FLAGS_sequence_length, model_config.n_embd}}; + + model = std::make_shared(model, pp_world_size, num_micro_batches, shapes, + pp_rank, std::make_shared(optimizer)); } LOG(INFO) << "start training"; @@ -274,23 +258,55 @@ void Train(const nn::parallel::Rank &rank) { break; } - // model->Train(); - if (pp_world_size == 1) { - optimizer->ZeroGrad(); - } - // if we are trying to overfit a single batch, we reset the loader here - if (FLAGS_overfit_single_batch) { - // train_loader.Reset(); - } - float lossf = 0.0f; #ifdef PROFILE_MODE Profiler::Instance().SetTag("Step_" + std::to_string(step)); #endif - for (int micro_step = 0; micro_step < grad_accum_steps; ++micro_step) { - // enable autocast for the current step - infini_train::AutocastGuard autocast_guard(device->Type(), dtype); - // (bs, seq_len), (bs, seq_len) + float lossf = 0.0f; + // model->Train(); + if (pp_world_size == 1) { + optimizer.ZeroGrad(); + + // if we are trying to overfit a single batch, we reset the loader here + if (FLAGS_overfit_single_batch) { + // train_loader.Reset(); + } + + for (int micro_step = 0; micro_step < grad_accum_steps; ++micro_step) { + // enable autocast for the current step + infini_train::AutocastGuard autocast_guard(device->Type(), dtype); + + // (bs, seq_len), (bs, seq_len) + auto [x, y] = *train_iter; + // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below + // TODO(dcj): support dataloader.reset() later + ++train_iter; + x = std::make_shared(x->To(device)); + y = std::make_shared(y->To(device)); + + LOG(INFO) << "Rank " << rank.thread_rank() << ": start forward"; + // (bs, seq_len, vocab_size) + auto logits = model->Forward({x, y})[0]; + LOG(INFO) << "Rank " << rank.thread_rank() << ": finish model forward, start loss forward"; + auto loss = loss_fn->Forward({logits, y})[0]; + loss = loss / grad_accum_steps; + + // disable autocast for the current step (backward is not under autocast) + autocast_guard.Disable(); + + LOG(INFO) << "Rank " << rank.thread_rank() << ": finish loss forward"; + if (ddp_world_size > 1) { + function::AllReduce(loss, function::ReduceOpType::kAvg); + } + auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice()); + lossf += static_cast(loss_cpu.DataPtr())[0]; + LOG(INFO) << "Rank " << rank.thread_rank() << ": start backward"; + loss->Backward(); + LOG(INFO) << "Rank " << rank.thread_rank() << ": finish backward"; + } + + optimizer.Step(); + } else { auto [x, y] = *train_iter; // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below // TODO(dcj): support dataloader.reset() later @@ -298,57 +314,24 @@ void Train(const nn::parallel::Rank &rank) { x = std::make_shared(x->To(device)); y = std::make_shared(y->To(device)); - if (pp_world_size > 1) { - lossf = model->TrainStep({x}, {y}, loss_fn); - - auto loss_tensor = std::make_shared(std::vector{}, DataType::kFLOAT32); - static_cast(loss_tensor->DataPtr())[0] = lossf; - auto loss_device_ptr = std::make_shared(loss_tensor->To(device)); - function::AllReduce(loss_device_ptr, function::ReduceOpType::kMax); - auto loss_copy = loss_device_ptr->To(DeviceManager::Instance()->GetDefaultDevice()); - lossf = static_cast(loss_copy.DataPtr())[0]; - continue; - } - - LOG(INFO) << "Rank " << rank.thread_rank() << ": start forward"; - // (bs, seq_len, vocab_size) - auto logits = model->Forward({x, y})[0]; - LOG(INFO) << "Rank " << rank.thread_rank() << ": finish model forward, start loss forward"; - auto loss = loss_fn->Forward({logits, y})[0]; - loss = loss / grad_accum_steps; - - // disable autocast for the current step (backward is not under autocast) - autocast_guard.Disable(); - - LOG(INFO) << "Rank " << rank.thread_rank() << ": finish loss forward"; - if (ddp_world_size > 1) { - function::AllReduce(loss, function::ReduceOpType::kAvg); - } - auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice()); - lossf += static_cast(loss_cpu.DataPtr())[0]; - LOG(INFO) << "Rank " << rank.thread_rank() << ": start backward"; - loss->Backward(); - LOG(INFO) << "Rank " << rank.thread_rank() << ": finish backward"; - } - - if (pp_world_size == 1) { - optimizer->Step(); + lossf = model->TrainStep({x}, {y}, loss_fn); } const auto iter_end = std::chrono::high_resolution_clock::now(); const double duration_us = std::chrono::duration(iter_end - iter_start).count(); const double tps = FLAGS_total_batch_size / (duration_us / 1e6); - if (rank.IsMainRank()) { + if (rank.thread_rank() == pp_world_size - 1) { LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, " "DP={}, TP={}, SP={}, PP={})", step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, tps, ddp_world_size, tp_world_size, sp_world_size, pp_world_size); if ((step + 1) % FLAGS_freq_generate_txt == 0) { - if (!tokenizer) { - continue; + if (tokenizer) { + // FIXME(jym): to support PP + CHECK_EQ(pp_world_size, 1); + tokenizer->GenerateText(*model, FLAGS_batch_size, FLAGS_sequence_length, FLAGS_text_length, device); } - tokenizer->GenerateText(*model, FLAGS_batch_size, FLAGS_sequence_length, FLAGS_text_length, device); } } } diff --git a/example/gpt2/net.cc b/example/gpt2/net.cc index 6640c026..259439c0 100644 --- a/example/gpt2/net.cc +++ b/example/gpt2/net.cc @@ -21,6 +21,7 @@ #include "infini_train/include/nn/modules/normalization.h" #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" #include "infini_train/include/nn/parallel/process_group.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" #include "infini_train/include/nn/parallel/utils.h" @@ -176,6 +177,10 @@ Block::Forward(const std::vector> &x) { } GPT2::GPT2(const GPT2Config &config) : config_(config) { + int pp_size = nn::parallel::global::GetPipelineParallelSize(); + auto [is_first_stage, is_last_stage, start_layer, end_layer] + = nn::parallel::PipelineParallel::GetStageInfo(config_.n_layer, pp_size); + auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); // NOTE(zbl): VocabParallelEmbedding requires vocab_size % tp_size == 0 @@ -184,131 +189,102 @@ GPT2::GPT2(const GPT2Config &config) : config_(config) { CHECK_EQ(config.vocab_size % tp_world_size, 0) << "Vocab size should be divisible by TP world size"; { std::unordered_map> transformer; - transformer[kWTELayerName] = std::make_shared( - config_.vocab_size, config_.n_embd, nn::parallel::global::GetSequenceParallelEnabled()); - transformer[kWPELayerName] = std::make_shared(config_.block_size, config_.n_embd); - { - std::vector> h; - for (int64_t i = 0; i < config_.n_layer; ++i) { h.push_back(std::make_shared(config_)); } - transformer[kHLayerName] = std::make_shared(std::move(h)); + if (is_first_stage) { + transformer[kWTELayerName] = std::make_shared( + config_.vocab_size, config_.n_embd, nn::parallel::global::GetSequenceParallelEnabled()); + transformer[kWPELayerName] = std::make_shared(config_.block_size, config_.n_embd); } - transformer[kLnFLayerName] = std::make_shared(std::vector{config_.n_embd}); - modules_[kTransformerLayerName] = std::make_shared(std::move(transformer)); - } - // don't init this one, we will tie weights - modules_[kLMHeadLayerName] = std::make_shared( - /*in_features=*/config_.n_embd, /*out_features=*/config_.vocab_size, - /*bias=*/false, - // NOTE(zbl): each tp_rank would get sharded [B, T, V_local] as logits - /*gather_output=*/false, - /*input_is_parallel=*/false, - /*skip_bias_add=*/false, - /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); - - // https://paperswithcode.com/method/weight-tying - *mutable_module(kTransformerLayerName) - ->mutable_module(kWTELayerName) - ->mutable_parameter(nn::parallel::VocabParallelEmbedding::kParamWeightName) - = module(kLMHeadLayerName).parameter(nn::parallel::ColumnParallelLinear::kParamWeightName); -} - -class EmbeddingLayer : public nn::Module { - -public: - EmbeddingLayer(std::shared_ptr wte, std::shared_ptr wpe) { - modules_["wte"] = wte; - modules_["wpe"] = wpe; - } - - std::vector> - Forward(const std::vector> &inputs) override { - auto &input_ids = inputs[0]; // (bs, seq_len) - const int seq_len = input_ids->Dims()[1]; - const auto device = input_ids->GetDevice(); - - // position ids: [0, 1, ..., seq_len-1] - auto pos_ids = nn::init::Arange(0, seq_len, infini_train::DataType::kINT64, device); - // (bs, seq_len) -> wte -> (bs, seq_len, n_embd) - auto tok_emb = modules_["wte"]->Forward({input_ids})[0]; - // (seq_len,) -> wpe -> (seq_len, n_embd) - auto pos_emb = modules_["wpe"]->Forward({pos_ids})[0]; - - auto output = tok_emb + pos_emb; // (bs, seq_len, n_embd) - - return {output}; - } -}; - -std::vector> GPT2::GetPipelineLayers() { - auto &transformer = modules_[kTransformerLayerName]; - std::vector> layers; + { + std::vector> h; + for (int64_t i = start_layer; i < end_layer; ++i) { h.push_back(std::make_shared(config_)); } + transformer[kHLayerName] = std::make_shared(std::move(h)); + if (is_last_stage) { + transformer[kLnFLayerName] = std::make_shared(std::vector{config_.n_embd}); + } - auto embedding_layer = std::make_shared(transformer->mutable_module(kWTELayerName), - transformer->mutable_module(kWPELayerName)); - layers.push_back(embedding_layer); + modules_[kTransformerLayerName] = std::make_shared(std::move(transformer)); + } + if (is_last_stage) { + // don't init this one, we will tie weights + modules_[kLMHeadLayerName] = std::make_shared( + /*in_features=*/config_.n_embd, /*out_features=*/config_.vocab_size, + /*bias=*/false, + // NOTE(zbl): each tp_rank would get sharded [B, T, V_local] as logits + /*gather_output=*/false, + /*input_is_parallel=*/false, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); + } - auto seq = std::dynamic_pointer_cast(transformer->mutable_module(kHLayerName)); - if (seq) { - for (int idx = 0; idx < seq->size(); ++idx) { - auto sub_module = (*seq)[idx]; - layers.push_back(sub_module); + // FIXME(jym): Assigning the parameter values of wte to LMHead, which is not real tying operation + if (pp_size == 1) { + // https://paperswithcode.com/method/weight-tying + *mutable_module(kTransformerLayerName) + ->mutable_module(kWTELayerName) + ->mutable_parameter(nn::parallel::VocabParallelEmbedding::kParamWeightName) + = module(kLMHeadLayerName).parameter(nn::parallel::ColumnParallelLinear::kParamWeightName); } } - - layers.push_back(transformer->mutable_module(kLnFLayerName)); - layers.push_back(modules_[kLMHeadLayerName]); - - return layers; -} - -std::unordered_map GPT2::GetConfig() const { - return {{"n_embd", config_.n_embd}, {"n_head", config_.n_head}}; } std::vector> GPT2::Forward(const std::vector> &x) { + int pp_rank = nn::parallel::pp_rank; + int pp_size = nn::parallel::global::GetPipelineParallelSize(); + bool is_first_stage = (pp_rank == 0); + bool is_last_stage = (pp_rank == pp_size - 1); + // (B, T) - auto &idx = x[0]; - const auto device = idx->GetDevice(); + auto x1 = x[0]; + const auto device = x1->GetDevice(); - const auto t = idx->Dims()[1]; // T + const auto t = x1->Dims()[1]; // T CHECK_LE(t, config_.block_size) << "Cannot forward sequence of length " << t << ", block size is only " << config_.block_size; - // (T_local) - // NOTE(zbl): Slice pos sequence when SP is enabled - auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); - auto sequence_parallel_enabled = nn::parallel::global::GetSequenceParallelEnabled(); - int tp_rank = 0; - if (tp_world_size > 1) { - auto tp_group = nn::parallel::ProcessGroupFactory::Instance()->Get( - nn::parallel::GetTensorParallelProcessGroupName(device->rank().thread_rank())); - tp_rank = tp_group->GetGroupRank(device->rank().thread_rank()); - } - - int64_t t_local = sequence_parallel_enabled ? (t / tp_world_size) : t; - int64_t start = sequence_parallel_enabled ? tp_rank * t_local : 0; - auto pos = nn::init::Arange(start, start + t_local, infini_train::DataType::kINT64, device); // forward the GPT2 model itself auto &transformer = modules_[kTransformerLayerName]; - // (B, T) -> Embedding(V_local, C) -> (B, T, C) - auto tok_emb = transformer->mutable_module(kWTELayerName)->Forward({idx})[0]; - // (T) -> Embedding(T_max, C) -> (T, C) - auto pos_emb = transformer->mutable_module(kWPELayerName)->Forward({pos})[0]; - // (B, T, C) - auto x1 = tok_emb + pos_emb; + if (is_first_stage) { + // (T_local) + // NOTE(zbl): Slice pos sequence when SP is enabled + auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); + auto sequence_parallel_enabled = nn::parallel::global::GetSequenceParallelEnabled(); + int tp_rank = 0; + if (tp_world_size > 1) { + auto tp_group = nn::parallel::ProcessGroupFactory::Instance()->Get( + nn::parallel::GetTensorParallelProcessGroupName(device->rank().thread_rank())); + tp_rank = tp_group->GetGroupRank(device->rank().thread_rank()); + } + int64_t t_local = sequence_parallel_enabled ? (t / tp_world_size) : t; + int64_t start = sequence_parallel_enabled ? tp_rank * t_local : 0; + auto pos = nn::init::Arange(start, start + t_local, infini_train::DataType::kINT64, device); + + // (B, T) -> Embedding(V_local, C) -> (B, T, C) + auto tok_emb = transformer->mutable_module(kWTELayerName)->Forward({x1})[0]; + // (T) -> Embedding(T_max, C) -> (T, C) + auto pos_emb = transformer->mutable_module(kWPELayerName)->Forward({pos})[0]; + // (B, T, C) + x1 = tok_emb + pos_emb; + } // (B, T, C) -> transformer -> (B, T, C) - auto x2 = transformer->mutable_module(kHLayerName)->Forward({x1}); - // (B, T, C) -> Layernorm -> (B, T, C) - auto x3 = transformer->mutable_module(kLnFLayerName)->Forward(x2); - - // TODO(dcj): add inference-time mini-optimization - // (B, T, C) -> Linear(C, V) -> (B, T, V) - auto logits = modules_[kLMHeadLayerName]->Forward(x3); - - // (B, T, V_original) - return logits; + auto h_modules = transformer->mutable_module(kHLayerName); + CHECK_EQ(h_modules->type(), nn::ModuleList::kType) << "Failed to get ModuleList from transformer"; + auto h_layers = std::dynamic_pointer_cast(h_modules); + // (bs, seq_len, n_embd) -> transformer -> (bs, seq_len, n_embd) + for (auto &h : *h_layers) { x1 = h->Forward({x1})[0]; } + + if (is_last_stage) { + // (B, T, C) -> Layernorm -> (B, T, C) + auto x3 = transformer->mutable_module(kLnFLayerName)->Forward({x1}); + + // TODO(dcj): add inference-time mini-optimization + // (B, T, C) -> Linear(C, V) -> (B, T, V) + auto logits = modules_[kLMHeadLayerName]->Forward(x3); + // (B, T, V_original) + return logits; + } + return {x1}; } std::shared_ptr GPT2::FromPretrained(ModelType model_type) { @@ -375,6 +351,10 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { CHECK_EQ(n_embd % n_head, 0) << "n_embd must be divisible by n_head."; CHECK_EQ(n_head % tp_size, 0) << "n_head must be divisible by TP world size."; + int pp_size = nn::parallel::global::GetPipelineParallelSize(); + auto [is_first_stage, is_last_stage, start_layer, end_layer] + = nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size); + auto tp_rank = nn::parallel::tp_rank; // calculate xx_size_per_partition const int64_t vpp = model_vocab_size / tp_size; @@ -397,163 +377,274 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { // transformer.wte.weight (also transformer.lm_head.weight) // full: (model_vocab_size, n_embd) // local: (vocab_size_per_partition, n_embd) - auto &transformer_wte_weight = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kWTELayerName, - nn::parallel::VocabParallelEmbedding::kParamWeightName)]; - ReadMatrixRowShardFloat(ifs, static_cast(transformer_wte_weight->DataPtr()), model_vocab_size, n_embd, - v_start, vpp); + if (is_first_stage) { + auto &transformer_wte_weight + = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kWTELayerName, + nn::parallel::VocabParallelEmbedding::kParamWeightName)]; + ReadMatrixRowShardFloat(ifs, static_cast(transformer_wte_weight->DataPtr()), model_vocab_size, n_embd, + v_start, vpp); + } else if (pp_size > 1 && is_last_stage) { + auto &lm_head_weight = state_dict[std::format("{}.{}", GPT2::kLMHeadLayerName, + nn::parallel::ColumnParallelLinear::kParamWeightName)]; + ifs.read(reinterpret_cast(lm_head_weight->DataPtr()), lm_head_weight->SizeInBytes()); + } else { + size_t wte_bytes = vocab_size * n_embd * sizeof(float); + ifs.seekg(wte_bytes, std::ios::cur); + } + if (tp_size == 1) { // Skip padded vocab part when TP is not enabled ifs.ignore((padded_vocab_size - vocab_size) * n_embd * sizeof(float)); } - // transformer.wpe.weight - auto &transformer_wpe_weight = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kWPELayerName, - nn::Embedding::kParamWeightName)]; - ReadMatrixAllFloat(ifs, static_cast(transformer_wpe_weight->DataPtr()), block_size, n_embd); + + if (is_first_stage) { + // transformer.wpe.weight + auto &transformer_wpe_weight = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, + GPT2::kWPELayerName, nn::Embedding::kParamWeightName)]; + ReadMatrixAllFloat(ifs, static_cast(transformer_wpe_weight->DataPtr()), block_size, n_embd); + } else { + size_t wpe_bytes = block_size * n_embd * sizeof(float); + ifs.seekg(wpe_bytes, std::ios::cur); + } + // transformer.h.{i}.ln_1.weight for (int idx = 0; idx < n_layer; ++idx) { - auto &tensor - = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx), Block::kLn1LayerName, nn::LayerNorm::kParamWeightName)]; - ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + bool owned = (idx >= start_layer && idx < end_layer); + if (owned) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, + std::to_string(idx - start_layer), Block::kLn1LayerName, + nn::LayerNorm::kParamWeightName)]; + ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + } else { + size_t ln_1_w_bytes = n_embd * sizeof(float); + ifs.seekg(ln_1_w_bytes, std::ios::cur); + } } + // transformer.h.{i}.ln_1.bias for (int idx = 0; idx < n_layer; ++idx) { - auto &tensor - = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx), Block::kLn1LayerName, nn::LayerNorm::kParamBiasName)]; - ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + bool owned = (idx >= start_layer && idx < end_layer); + if (owned) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, + std::to_string(idx - start_layer), Block::kLn1LayerName, + nn::LayerNorm::kParamBiasName)]; + ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + } else { + size_t ln_1_b_bytes = n_embd * sizeof(float); + ifs.seekg(ln_1_b_bytes, std::ios::cur); + } } + // transformer.h.{i}.attn.c_attn.weight (ColumnParallelLinear, but actually applies on "rows") for (int idx = 0; idx < n_layer; ++idx) { - auto &tensor - = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx), Block::kAttnLayerName, CausalSelfAttention::kCAttnLayerName, - nn::parallel::ColumnParallelLinear::kParamWeightName)]; - // NOTE(zbl): In the .bin model file, Q/K/V is concated along last dim, - // i.e. [Q|K|V].T = [q1|q2|...|qn|k1|k2|...|kn|v1|v2|...|vn].T - // However, each tp_rank needs to get [q_i|k_i|v_i].T, so we need to jump and read them respectively - float *dst = static_cast(tensor->DataPtr()); - const int64_t local_C = n_embd / tp_size; - const int64_t rows_all = 3 * n_embd; - const int64_t cols_all = n_embd; - const std::streampos base_pos = ifs.tellg(); - // Read q_i -> write to dst rows of [0 : local_C) - ifs.seekg(base_pos); - ReadMatrixRowShardFloat(ifs, - /*dst=*/dst + (0 * local_C) * cols_all, - /*rows=*/rows_all, /*cols=*/cols_all, - /*row_start=*/tp_rank * local_C, /*row_cnt=*/local_C); - // Read k_i -> write to dst rows of [local_C : 2*local_C) - ifs.seekg(base_pos); - ReadMatrixRowShardFloat(ifs, - /*dst=*/dst + (1 * local_C) * cols_all, - /*rows=*/rows_all, /*cols=*/cols_all, - /*row_start=*/n_embd + tp_rank * local_C, /*row_cnt=*/local_C); - // Read v_i -> write to dst rows of [2*local_C : 3*local_C) - ifs.seekg(base_pos); - ReadMatrixRowShardFloat(ifs, - /*dst=*/dst + (2 * local_C) * cols_all, - /*rows=*/rows_all, /*cols=*/cols_all, - /*row_start=*/2 * n_embd + tp_rank * local_C, /*row_cnt=*/local_C); + bool owned = (idx >= start_layer && idx < end_layer); + if (owned) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, + std::to_string(idx - start_layer), Block::kAttnLayerName, + CausalSelfAttention::kCAttnLayerName, + nn::parallel::ColumnParallelLinear::kParamWeightName)]; + // NOTE(zbl): In the .bin model file, Q/K/V is concated along last dim, + // i.e. [Q|K|V].T = [q1|q2|...|qn|k1|k2|...|kn|v1|v2|...|vn].T + // However, each tp_rank needs to get [q_i|k_i|v_i].T, so we need to jump and read them + // respectively + float *dst = static_cast(tensor->DataPtr()); + const int64_t local_C = n_embd / tp_size; + const int64_t rows_all = 3 * n_embd; + const int64_t cols_all = n_embd; + const std::streampos base_pos = ifs.tellg(); + // Read q_i -> write to dst rows of [0 : local_C) + ifs.seekg(base_pos); + ReadMatrixRowShardFloat(ifs, + /*dst=*/dst + (0 * local_C) * cols_all, + /*rows=*/rows_all, /*cols=*/cols_all, + /*row_start=*/tp_rank * local_C, /*row_cnt=*/local_C); + // Read k_i -> write to dst rows of [local_C : 2*local_C) + ifs.seekg(base_pos); + ReadMatrixRowShardFloat(ifs, + /*dst=*/dst + (1 * local_C) * cols_all, + /*rows=*/rows_all, /*cols=*/cols_all, + /*row_start=*/n_embd + tp_rank * local_C, /*row_cnt=*/local_C); + // Read v_i -> write to dst rows of [2*local_C : 3*local_C) + ifs.seekg(base_pos); + ReadMatrixRowShardFloat(ifs, + /*dst=*/dst + (2 * local_C) * cols_all, + /*rows=*/rows_all, /*cols=*/cols_all, + /*row_start=*/2 * n_embd + tp_rank * local_C, /*row_cnt=*/local_C); + } else { + size_t c_attn_w_bytes = qkv_out * n_embd * sizeof(float); + ifs.seekg(c_attn_w_bytes, std::ios::cur); + } } + // transformer.h.{i}.attn.c_attn.bias (ColumnParallelLinear) for (int idx = 0; idx < n_layer; ++idx) { - auto &tensor - = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx), Block::kAttnLayerName, CausalSelfAttention::kCAttnLayerName, - nn::parallel::ColumnParallelLinear::kParamBiasName)]; - // NOTE(zbl): Same as c_attn.weight, the bias for Q/K/V is concated - // i.e. [Q|K|V] = [q1|q2|...|qn|k1|k2|...|kn|v1|v2|...|vn] - // However, each tp_rank needs to get [q_i|k_i|v_i], so we need to jump and read them respectively - float *dst = static_cast(tensor->DataPtr()); - const int64_t local_C = n_embd / tp_size; - const int64_t len_all = 3 * n_embd; - const std::streampos base_pos = ifs.tellg(); - // Read q_i - ifs.seekg(base_pos); - ReadVectorShardFloat(ifs, - /*dst=*/dst + (0 * local_C), - /*len=*/len_all, - /*start=*/tp_rank * local_C, /*cnt=*/local_C); - // Read k_i - ifs.seekg(base_pos); - ReadVectorShardFloat(ifs, - /*dst=*/dst + (1 * local_C), - /*len=*/len_all, - /*start=*/n_embd + tp_rank * local_C, /*cnt=*/local_C); - // Read v_i - ifs.seekg(base_pos); - ReadVectorShardFloat(ifs, - /*dst=*/dst + (2 * local_C), - /*len=*/len_all, - /*start=*/2 * n_embd + tp_rank * local_C, /*cnt=*/local_C); + bool owned = (idx >= start_layer && idx < end_layer); + if (owned) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, + std::to_string(idx - start_layer), Block::kAttnLayerName, + CausalSelfAttention::kCAttnLayerName, + nn::parallel::ColumnParallelLinear::kParamBiasName)]; + // NOTE(zbl): Same as c_attn.weight, the bias for Q/K/V is concated + // i.e. [Q|K|V] = [q1|q2|...|qn|k1|k2|...|kn|v1|v2|...|vn] + // However, each tp_rank needs to get [q_i|k_i|v_i], so we need to jump and read them + // respectively + float *dst = static_cast(tensor->DataPtr()); + const int64_t local_C = n_embd / tp_size; + const int64_t len_all = 3 * n_embd; + const std::streampos base_pos = ifs.tellg(); + // Read q_i + ifs.seekg(base_pos); + ReadVectorShardFloat(ifs, + /*dst=*/dst + (0 * local_C), + /*len=*/len_all, + /*start=*/tp_rank * local_C, /*cnt=*/local_C); + // Read k_i + ifs.seekg(base_pos); + ReadVectorShardFloat(ifs, + /*dst=*/dst + (1 * local_C), + /*len=*/len_all, + /*start=*/n_embd + tp_rank * local_C, /*cnt=*/local_C); + // Read v_i + ifs.seekg(base_pos); + ReadVectorShardFloat(ifs, + /*dst=*/dst + (2 * local_C), + /*len=*/len_all, + /*start=*/2 * n_embd + tp_rank * local_C, /*cnt=*/local_C); + } else { + size_t c_attn_b_bytes = qkv_out * sizeof(float); + ifs.seekg(c_attn_b_bytes, std::ios::cur); + } } + // transformer.h.{i}.attn.c_proj.weight (RowParallelLinear, but actually applies on "columns") for (int idx = 0; idx < n_layer; ++idx) { - auto &tensor - = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx), Block::kAttnLayerName, CausalSelfAttention::kCProjLayerName, - nn::parallel::RowParallelLinear::kParamWeightName)]; - ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), n_embd, n_embd, tp_rank * in_pp, in_pp); + bool owned = (idx >= start_layer && idx < end_layer); + if (owned) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, + std::to_string(idx - start_layer), Block::kAttnLayerName, + CausalSelfAttention::kCProjLayerName, + nn::parallel::RowParallelLinear::kParamWeightName)]; + ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), n_embd, n_embd, tp_rank * in_pp, + in_pp); + } else { + size_t c_proj_w_bytes = n_embd * n_embd * sizeof(float); + ifs.seekg(c_proj_w_bytes, std::ios::cur); + } } + // transformer.h.{i}.attn.c_proj.bias (RowParallelLinear, no shard on bias) for (int idx = 0; idx < n_layer; ++idx) { - auto &tensor - = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx), Block::kAttnLayerName, CausalSelfAttention::kCProjLayerName, - nn::parallel::RowParallelLinear::kParamBiasName)]; - ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + bool owned = (idx >= start_layer && idx < end_layer); + if (owned) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, + std::to_string(idx - start_layer), Block::kAttnLayerName, + CausalSelfAttention::kCProjLayerName, + nn::parallel::RowParallelLinear::kParamBiasName)]; + ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + } else { + size_t c_proj_b_bytes = n_embd * sizeof(float); + ifs.seekg(c_proj_b_bytes, std::ios::cur); + } } + // transformer.h.{i}.ln_2.weight for (int idx = 0; idx < n_layer; ++idx) { - auto &tensor - = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx), Block::kLn2LayerName, nn::LayerNorm::kParamWeightName)]; - ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + bool owned = (idx >= start_layer && idx < end_layer); + if (owned) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, + std::to_string(idx - start_layer), Block::kLn2LayerName, + nn::LayerNorm::kParamWeightName)]; + ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + } else { + size_t ln_2_w_bytes = n_embd * sizeof(float); + ifs.seekg(ln_2_w_bytes, std::ios::cur); + } } + // transformer.h.{i}.ln_2.bias for (int idx = 0; idx < n_layer; ++idx) { - auto &tensor - = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx), Block::kLn2LayerName, nn::LayerNorm::kParamBiasName)]; - ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + bool owned = (idx >= start_layer && idx < end_layer); + if (owned) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, + std::to_string(idx - start_layer), Block::kLn2LayerName, + nn::LayerNorm::kParamBiasName)]; + ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + } else { + size_t ln_2_b_bytes = n_embd * sizeof(float); + ifs.seekg(ln_2_b_bytes, std::ios::cur); + } } + // transformer.h.{i}.mlp.c_fc.weight (ColumnParallelLinear, but actually applies on "rows") for (int idx = 0; idx < n_layer; ++idx) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx), Block::kMlpLayerName, MLP::kCFcLayerName, - nn::parallel::ColumnParallelLinear::kParamWeightName)]; - ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), fc_out, n_embd, fc_start, fc_pp); + bool owned = (idx >= start_layer && idx < end_layer); + if (owned) { + auto &tensor = state_dict[std::format( + "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, std::to_string(idx - start_layer), + Block::kMlpLayerName, MLP::kCFcLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; + ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), fc_out, n_embd, fc_start, fc_pp); + } else { + size_t c_fc_w_bytes = fc_out * n_embd * sizeof(float); + ifs.seekg(c_fc_w_bytes, std::ios::cur); + } } + // transformer.h.{i}.mlp.c_fc.bias (ColumnParallelLinear) for (int idx = 0; idx < n_layer; ++idx) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx), Block::kMlpLayerName, MLP::kCFcLayerName, - nn::parallel::ColumnParallelLinear::kParamBiasName)]; - ReadVectorShardFloat(ifs, static_cast(tensor->DataPtr()), fc_out, fc_start, fc_pp); + bool owned = (idx >= start_layer && idx < end_layer); + if (owned) { + auto &tensor = state_dict[std::format( + "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, std::to_string(idx - start_layer), + Block::kMlpLayerName, MLP::kCFcLayerName, nn::parallel::ColumnParallelLinear::kParamBiasName)]; + ReadVectorShardFloat(ifs, static_cast(tensor->DataPtr()), fc_out, fc_start, fc_pp); + } else { + size_t c_fc_b_bytes = fc_out * sizeof(float); + ifs.seekg(c_fc_b_bytes, std::ios::cur); + } } + // transformer.h.{i}.mlp.c_proj.weight (RowParallelLinear, but actually applies on "columns") for (int idx = 0; idx < n_layer; ++idx) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx), Block::kMlpLayerName, MLP::kCProjLayerName, - nn::parallel::RowParallelLinear::kParamWeightName)]; - ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), n_embd, fc_out, tp_rank * in4_pp, in4_pp); + bool owned = (idx >= start_layer && idx < end_layer); + if (owned) { + auto &tensor = state_dict[std::format( + "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, std::to_string(idx - start_layer), + Block::kMlpLayerName, MLP::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)]; + ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), n_embd, fc_out, tp_rank * in4_pp, + in4_pp); + } else { + size_t c_proj_w_bytes = fc_out * n_embd * sizeof(float); + ifs.seekg(c_proj_w_bytes, std::ios::cur); + } } + // transformer.h.{i}.mlp.c_proj.bias (RowParallelLinear, no shard on bias) for (int idx = 0; idx < n_layer; ++idx) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx), Block::kMlpLayerName, MLP::kCProjLayerName, - nn::parallel::RowParallelLinear::kParamBiasName)]; - ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + bool owned = (idx >= start_layer && idx < end_layer); + if (owned) { + auto &tensor = state_dict[std::format( + "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, std::to_string(idx - start_layer), + Block::kMlpLayerName, MLP::kCProjLayerName, nn::parallel::RowParallelLinear::kParamBiasName)]; + ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + } else { + size_t c_proj_b_bytes = n_embd * sizeof(float); + ifs.seekg(c_proj_b_bytes, std::ios::cur); + } } - // transformer.ln_f.weight - auto &transformer_ln_f_weight = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kLnFLayerName, - nn::LayerNorm::kParamWeightName)]; - ReadVectorAllFloat(ifs, static_cast(transformer_ln_f_weight->DataPtr()), n_embd); - // transformer.ln_f.bias - auto &transformer_ln_f_bias = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kLnFLayerName, - nn::LayerNorm::kParamBiasName)]; - ReadVectorAllFloat(ifs, static_cast(transformer_ln_f_bias->DataPtr()), n_embd); + if (is_last_stage) { + // transformer.ln_f.weight + auto &transformer_ln_f_weight = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, + GPT2::kLnFLayerName, nn::LayerNorm::kParamWeightName)]; + ReadVectorAllFloat(ifs, static_cast(transformer_ln_f_weight->DataPtr()), n_embd); + // transformer.ln_f.bias + auto &transformer_ln_f_bias = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, + GPT2::kLnFLayerName, nn::LayerNorm::kParamBiasName)]; + ReadVectorAllFloat(ifs, static_cast(transformer_ln_f_bias->DataPtr()), n_embd); + } else { + size_t ln_f_w_bytes = n_embd * sizeof(float); + size_t ln_f_b_bytes = n_embd * sizeof(float); + ifs.seekg(ln_f_w_bytes + ln_f_b_bytes, std::ios::cur); + } return local_gpt2; } diff --git a/example/gpt2/net.h b/example/gpt2/net.h index 548e43a2..dba5cdfe 100644 --- a/example/gpt2/net.h +++ b/example/gpt2/net.h @@ -89,9 +89,6 @@ class GPT2 : public infini_train::nn::CloneableModule { explicit GPT2(const GPT2Config &config); - std::unordered_map GetConfig() const override; - std::vector> GetPipelineLayers() override; - std::vector> Forward(const std::vector> &x) override; diff --git a/example/llama3/main.cc b/example/llama3/main.cc index c983268c..c3df5a97 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -63,10 +63,7 @@ DEFINE_int32( "When set > 1, enables data parallelism with device=cuda on the specified number of visible CUDA devices."); DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size"); DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel"); -DEFINE_uint32( - pipeline_parallel, 1, - "Pipeline Parallel world size, will always use device=cuda and use all cuda visible devices when set to true"); -DEFINE_uint32(num_microbatches, 4, "the num of microbatches in pipeline parallelism"); +DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, , specified the number of PP stages."); // precision DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)"); @@ -130,6 +127,8 @@ void Train(const nn::parallel::Rank &rank) { pp_pg = ProcessGroupFactory::Instance()->GetOrCreate( GetPipelineParallelProcessGroupName(rank.thread_rank()), GetPipelineParallelGroupRanks(pp_world_size)); pp_rank = pp_pg->GetGroupRank(rank.thread_rank()); + + nn::parallel::pp_rank = pp_rank; } } else { device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDefaultDevice() @@ -137,7 +136,7 @@ void Train(const nn::parallel::Rank &rank) { } // calculate gradient accumulation from the desired total batch size and the current run configuration - const auto tokens_per_fwdbwd = FLAGS_batch_size * FLAGS_sequence_length * (ddp_world_size * pp_world_size); + const auto tokens_per_fwdbwd = FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size; CHECK_EQ(FLAGS_total_batch_size % tokens_per_fwdbwd, 0); const auto grad_accum_steps = FLAGS_total_batch_size / tokens_per_fwdbwd; if (rank.IsMainRank()) { @@ -177,16 +176,10 @@ void Train(const nn::parallel::Rank &rank) { model = std::make_shared(model, rank.thread_rank()); } - std::unique_ptr train_loader; - if (pp_world_size > 1) { - train_loader = std::make_unique( - std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), - FLAGS_batch_size * pp_world_size); - } else { - train_loader = std::make_unique( - std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), FLAGS_batch_size, - ddp_rank, ddp_world_size); - } + auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size); + + DistributedDataLoader train_loader(std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), + FLAGS_batch_size * num_micro_batches, ddp_rank, ddp_world_size); std::optional val_loader = std::nullopt; if (!FLAGS_input_val_bin.empty()) { @@ -204,13 +197,9 @@ void Train(const nn::parallel::Rank &rank) { } // TODO(dcj): support more complex optimizer later - auto lr = FLAGS_learning_rate; - auto optimizer_factory = [lr](const std::vector> ¶ms) { - return std::make_shared(params, lr); - }; - auto optimizer = optimizer_factory(model->Parameters()); + auto optimizer = optimizers::Adam(model->Parameters(), FLAGS_learning_rate); - auto train_iter = train_loader->begin(); + auto train_iter = train_loader.begin(); std::shared_ptr loss_fn = (tp_world_size > 1) ? std::static_pointer_cast(std::make_shared()) : std::static_pointer_cast(std::make_shared()); @@ -218,15 +207,10 @@ void Train(const nn::parallel::Rank &rank) { LOG(INFO) << "Rank " << rank.thread_rank() << ": start training"; if (pp_world_size > 1) { - CHECK_EQ((FLAGS_batch_size * pp_world_size) % FLAGS_num_microbatches, 0) - << "FLAGS_batch_size (" << (FLAGS_batch_size * pp_world_size) - << ") must be divisible by FLAGS_num_microbatches (" << FLAGS_num_microbatches << ")"; - - auto shapes = std::vector>{{FLAGS_batch_size * pp_world_size / FLAGS_num_microbatches, - FLAGS_sequence_length, model->GetConfig()["n_embd"]}}; + auto shapes = std::vector>{{FLAGS_batch_size, FLAGS_sequence_length, model_config.n_embd}}; - model = std::make_shared(model, pp_world_size, FLAGS_num_microbatches, shapes, - pp_rank, optimizer_factory); + model = std::make_shared( + model, pp_world_size, num_micro_batches, shapes, pp_rank, std::make_shared(optimizer)); } for (int step = 0; step < FLAGS_num_iteration + 1; ++step) { @@ -251,81 +235,81 @@ void Train(const nn::parallel::Rank &rank) { break; } - // model->Train(); - if (pp_world_size == 1) { - optimizer->ZeroGrad(); - } - // if we are trying to overfit a single batch, we reset the loader here - if (FLAGS_overfit_single_batch) { - // train_loader.Reset(); - } - float lossf = 0.0f; #ifdef PROFILE_MODE Profiler::Instance().SetTag("Step_" + std::to_string(step)); #endif - for (int micro_step = 0; micro_step < grad_accum_steps; ++micro_step) { - // enable autocast for the current step - infini_train::AutocastGuard autocast_guard(device->Type(), dtype); - // (bs, seq_len), (bs, seq_len) + float lossf = 0.0f; + if (pp_world_size == 1) { + // model->Train(); + optimizer.ZeroGrad(); + + // if we are trying to overfit a single batch, we reset the loader here + if (FLAGS_overfit_single_batch) { + // train_loader.Reset(); + } + + for (int micro_step = 0; micro_step < grad_accum_steps; ++micro_step) { + // enable autocast for the current step + infini_train::AutocastGuard autocast_guard(device->Type(), dtype); + + // (bs, seq_len), (bs, seq_len) + auto [x, y] = *train_iter; + // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below + // TODO(dcj): support dataloader.reset() later + ++train_iter; + x = std::make_shared(x->To(device)); + y = std::make_shared(y->To(device)); + + LOG(INFO) << "Rank " << rank.thread_rank() << ": start forward"; + // (bs, seq_len, vocab_size) + auto logits = model->Forward({x, y})[0]; + LOG(INFO) << "Rank " << rank.thread_rank() << ": finish model forward, start loss forward"; + auto loss = loss_fn->Forward({logits, y})[0]; + loss = loss / grad_accum_steps; + + // disable autocast for the current step (backward is not under autocast) + autocast_guard.Disable(); + + LOG(INFO) << "Rank " << rank.thread_rank() << ": finish loss forward"; + if (ddp_world_size > 1) { + function::AllReduce(loss, function::ReduceOpType::kAvg); + } + auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice()); + lossf += static_cast(loss_cpu.DataPtr())[0]; + LOG(INFO) << "Rank " << rank.thread_rank() << ": start backward"; + loss->Backward(); + LOG(INFO) << "Rank " << rank.thread_rank() << ": finish backward"; + } + + optimizer.Step(); + } else { auto [x, y] = *train_iter; // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below // TODO(dcj): support dataloader.reset() later ++train_iter; x = std::make_shared(x->To(device)); y = std::make_shared(y->To(device)); - if (pp_world_size > 1) { - lossf = model->TrainStep({x}, {y}, loss_fn); - - auto loss_tensor = std::make_shared(std::vector{}, DataType::kFLOAT32); - static_cast(loss_tensor->DataPtr())[0] = lossf; - auto loss_device_ptr = std::make_shared(loss_tensor->To(device)); - function::AllReduce(loss_device_ptr, function::ReduceOpType::kMax); - auto loss_copy = loss_device_ptr->To(DeviceManager::Instance()->GetDefaultDevice()); - lossf = static_cast(loss_copy.DataPtr())[0]; - continue; - } - LOG(INFO) << "Rank " << rank.thread_rank() << ": start forward"; - // (bs, seq_len, vocab_size) - auto logits = model->Forward({x, y})[0]; - LOG(INFO) << "Rank " << rank.thread_rank() << ": finish model forward, start loss forward"; - auto loss = loss_fn->Forward({logits, y})[0]; - loss = loss / grad_accum_steps; - - // disable autocast for the current step (backward is not under autocast) - autocast_guard.Disable(); - - LOG(INFO) << "Rank " << rank.thread_rank() << ": finish loss forward"; - if (ddp_world_size > 1) { - function::AllReduce(loss, function::ReduceOpType::kAvg); - } - auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice()); - lossf += static_cast(loss_cpu.DataPtr())[0]; - LOG(INFO) << "Rank " << rank.thread_rank() << ": start backward"; - loss->Backward(); - LOG(INFO) << "Rank " << rank.thread_rank() << ": finish backward"; - } - - if (pp_world_size == 1) { - optimizer->Step(); + lossf = model->TrainStep({x}, {y}, loss_fn); } const auto iter_end = std::chrono::high_resolution_clock::now(); const double duration_us = std::chrono::duration(iter_end - iter_start).count(); const double tps = FLAGS_total_batch_size / (duration_us / 1e6); - if (rank.IsMainRank()) { + if (rank.thread_rank() == pp_world_size - 1) { LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, " "DP={}, TP={}, SP={}, PP={})", step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, tps, ddp_world_size, tp_world_size, sp_world_size, pp_world_size); if ((step + 1) % FLAGS_freq_generate_txt == 0) { - if (!tokenizer) { - continue; + // FIXME(jym): to support PP + if (tokenizer) { + CHECK_EQ(pp_world_size, 1); + tokenizer->GenerateText(*model, FLAGS_batch_size, FLAGS_sequence_length, FLAGS_text_length, device); } - tokenizer->GenerateText(*model, FLAGS_batch_size, FLAGS_sequence_length, FLAGS_text_length, device); } } } diff --git a/example/llama3/net.cc b/example/llama3/net.cc index 856ccbf6..3e7becfa 100644 --- a/example/llama3/net.cc +++ b/example/llama3/net.cc @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -22,6 +23,7 @@ #include "infini_train/include/nn/modules/normalization.h" #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" #include "infini_train/include/tensor.h" @@ -99,8 +101,8 @@ ApplyRotaryEmbedding(const std::shared_ptr &xq, const std::shared_ptr PrecomputeFreqsCis(int64_t dim, int64_t end, float theta = 10000.0f, bool use_scaled = false, const infini_train::Device *device - = DeviceManager::Instance()->GetDefaultDevice(), - DataType dtype = DataType::kFLOAT32) { + = DeviceManager::Instance()->GetDefaultDevice()) { + DataType dtype = DataType::kFLOAT32; CHECK_GE(dim, 2) << "dim must be >= 2 for slicing"; auto arange = nn::init::Arange(0, dim, dtype, device)->Slice(0, 0, dim, 2); auto freqs = 1.0f / nn::function::Pow(theta, arange / float(dim)); @@ -324,153 +326,92 @@ std::vector> Block::Forward(const std::vector> transformer; + int pp_size = nn::parallel::global::GetPipelineParallelSize(); + auto [is_first_stage, is_last_stage, start_layer, end_layer] + = nn::parallel::PipelineParallel::GetStageInfo(config_.n_layer, pp_size); + + std::unordered_map> transformer; + if (is_first_stage) { transformer[kWTELayerName] = std::make_shared( config.vocab_size, config.n_embd, nn::parallel::global::GetSequenceParallelEnabled()); - { - std::vector> h; - for (int64_t i = 0; i < config.n_layer; ++i) { h.push_back(std::make_shared(config)); } - transformer[kHLayerName] = std::make_shared(std::move(h)); - } - transformer[kLnFLayerName] = std::make_shared(config.n_embd, config.norm_eps); - modules_[kTransformerLayerName] = std::make_shared(std::move(transformer)); } - // NOTE(zbl): weight-tying is possible but torch script did not do so - modules_[kLMHeadLayerName] = std::make_shared( - /*in_features=*/config_.n_embd, /*out_features=*/config_.vocab_size, - /*bias=*/false, - // NOTE(zbl): each tp_rank would get sharded [B, T, V_local] as logits - /*gather_output=*/false, - /*input_is_parallel=*/false, - /*skip_bias_add=*/false, - /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); -} - -class LLaMALayer : public nn::Module { - std::shared_ptr h_layer_; - std::shared_ptr freqs_cis_ = nullptr; // freqs_cis (shape: [max_seq_len, head_dim]) - int64_t start_pos_ = 0; - DataType dtype_; - LLaMA3Config config_; - -public: - LLaMALayer(std::shared_ptr h_layer, DataType dtype, LLaMA3Config config) - : h_layer_(h_layer), dtype_(dtype), config_(config) {} - - void SetFreqsCis(std::shared_ptr freqs) { freqs_cis_ = freqs; } - void SetStartPos(int64_t pos) { start_pos_ = pos; } - - std::vector> Parameters() const override { return h_layer_->Parameters(); } - std::vector> Forward(const std::vector> &inputs) override { - auto &x = inputs[0]; // (bs, seq_len, n_embd) - const int seq_len = x->Dims()[1]; - const auto device = x->GetDevice(); + std::vector> h_local; + for (int64_t i = start_layer; i < end_layer; ++i) { h_local.push_back(std::make_shared(config)); } + transformer[kHLayerName] = std::make_shared(std::move(h_local)); - if (freqs_cis_ == nullptr) { - freqs_cis_ = PrecomputeFreqsCis(config_.n_embd / config_.n_head, config_.block_size * 2, config_.rope_theta, - config_.use_scaled_rope, device, dtype_); - } - - // slice freqs_cis: [start_pos:start_pos+seq_len] - auto freqs_view = freqs_cis_->Slice(0, start_pos_, start_pos_ + seq_len, 1); - - std::shared_ptr start_pos_ptr = nullptr; - - // causal mask: (1, 1, seq_len, seq_len) - std::shared_ptr ones = std::make_shared(nn::function::Ones({seq_len, seq_len})->To(device)); - auto mask = nn::function::Triu(ones, 1)->View({1, 1, seq_len, seq_len}); - if (dtype_ == DataType::kBFLOAT16) { - mask = std::make_shared(mask->To(DataType::kBFLOAT16)); - } - - // DecoderLayer: {x, freqs, start_pos, mask} - // std::vector> args = {x, freqs_view, nullptr, mask}; - auto output = h_layer_->Forward({x, freqs_view, start_pos_ptr, mask}); - return output; - } -}; - -std::vector> LLaMA3::GetPipelineLayers() { - std::vector> layers; - - auto transformer = modules_[kTransformerLayerName]; - layers.push_back(transformer->mutable_module(kWTELayerName)); - - auto seq = std::dynamic_pointer_cast(transformer->mutable_module(kHLayerName)); - auto dtype = modules_[kLMHeadLayerName]->parameter(nn::Linear::kParamWeightName)->Dtype(); - int idx = 0; - for (auto h : *seq) { - layers.push_back(std::make_shared(h, dtype, config_)); - ++idx; + if (is_last_stage) { + transformer[kLnFLayerName] = std::make_shared(config.n_embd, config.norm_eps); + // NOTE(zbl): weight-tying is possible but torch script did not do so + modules_[kLMHeadLayerName] = std::make_shared( + /*in_features=*/config_.n_embd, /*out_features=*/config_.vocab_size, + /*bias=*/false, + // NOTE(zbl): each rank would get sharded [B, T, V_local] as logits + /*gather_output=*/false, + /*input_is_parallel=*/false, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); } - - layers.push_back(transformer->mutable_module(kLnFLayerName)); - - layers.push_back(modules_[kLMHeadLayerName]); - - return layers; -} - -std::unordered_map LLaMA3::GetConfig() const { - return {{"n_embd", config_.n_embd}, {"n_head", config_.n_head}}; + modules_[kTransformerLayerName] = std::make_shared(std::move(transformer)); } std::vector> LLaMA3::Forward(const std::vector> &x) { + int pp_rank = nn::parallel::pp_rank; + int pp_size = nn::parallel::global::GetPipelineParallelSize(); + bool is_first_stage = (pp_rank == 0); + bool is_last_stage = (pp_rank == pp_size - 1); + // (bs, seq_len) - auto &idx = x[0]; - const auto device = idx->GetDevice(); - const auto t = idx->Dims()[1]; // seq_len + auto x1 = x[0]; + const auto device = x1->GetDevice(); + const auto t = x1->Dims()[1]; // seq_len CHECK_LE(t, config_.block_size) << "Cannot forward sequence of length " << t << ", block size is only " << config_.block_size; // Init freqs_cis on device only once // TODO(zbl): consider moving this to model construction if (buffers_[kFreqsCisName] == nullptr) { - buffers_[kFreqsCisName] = PrecomputeFreqsCis( - config_.n_embd / config_.n_head, config_.block_size * 2, config_.rope_theta, config_.use_scaled_rope, - device, - modules_[kLMHeadLayerName]->parameter(nn::parallel::ColumnParallelLinear::kParamWeightName)->Dtype()); + buffers_[kFreqsCisName] = PrecomputeFreqsCis(config_.n_embd / config_.n_head, config_.block_size * 2, + config_.rope_theta, config_.use_scaled_rope, device); } // forward the LLaMA3 model itself auto &transformer = modules_[kTransformerLayerName]; - // (bs, seq_len) -> Embedding(vocab_size, n_embd) -> (bs, seq_len, n_embd) - auto x1 = transformer->mutable_module(kWTELayerName)->Forward({idx})[0]; + + if (is_first_stage) { + // (bs, seq_len) -> Embedding(vocab_size, n_embd) -> (bs, seq_len, n_embd) + x1 = transformer->mutable_module(kWTELayerName)->Forward({x1})[0]; + } // TODO(zbl): dynamic start_pos int64_t start_pos = 0; auto freqs_view = buffers_[kFreqsCisName]->Slice(0, start_pos, start_pos + t, 1); // TODO(lzm): add dtype support for nn::function::Ones later - std::shared_ptr ones = std::make_shared(nn::function::Ones({t, t})->To(idx->GetDevice())); + std::shared_ptr ones = std::make_shared(nn::function::Ones({t, t})->To(x1->GetDevice())); std::shared_ptr mask = nn::function::Triu(ones, 1)->View({1, 1, t, t}); - // TODO(zbl): nn::function::Ones builds tensor in FP32 by default - if (modules_[kLMHeadLayerName]->parameter(nn::parallel::ColumnParallelLinear::kParamWeightName)->Dtype() - == DataType::kBFLOAT16) { - mask = std::make_shared(mask->To(DataType::kBFLOAT16)); - } + std::shared_ptr start_pos_ptr = nullptr; auto h_modules = transformer->mutable_module(kHLayerName); - if (h_modules->type() == nn::ModuleList::kType) { - auto h_layers = std::dynamic_pointer_cast(h_modules); - // (bs, seq_len, n_embd) -> transformer -> (bs, seq_len, n_embd) - for (auto &h : *h_layers) { x1 = h->Forward({x1, freqs_view, start_pos_ptr, mask})[0]; } - } else { - LOG(FATAL) << "Failed to get ModuleList from transformer"; - } + CHECK_EQ(h_modules->type(), nn::ModuleList::kType) << "Failed to get ModuleList from transformer"; + auto h_layers = std::dynamic_pointer_cast(h_modules); + // (bs, seq_len, n_embd) -> transformer -> (bs, seq_len, n_embd) + for (auto &h : *h_layers) { x1 = h->Forward({x1, freqs_view, start_pos_ptr, mask})[0]; } - // (bs, seq_len, n_embd) -> RMSNorm -> (bs, seq_len, n_embd) - auto x2 = transformer->mutable_module(kLnFLayerName)->Forward({x1}); + if (is_last_stage) { + // (bs, seq_len, n_embd) -> RMSNorm -> (bs, seq_len, n_embd) + auto x2 = transformer->mutable_module(kLnFLayerName)->Forward({x1}); - // TODO(zbl): add inference-time mini-optimization - // (bs, seq_len, n_embd) -> Linear(n_embd, vocab_size) -> (bs, seq_len, vocab_size) - auto logits = modules_[kLMHeadLayerName]->Forward(x2); + // TODO(zbl): add inference-time mini-optimization + // (bs, seq_len, n_embd) -> Linear(n_embd, vocab_size) -> (bs, seq_len, vocab_size) + auto logits = modules_[kLMHeadLayerName]->Forward(x2); - // (bs, seq_len, vocab_size) - return logits; + // (bs, seq_len, vocab_size) + return logits; + } + + return {x1}; } std::shared_ptr LLaMA3::FromPretrained(ModelType model_type) { @@ -524,6 +465,9 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { .use_scaled_rope = static_cast(use_scaled_rope), .norm_eps = norm_eps, .max_gen_batch_size = max_gen_bs}); + int pp_size = nn::parallel::global::GetPipelineParallelSize(); + auto [is_first_stage, is_last_stage, start_layer, end_layer] + = nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size); const int tp_size = nn::parallel::global::GetTensorParallelSize(); const int tp_rank = nn::parallel::tp_rank; @@ -592,114 +536,168 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { // ========== Read Sharded Params ========== // transformer.wte.weight : (vocab_size, n_embd) -> local tp_rank: rows of [v_start : v_start+vpp) - { + if (is_first_stage) { auto &wte = state_dict[std::format("{}.{}.{}", kTransformerLayerName, kWTELayerName, nn::parallel::VocabParallelEmbedding::kParamWeightName)]; ReadMatrixRowShardFloat(ifs, static_cast(wte->DataPtr()), /*rows=*/vocab_size, /*cols=*/n_embd, /*row_start=*/v_start, /*row_cnt=*/vpp); + } else { + size_t wte_bytes = static_cast(vocab_size) * n_embd * sizeof(float); + ifs.seekg(wte_bytes, std::ios::cur); } // transformer.h.{i}.ln_1.weight : Full version RMSNorm for (int i = 0; i < static_cast(n_layer); ++i) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, std::to_string(i), - Block::kLn1LayerName, RMSNorm::kParamWeightName)]; - ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + bool owned = (i >= start_layer && i < end_layer); + + if (owned) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, + std::to_string(i - start_layer), Block::kLn1LayerName, + RMSNorm::kParamWeightName)]; + ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + } else { + size_t ln_1_bytes = n_embd * sizeof(float); + ifs.seekg(ln_1_bytes, std::ios::cur); + } } // transformer.h.{i}.attn.c_attn.weight : ColumnParallelLinear, but actually applies on "rows" // W-qkv should be [Q(=n_embd) | K(=n_kv_head*head_dim) | V(=n_kv_head*head_dim)] × n_embd for (int i = 0; i < static_cast(n_layer); ++i) { - auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, std::to_string(i), Block::kAttnLayerName, - CausalSelfAttention::kCAttnLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; - float *dst = static_cast(tensor->DataPtr()); - const std::streampos base_pos = ifs.tellg(); - - // Q block -> [0 : q_local_rows) - ifs.seekg(base_pos); - ReadMatrixRowShardFloat(ifs, - /*dst=*/dst + (0 * attn_cols), - /*rows=*/attn_rows_all, /*cols=*/attn_cols, - /*row_start=*/tp_rank * q_local_rows, /*row_cnt=*/q_local_rows); - - // K block -> [q_local_rows : q_local_rows + kv_local_rows) - ifs.seekg(base_pos); - ReadMatrixRowShardFloat(ifs, - /*dst=*/dst + (q_local_rows * attn_cols), - /*rows=*/attn_rows_all, /*cols=*/attn_cols, - /*row_start=*/q_out_rows + tp_rank * kv_local_rows, /*row_cnt=*/kv_local_rows); - - // V block -> [q_local_rows + kv_local_rows : q_local_rows + 2*kv_local_rows) - ifs.seekg(base_pos); - ReadMatrixRowShardFloat(ifs, - /*dst=*/dst + ((q_local_rows + kv_local_rows) * attn_cols), - /*rows=*/attn_rows_all, /*cols=*/attn_cols, - /*row_start=*/q_out_rows + kv_out_rows + tp_rank * kv_local_rows, - /*row_cnt=*/kv_local_rows); + bool owned = (i >= start_layer && i < end_layer); + if (owned) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, + std::to_string(i - start_layer), Block::kAttnLayerName, + CausalSelfAttention::kCAttnLayerName, + nn::parallel::ColumnParallelLinear::kParamWeightName)]; + + float *dst = static_cast(tensor->DataPtr()); + const std::streampos base_pos = ifs.tellg(); + + // Q block -> [0 : q_local_rows) + ifs.seekg(base_pos); + ReadMatrixRowShardFloat(ifs, + /*dst=*/dst + (0 * attn_cols), + /*rows=*/attn_rows_all, /*cols=*/attn_cols, + /*row_start=*/tp_rank * q_local_rows, /*row_cnt=*/q_local_rows); + + // K block -> [q_local_rows : q_local_rows + kv_local_rows) + ifs.seekg(base_pos); + ReadMatrixRowShardFloat(ifs, + /*dst=*/dst + (q_local_rows * attn_cols), + /*rows=*/attn_rows_all, /*cols=*/attn_cols, + /*row_start=*/q_out_rows + tp_rank * kv_local_rows, /*row_cnt=*/kv_local_rows); + + // V block -> [q_local_rows + kv_local_rows : q_local_rows + 2*kv_local_rows) + ifs.seekg(base_pos); + ReadMatrixRowShardFloat(ifs, + /*dst=*/dst + ((q_local_rows + kv_local_rows) * attn_cols), + /*rows=*/attn_rows_all, /*cols=*/attn_cols, + /*row_start=*/q_out_rows + kv_out_rows + tp_rank * kv_local_rows, + /*row_cnt=*/kv_local_rows); + } else { + size_t qkv_bytes = static_cast(attn_rows_all) * attn_cols * sizeof(float); + ifs.seekg(qkv_bytes, std::ios::cur); + } } // transformer.h.{i}.attn.c_proj.weight : RowParallelLinear, but actually applies on "columns" for (int i = 0; i < static_cast(n_layer); ++i) { - auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, std::to_string(i), Block::kAttnLayerName, - CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)]; - ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), - /*rows=*/n_embd, /*cols=*/n_embd, - /*col_start=*/tp_rank * in_pp, /*col_cnt=*/in_pp); + bool owned = (i >= start_layer && i < end_layer); + if (owned) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, + std::to_string(i - start_layer), Block::kAttnLayerName, + CausalSelfAttention::kCProjLayerName, + nn::parallel::RowParallelLinear::kParamWeightName)]; + ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), + /*rows=*/n_embd, /*cols=*/n_embd, + /*col_start=*/tp_rank * in_pp, /*col_cnt=*/in_pp); + } else { + size_t c_proj_bytes = static_cast(n_embd) * n_embd * sizeof(float); + ifs.seekg(c_proj_bytes, std::ios::cur); + } } // transformer.h.{i}.ln_2.weight : Full version RMSNorm for (int i = 0; i < static_cast(n_layer); ++i) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, std::to_string(i), - Block::kLn2LayerName, RMSNorm::kParamWeightName)]; - ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + bool owned = (i >= start_layer && i < end_layer); + if (owned) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, + std::to_string(i - start_layer), Block::kLn2LayerName, + RMSNorm::kParamWeightName)]; + ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + } else { + size_t ln_2_bytes = static_cast(n_embd) * sizeof(float); + ifs.seekg(ln_2_bytes, std::ios::cur); + } } // transformer.h.{i}.mlp.c_fc.weight : ColumnParallelLinear, but actually applies on "rows" for (int i = 0; i < static_cast(n_layer); ++i) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, - std::to_string(i), Block::kMlpLayerName, MLP::kCFcLayerName, - nn::parallel::ColumnParallelLinear::kParamWeightName)]; - ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), - /*rows=*/fc_out, /*cols=*/n_embd, - /*row_start=*/tp_rank * fc_pp, /*row_cnt=*/fc_pp); + bool owned = (i >= start_layer && i < end_layer); + if (owned) { + auto &tensor = state_dict[std::format( + "{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, std::to_string(i - start_layer), + Block::kMlpLayerName, MLP::kCFcLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; + ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), + /*rows=*/fc_out, /*cols=*/n_embd, + /*row_start=*/tp_rank * fc_pp, /*row_cnt=*/fc_pp); + } else { + size_t fc_bytes = static_cast(ffn_hidden) * n_embd * sizeof(float); + ifs.seekg(fc_bytes, std::ios::cur); + } } // transformer.h.{i}.mlp.c_fc2.weight : ColumnParallelLinear, but actually applies on "rows" for (int i = 0; i < static_cast(n_layer); ++i) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, - std::to_string(i), Block::kMlpLayerName, MLP::kCFc2LayerName, - nn::parallel::ColumnParallelLinear::kParamWeightName)]; - ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), - /*rows=*/fc_out, /*cols=*/n_embd, - /*row_start=*/tp_rank * fc_pp, /*row_cnt=*/fc_pp); + bool owned = (i >= start_layer && i < end_layer); + if (owned) { + auto &tensor = state_dict[std::format( + "{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, std::to_string(i - start_layer), + Block::kMlpLayerName, MLP::kCFc2LayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; + ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), + /*rows=*/fc_out, /*cols=*/n_embd, + /*row_start=*/tp_rank * fc_pp, /*row_cnt=*/fc_pp); + } else { + size_t fc2_bytes = static_cast(ffn_hidden) * n_embd * sizeof(float); + ifs.seekg(fc2_bytes, std::ios::cur); + } } // transformer.h.{i}.mlp.c_proj.weight : RowParallelLinear, but actually applies on "columns" for (int i = 0; i < static_cast(n_layer); ++i) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, - std::to_string(i), Block::kMlpLayerName, MLP::kCProjLayerName, - nn::parallel::RowParallelLinear::kParamWeightName)]; - ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), - /*rows=*/n_embd, /*cols=*/fc_out, - /*col_start=*/tp_rank * in_fc_pp, /*col_cnt=*/in_fc_pp); + bool owned = (i >= start_layer && i < end_layer); + if (owned) { + auto &tensor = state_dict[std::format( + "{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, std::to_string(i - start_layer), + Block::kMlpLayerName, MLP::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)]; + ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), + /*rows=*/n_embd, /*cols=*/fc_out, + /*col_start=*/tp_rank * in_fc_pp, /*col_cnt=*/in_fc_pp); + } else { + size_t c_proj_bytes = static_cast(n_embd) * ffn_hidden * sizeof(float); + ifs.seekg(c_proj_bytes, std::ios::cur); + } } // transformer.ln_f.weight : Full version RMSNorm - { - auto &ln_f - = state_dict[std::format("{}.{}.{}", kTransformerLayerName, kLnFLayerName, RMSNorm::kParamWeightName)]; - ReadVectorAllFloat(ifs, static_cast(ln_f->DataPtr()), n_embd); - } - // lm_head.weight : (vocab_size, n_embd) -> ColumnParallelLinear, but actually applies on "rows" { - auto &lm_head - = state_dict[std::format("{}.{}", kLMHeadLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; - ReadMatrixRowShardFloat(ifs, static_cast(lm_head->DataPtr()), - /*rows=*/vocab_size, /*cols=*/n_embd, - /*row_start=*/v_start, /*row_cnt=*/vpp); + if (is_last_stage) { + auto &ln_f + = state_dict[std::format("{}.{}.{}", kTransformerLayerName, kLnFLayerName, RMSNorm::kParamWeightName)]; + auto &lm_head = state_dict[std::format("{}.{}", kLMHeadLayerName, + nn::parallel::ColumnParallelLinear::kParamWeightName)]; + ReadVectorAllFloat(ifs, static_cast(ln_f->DataPtr()), n_embd); + ReadMatrixRowShardFloat(ifs, static_cast(lm_head->DataPtr()), + /*rows=*/vocab_size, /*cols=*/n_embd, + /*row_start=*/v_start, /*row_cnt=*/vpp); + } else { + size_t ln_f_bytes = static_cast(n_embd) * sizeof(float); + size_t lm_head_bytes = static_cast(vocab_size) * n_embd * sizeof(float); + ifs.seekg(ln_f_bytes + lm_head_bytes, std::ios::cur); + } } return llama3; diff --git a/example/llama3/net.h b/example/llama3/net.h index d89907c4..ec0199aa 100644 --- a/example/llama3/net.h +++ b/example/llama3/net.h @@ -129,10 +129,6 @@ class LLaMA3 : public infini_train::nn::CloneableModule { explicit LLaMA3(const LLaMA3Config &config); - std::vector> GetPipelineLayers() override; - - std::unordered_map GetConfig() const; - std::vector> Forward(const std::vector> &x) override; diff --git a/infini_train/include/nn/modules/container.h b/infini_train/include/nn/modules/container.h index 493c1ba8..d1356712 100644 --- a/infini_train/include/nn/modules/container.h +++ b/infini_train/include/nn/modules/container.h @@ -17,16 +17,6 @@ class Sequential : public CloneableModule { explicit Sequential(std::vector> &&layers); std::vector> Forward(const std::vector> &input_tensors) override; - - size_t size() const { return modules_.size(); } - - std::shared_ptr operator[](size_t idx) const { - auto it = modules_.find(std::to_string(idx)); - if (it != modules_.end()) { - return it->second; - } - return nullptr; - } }; class ModuleDict : public CloneableModule { diff --git a/infini_train/include/nn/modules/module.h b/infini_train/include/nn/modules/module.h index 85fc35f8..c090dd35 100644 --- a/infini_train/include/nn/modules/module.h +++ b/infini_train/include/nn/modules/module.h @@ -47,10 +47,6 @@ class Module : public std::enable_shared_from_this { virtual std::vector> Forward(const std::vector> &input_tensors); - virtual std::vector> GetPipelineLayers() { return {}; } - - virtual std::unordered_map GetConfig() const { return {}; } - virtual float TrainStep(const std::vector> &input_tensors, const std::vector> &targets, const std::shared_ptr &loss_fn) { diff --git a/infini_train/include/nn/parallel/global.h b/infini_train/include/nn/parallel/global.h index 4b44e408..1a6e22fd 100644 --- a/infini_train/include/nn/parallel/global.h +++ b/infini_train/include/nn/parallel/global.h @@ -26,8 +26,8 @@ class GlobalEnv { public: static GlobalEnv &Instance(); - void Init(int threads_per_process, int tensor_parallel_size, bool sequence_parallel_enabled = false, - int pipeline_parallel_size = 1); + void Init(int threads_per_process, int tensor_parallel_size, bool sequence_parallel_enabled, + int pipeline_parallel_size); int world_size() const; @@ -76,8 +76,8 @@ class GlobalEnv { Layout layout_; }; -inline void InitAllEnv(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled = false, - int pipeline_parallel_size = 1) { +inline void InitAllEnv(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled, + int pipeline_parallel_size) { GlobalEnv::Instance().Init(nthread_per_process, tensor_parallel_size, sequence_parallel_enabled, pipeline_parallel_size); } diff --git a/infini_train/include/nn/parallel/pp/pipeline_parallel.h b/infini_train/include/nn/parallel/pp/pipeline_parallel.h index 166c6849..842bff45 100644 --- a/infini_train/include/nn/parallel/pp/pipeline_parallel.h +++ b/infini_train/include/nn/parallel/pp/pipeline_parallel.h @@ -4,47 +4,41 @@ #include #include -#include "infini_train/include/device.h" #include "infini_train/include/nn/modules/module.h" -#include "infini_train/include/tensor.h" +#include "infini_train/include/optimizer.h" -#include "infini_train/include/nn/parallel/pp/pipeline_schedule.h" -#include "infini_train/include/nn/parallel/pp/pipeline_stage.h" +namespace infini_train { +class Tensor; +class Device; +} // namespace infini_train namespace infini_train::nn::parallel { +class PipelineStage; +class PipelineSchedule; -using OptimizerFactory = std::function(const std::vector> ¶ms)>; +extern thread_local int pp_rank; class PipelineParallel : public Module { public: - PipelineParallel(const std::shared_ptr &model, int num_stages, int num_microbatches, - const std::vector> &recv_shape, int rank, OptimizerFactory optimizer_factory); + PipelineParallel(const std::shared_ptr module, int num_stages, int num_micro_batches, + const std::vector> &recv_shape, int rank, + const std::shared_ptr &optimizer); float TrainStep(const std::vector> &input, - const std::vector> &target, const std::shared_ptr &loss_fn); + const std::vector> &target, const std::shared_ptr &loss_fn); -private: - int num_stages_; - int rank_; - std::vector devices_; - std::shared_ptr original_model_; - std::shared_ptr pipeline_stage_; - std::shared_ptr schedule_; - - std::vector>> - SplitLayersIntoStages(std::vector> layers); + static std::tuple GetStageInfo(int total_layers, int pp_size); - void SplitModel(const std::vector> &recv_shape, OptimizerFactory optimizer_factory); - - std::vector> - CreateOptimizers(const std::vector>> &stage_layers, - OptimizerFactory optimizer_factory); +private: + int num_stages_ = -1; + int rank_ = -1; + std::shared_ptr pipeline_stage_ = nullptr; + std::shared_ptr schedule_ = nullptr; - void BuildPipelineStage(const std::vector>> &stage_layers, - const std::vector> &optimizers, + void BuildPipelineStage(const std::shared_ptr &model, const std::shared_ptr &optimizer, const std::vector> &recv_shape); - void SetupSchedule(int num_microbatches); + void SetupSchedule(int num_micro_batches); }; } // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/pp/pipeline_schedule.h b/infini_train/include/nn/parallel/pp/pipeline_schedule.h index 0e51b51a..e5f13f82 100644 --- a/infini_train/include/nn/parallel/pp/pipeline_schedule.h +++ b/infini_train/include/nn/parallel/pp/pipeline_schedule.h @@ -1,55 +1,60 @@ #pragma once + #include #include -#include "infini_train/include/nn/parallel/pp/pipeline_stage.h" +#include "infini_train/include/nn/modules/module.h" + +namespace infini_train { +class Tensor; +} // namespace infini_train namespace infini_train::nn::parallel { +class PipelineStage; + class PipelineSchedule { public: - PipelineSchedule(std::shared_ptr stage, int num_stages, int num_microbatches, int stage_index) - : stage_(std::move(stage)), num_microbatches_(num_microbatches), stage_index_(stage_index) {} + PipelineSchedule(std::shared_ptr stage, int num_stages, int num_micro_batches, int stage_index) + : stage_(std::move(stage)), num_micro_batches_(num_micro_batches), stage_index_(stage_index) {} virtual ~PipelineSchedule() = default; - float Step(std::shared_ptr input, std::shared_ptr target, const std::shared_ptr &loss_fn); + float Step(std::shared_ptr input, std::shared_ptr target, + const std::shared_ptr &loss_fn); - virtual float StepMicrobatches(const std::vector> &arg_mbs, + virtual float StepMicroBatches(const std::vector> &arg_mbs, const std::vector> &target_mbs, - const std::shared_ptr &loss_fn) + const std::shared_ptr &loss_fn) = 0; std::vector> ReceiveFromPrev(); std::vector> SendToNext(const std::vector> &tensors); protected: - int num_microbatches_; - int stage_index_; - std::shared_ptr stage_; - -private: - std::vector> SplitTensor(std::shared_ptr full_inputs); + int num_micro_batches_ = -1; + int stage_index_ = -1; + std::shared_ptr stage_ = nullptr; }; class ScheduleGPipe : public PipelineSchedule { public: - ScheduleGPipe(std::shared_ptr stage, int num_stages, int num_microbatches, int stage_index) - : PipelineSchedule(std::move(stage), num_stages, num_microbatches, stage_index){}; + ScheduleGPipe(std::shared_ptr stage, int num_stages, int num_micro_batches, int stage_index) + : PipelineSchedule(std::move(stage), num_stages, num_micro_batches, stage_index){}; - float StepMicrobatches(const std::vector> &arg_mbs, + float StepMicroBatches(const std::vector> &arg_mbs, const std::vector> &target_mbs, - const std::shared_ptr &loss_fn) override; + const std::shared_ptr &loss_fn) override; }; class Schedule1F1B : public PipelineSchedule { public: - Schedule1F1B(std::shared_ptr stage, int num_stages, int num_microbatches, int stage_index) - : PipelineSchedule(std::move(stage), num_stages, num_microbatches, stage_index){}; + Schedule1F1B(std::shared_ptr stage, int num_stages, int num_micro_batches, int stage_index) + : PipelineSchedule(std::move(stage), num_stages, num_micro_batches, stage_index){}; - float StepMicrobatches(const std::vector> &arg_mbs, + float StepMicroBatches(const std::vector> &arg_mbs, const std::vector> &target_mbs, - const std::shared_ptr &loss_fn) override; + const std::shared_ptr &loss_fn) override; }; } // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/pp/pipeline_stage.h b/infini_train/include/nn/parallel/pp/pipeline_stage.h index 0f0a1515..7d6b4086 100644 --- a/infini_train/include/nn/parallel/pp/pipeline_stage.h +++ b/infini_train/include/nn/parallel/pp/pipeline_stage.h @@ -3,17 +3,20 @@ #include #include -#include "infini_train/include/device.h" #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/optimizer.h" -#include "infini_train/include/tensor.h" + +namespace infini_train { +class Tensor; +class Device; +} // namespace infini_train namespace infini_train::nn::parallel { class PipelineStage { public: - PipelineStage(const std::vector> &layers, int stage_index, int num_stages, - const std::vector> &recvShape, std::shared_ptr optim); + PipelineStage(const std::shared_ptr &model, int stage_index, int num_stages, + const std::vector> &recv_shape, std::shared_ptr optimizer); std::vector> ForwardOneChunk(const std::vector> &inputs); @@ -24,18 +27,18 @@ class PipelineStage { int next_rank() const { return next_rank_; } int num_stages() const { return num_stages_; } const Device *device() const { return device_; } - std::vector> recv_shape() const { return recv_shape_; } - std::shared_ptr optimizer() { return optim_; } + const std::vector> &recv_shape() const { return recv_shape_; } + std::shared_ptr optimizer() { return optimizer_; } private: - int stage_index_; - int num_stages_; - int prev_rank_; - int next_rank_; + int stage_index_ = -1; + int num_stages_ = -1; + int prev_rank_ = -1; + int next_rank_ = -1; const Device *device_ = nullptr; + std::shared_ptr model_ = nullptr; + std::shared_ptr optimizer_ = nullptr; std::vector> recv_shape_; - std::vector> layers_; - std::shared_ptr optim_; }; } // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/pp/send_recv.h b/infini_train/include/nn/parallel/pp/send_recv.h index be5fe47f..f76f4c72 100644 --- a/infini_train/include/nn/parallel/pp/send_recv.h +++ b/infini_train/include/nn/parallel/pp/send_recv.h @@ -3,13 +3,16 @@ #include #include -#include "infini_train/include/tensor.h" +namespace infini_train { +class Tensor; +class Device; +} // namespace infini_train namespace infini_train::nn::parallel { std::vector> ISend(const std::vector> &input_tensors, const Device *target_device, int cur_rank, int peer_rank, - std::vector> shape); + const std::vector> &shape); std::vector> IRecv(const std::vector> &outputs, const Device *src_device, int cur_rank, int peer_rank); diff --git a/infini_train/include/nn/parallel/utils.h b/infini_train/include/nn/parallel/utils.h index aa7a473a..3eb3960d 100644 --- a/infini_train/include/nn/parallel/utils.h +++ b/infini_train/include/nn/parallel/utils.h @@ -13,4 +13,6 @@ std::string GetPipelineParallelProcessGroupName(int thread_rank); std::vector GetDataParallelGroupRanks(int rank); std::vector GetTensorParallelGroupRanks(int rank); + +std::vector GetPipelineParallelGroupRanks(int pp_world_size); } // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/pp/pipeline_parallel.cc b/infini_train/src/nn/parallel/pp/pipeline_parallel.cc index 20e8ef36..6fffcfba 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_parallel.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_parallel.cc @@ -1,7 +1,7 @@ // pipeline_parallel.cc #include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" -#include +#include #include #include "infini_train/include/nn/modules/container.h" @@ -12,73 +12,16 @@ namespace infini_train::nn::parallel { -std::vector>> -PipelineParallel::SplitLayersIntoStages(std::vector> layers) { - const int total_layers = layers.size(); - CHECK_GT(total_layers, 0) << "Model has no layers to split!"; - CHECK_GE(num_stages_, 1) << "num_stages must be >= 1"; - CHECK_LE(num_stages_, total_layers) << "num_stages (" << num_stages_ << ") cannot be greater than total layers (" - << total_layers << ")"; +thread_local int pp_rank = 0; - std::vector>> stages(num_stages_); - int base_layers_per_stage = total_layers / num_stages_; - int remainder = total_layers % num_stages_; - int layer_idx = 0; - - for (int s = 0; s < num_stages_; ++s) { - int layers_in_this_stage = base_layers_per_stage + (s < remainder ? 1 : 0); - for (int i = 0; i < layers_in_this_stage; ++i) { - auto layer = layers[layer_idx]; - stages[s].emplace_back(layer); - layer_idx++; - } - } - - return stages; -} - -std::vector> -PipelineParallel::CreateOptimizers(const std::vector>> &stage_layers, - OptimizerFactory optimizer_factory) { - std::vector> optims; - optims.reserve(stage_layers.size()); - - for (int s = 0; s < num_stages_; ++s) { - std::vector> params; - for (const auto &layer : stage_layers[s]) { - layer->To(devices_[s]); - auto layer_params = layer->Parameters(); - params.insert(params.end(), layer_params.begin(), layer_params.end()); - } - - auto optim = optimizer_factory(params); - CHECK(optim != nullptr) << "Optimizer factory returned null optimizer for stage " << s; - optims.push_back(std::move(optim)); - } - return optims; -} - -void PipelineParallel::BuildPipelineStage(const std::vector>> &stage_layers, - const std::vector> &optimizers, +void PipelineParallel::BuildPipelineStage(const std::shared_ptr &module, + const std::shared_ptr &optimizer, const std::vector> &recv_shape) { - pipeline_stage_ - = std::make_shared(stage_layers[rank_], rank_, num_stages_, recv_shape, optimizers[rank_]); + pipeline_stage_ = std::make_shared(module, rank_, num_stages_, recv_shape, optimizer); } -void PipelineParallel::SplitModel(const std::vector> &recv_shape, - OptimizerFactory optimizer_factory) { - auto layers = original_model_->GetPipelineLayers(); - CHECK(!layers.empty()) << "SplitModel: GetPipelineLayers returned empty vector"; - - auto stage_layer = SplitLayersIntoStages(layers); - - auto optimizer = CreateOptimizers(stage_layer, optimizer_factory); - - BuildPipelineStage(stage_layer, optimizer, recv_shape); -} - -void PipelineParallel::SetupSchedule(int num_microbatches) { - schedule_ = std::make_shared(pipeline_stage_, num_stages_, num_microbatches, rank_); +void PipelineParallel::SetupSchedule(int num_micro_batches) { + schedule_ = std::make_shared(pipeline_stage_, num_stages_, num_micro_batches, rank_); } float PipelineParallel::TrainStep(const std::vector> &input, @@ -93,16 +36,32 @@ float PipelineParallel::TrainStep(const std::vector> &in return schedule_->Step(stage_input, stage_target, loss_fn); } -PipelineParallel::PipelineParallel(const std::shared_ptr &model, int num_stages, int num_microbatches, - const std::vector> &recv_shape, int rank, - OptimizerFactory optimizer_factory) - : original_model_(model), devices_(DeviceManager::Instance()->GetAllAvailableDevices(DeviceType::kCUDA)), - num_stages_(num_stages), rank_(rank) { - CHECK(!devices_.empty()) << "Devices list is empty"; +std::tuple PipelineParallel::GetStageInfo(int total_layers, int pp_size) { + int rank = pp_rank; + bool is_first_stage = (pp_rank == 0); + bool is_last_stage = (pp_rank == pp_size - 1); + + int layers_per_stage = total_layers / pp_size; + int remainder = total_layers % pp_size; + int start_layer, end_layer; + if (pp_rank < remainder) { + start_layer = pp_rank * (layers_per_stage + 1); + end_layer = start_layer + layers_per_stage + 1; + } else { + start_layer = pp_rank * layers_per_stage + remainder; + end_layer = start_layer + layers_per_stage; + } - SplitModel(recv_shape, optimizer_factory); + return {is_first_stage, is_last_stage, start_layer, end_layer}; +} + +PipelineParallel::PipelineParallel(const std::shared_ptr module, int num_stages, int num_micro_batches, + const std::vector> &recv_shape, int rank, + const std::shared_ptr &optimizer) + : num_stages_(num_stages), rank_(rank) { + BuildPipelineStage(module, optimizer, recv_shape); - SetupSchedule(num_microbatches); + SetupSchedule(num_micro_batches); } } // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/pp/pipeline_schedule.cc b/infini_train/src/nn/parallel/pp/pipeline_schedule.cc index 34bbae19..884e24a7 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_schedule.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_schedule.cc @@ -1,119 +1,82 @@ // pipeline_schedule.cc #include "infini_train/include/nn/parallel/pp/pipeline_schedule.h" -#include "glog/logging.h" #include #include -#include -#include #include +#include "glog/logging.h" + +#include "infini_train/include/autograd/grad_mode.h" #include "infini_train/include/device.h" #include "infini_train/include/nn/init.h" #include "infini_train/include/nn/modules/module.h" -#include "infini_train/include/nn/parallel/parallel_functional.h" +#include "infini_train/include/nn/parallel/pp/pipeline_stage.h" #include "infini_train/include/nn/parallel/pp/send_recv.h" #include "infini_train/include/tensor.h" namespace infini_train::nn::parallel { -std::vector> PipelineSchedule::SplitTensor(std::shared_ptr full_inputs) { - const auto n = num_microbatches_; - if (n == 1) { - return {full_inputs}; - } - - const auto &first_dims = full_inputs->Dims(); - if (first_dims.empty()) { - LOG(FATAL) << "SplitTensor: tensor has no dimensions."; - } - int64_t batch_size = first_dims[0]; - int microbatch_size = batch_size / n; - int remainder = batch_size % n; - - std::vector> micro_batches; - - int start_idx = 0; - int end_idx = 0; - for (int mb = 0; mb < n; ++mb) { - int current_size = microbatch_size + (mb == n - 1 ? remainder : 0); - end_idx = start_idx + current_size; - - if (start_idx < 0 || end_idx > batch_size || start_idx >= end_idx) { - LOG(FATAL) << "Invalid slice range: [%d, %d), batch_size=%ld" << start_idx << end_idx << batch_size; - } - - if (full_inputs->Dims()[0] != batch_size) { - LOG(FATAL) << "SplitTensor: tensor size mismatch on dim 0."; - } - - auto sliced = full_inputs->Slice(0, start_idx, end_idx); - - micro_batches.push_back(sliced); - - start_idx = end_idx; - } - - return micro_batches; -} - float PipelineSchedule::Step(std::shared_ptr input, std::shared_ptr target, const std::shared_ptr &loss_fn) { - std::vector> micro_batches(num_microbatches_); - std::vector> target_mbs(num_microbatches_); + std::vector> micro_batches(num_micro_batches_); + std::vector> target_mbs(num_micro_batches_); if (stage_->IsFirstStage()) { - micro_batches = SplitTensor(input); + { + autograd::NoGradGuard no_grad; + micro_batches = input->Split(input->Dims()[0] / num_micro_batches_); + } } + if (stage_->IsLastStage()) { - target_mbs = SplitTensor(target); + { + autograd::NoGradGuard no_grad; + target_mbs = target->Split(target->Dims()[0] / num_micro_batches_); + } } - const auto &optim = stage_->optimizer(); + const auto &optimizer = stage_->optimizer(); - optim->ZeroGrad(); + optimizer->ZeroGrad(); - float lossf = StepMicrobatches(micro_batches, target_mbs, loss_fn); + float lossf = StepMicroBatches(micro_batches, target_mbs, loss_fn); - optim->Step(); + optimizer->Step(); return lossf; } std::vector> PipelineSchedule::ReceiveFromPrev() { std::vector> recv_tensors; - if (!stage_->IsFirstStage()) { - auto shapes = stage_->recv_shape(); - for (size_t i = 0; i < shapes.size(); ++i) { - auto tensor = std::make_shared(shapes[i], DataType::kFLOAT32, stage_->device()); - tensor->set_requires_grad(true); - tensor->set_is_leaf(false); - recv_tensors.push_back(tensor); - } - return IRecv(recv_tensors, stage_->device(), stage_->stage_index(), stage_->prev_rank()); + auto &shapes = stage_->recv_shape(); + for (size_t i = 0; i < shapes.size(); ++i) { + // FIXME(jym): The data type between stages is not float32, which will cause a crash + auto tensor = std::make_shared(shapes[i], DataType::kFLOAT32, stage_->device()); + tensor->set_requires_grad(true); + tensor->set_is_leaf(false); + recv_tensors.push_back(tensor); } - return recv_tensors; + + return IRecv(recv_tensors, stage_->device(), stage_->stage_index(), stage_->prev_rank()); } std::vector> PipelineSchedule::SendToNext(const std::vector> &tensors) { - if (!stage_->IsLastStage()) { - return ISend(tensors, stage_->device(), stage_->stage_index(), stage_->next_rank(), stage_->recv_shape()); - } - return tensors; + return ISend(tensors, stage_->device(), stage_->stage_index(), stage_->next_rank(), stage_->recv_shape()); } -float ScheduleGPipe::StepMicrobatches(const std::vector> µbatch_inputs, +float ScheduleGPipe::StepMicroBatches(const std::vector> µbatch_inputs, const std::vector> µbatch_targets, const std::shared_ptr &loss_fn) { - const auto n = num_microbatches_; + const auto n = num_micro_batches_; if (n == 0) { return 0.0f; } + std::vector>> outputs(n); // ======== Forward Pass ======== for (int mb = 0; mb < n; ++mb) { std::vector> inputs; - if (stage_->IsFirstStage()) { inputs = {microbatch_inputs[mb]}; } else { @@ -122,13 +85,9 @@ float ScheduleGPipe::StepMicrobatches(const std::vector> outputs[mb] = stage_->ForwardOneChunk(inputs); - for (auto &t : outputs[mb]) { - if (!t) { - t = std::make_shared((std::vector){}, DataType::kFLOAT32, stage_->device()); - } + if (!stage_->IsLastStage()) { + outputs[mb] = SendToNext(outputs[mb]); } - - outputs[mb] = SendToNext(outputs[mb]); } // ======== Backward Pass ======== @@ -153,8 +112,7 @@ float ScheduleGPipe::StepMicrobatches(const std::vector> auto target_on_device = target->To(output->GetDevice()); auto loss = loss_fn->Forward({output, std::make_shared(target_on_device)})[0]; if (!loss) { - LOG(INFO) << "[ERROR] loss is nullptr at mb = " << mb; - continue; + LOG(FATAL) << "[ERROR] loss is nullptr at mb = " << mb; } loss = loss / n; diff --git a/infini_train/src/nn/parallel/pp/pipeline_stage.cc b/infini_train/src/nn/parallel/pp/pipeline_stage.cc index fb6d13be..f09ae13a 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_stage.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_stage.cc @@ -1,32 +1,24 @@ #include "infini_train/include/nn/parallel/pp/pipeline_stage.h" -#include "glog/logging.h" - #include +#include "glog/logging.h" + #include "infini_train/include/device.h" -#include "infini_train/include/dispatcher.h" -#include "infini_train/include/nn/init.h" namespace infini_train::nn::parallel { -PipelineStage::PipelineStage(const std::vector> &layers, int stage_index, int num_stages, - const std::vector> &recvShape, std::shared_ptr optim) - : stage_index_(stage_index), num_stages_(num_stages), layers_(layers), +PipelineStage::PipelineStage(const std::shared_ptr &model, int stage_index, int num_stages, + const std::vector> &recv_shape, std::shared_ptr optimizer) + : model_(model), stage_index_(stage_index), num_stages_(num_stages), prev_rank_(stage_index > 0 ? stage_index - 1 : -1), - next_rank_(stage_index < num_stages - 1 ? stage_index + 1 : -1), recv_shape_(recvShape), optim_(std::move(optim)), + next_rank_(stage_index < num_stages - 1 ? stage_index + 1 : -1), recv_shape_(recv_shape), + optimizer_(std::move(optimizer)), device_(DeviceManager::Instance()->GetAllAvailableDevices(DeviceType::kCUDA).at(stage_index)) {} std::vector> PipelineStage::ForwardOneChunk(const std::vector> &inputs) { - std::vector> current = inputs; - int i = 0; - for (const auto &layer : layers_) { - current = layer->Forward(current); - ++i; - } - - return current; + return model_->Forward(inputs); } } // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/pp/send_recv.cc b/infini_train/src/nn/parallel/pp/send_recv.cc index d132ba49..31cd0651 100644 --- a/infini_train/src/nn/parallel/pp/send_recv.cc +++ b/infini_train/src/nn/parallel/pp/send_recv.cc @@ -1,5 +1,6 @@ #include "infini_train/include/nn/parallel/pp/send_recv.h" +#include #include #include @@ -18,15 +19,13 @@ class ISend : public autograd::Function { public: static constexpr char kType[] = "ISendFunction"; - explicit ISend(const Device *target_device, int cur_rank, int peer_rank, std::vector> shape) + explicit ISend(const Device *target_device, int cur_rank, int peer_rank, + const std::vector> &shape) : autograd::Function(kType), target_device_(target_device), cur_rank_(cur_rank), peer_rank_(peer_rank), shapes_(shape) {} std::vector> Forward(const std::vector> &input_tensors) override; - void SetupContext(const std::vector> &input_tensors, - const std::vector> &output_tensors) override; - std::vector> Backward(const std::vector> &grad_outputs) override; private: @@ -34,7 +33,7 @@ class ISend : public autograd::Function { const Device *input_device_ = nullptr; int cur_rank_ = -1; int peer_rank_ = -1; - std::vector> shapes_; + const std::vector> &shapes_; }; class IRecv : public autograd::Function { @@ -72,14 +71,10 @@ std::vector> ISend::Forward(const std::vector> &input_tensors, - const std::vector> &output_tensors) {} - std::vector> ISend::Backward(const std::vector> &grad_outputs) { - auto shapes = shapes_; std::vector> recv_tensors; - for (int shape_i = 0; shape_i < shapes.size(); ++shape_i) { - auto r_tensor = std::make_shared(shapes[shape_i], DataType::kFLOAT32, input_device_); + for (int shape_i = 0; shape_i < shapes_.size(); ++shape_i) { + auto r_tensor = std::make_shared(shapes_[shape_i], DataType::kFLOAT32, input_device_); recv_tensors.push_back(r_tensor); } @@ -90,8 +85,7 @@ std::vector> ISend::Backward(const std::vector> IRecv::Forward(const std::vector> &recv_tensors) { - CHECK_NE(src_device_, nullptr) << "src_device_ must be set"; - + CHECK_NOTNULL(src_device_); auto pp_group = ProcessGroupFactory::Instance()->Get(GetPipelineParallelProcessGroupName(src_device_->rank().thread_rank())); pp_group->NcclRecv(recv_tensors, peer_rank_); @@ -122,7 +116,7 @@ std::vector> IRecv::Backward(const std::vector> ISend(const std::vector> &input_tensors, const Device *target_device, int cur_rank, int peer_rank, - std::vector> shape) { + const std::vector> &shape) { auto func = std::make_shared(target_device, cur_rank, peer_rank, shape); return func->Apply(input_tensors); } diff --git a/infini_train/src/nn/parallel/utils.cc b/infini_train/src/nn/parallel/utils.cc index 5456e4e7..4f661880 100644 --- a/infini_train/src/nn/parallel/utils.cc +++ b/infini_train/src/nn/parallel/utils.cc @@ -20,4 +20,10 @@ std::vector GetDataParallelGroupRanks(int thread_rank) { return global::Get std::vector GetTensorParallelGroupRanks(int thread_rank) { return global::GetGroupRanks(global::TP, thread_rank); } +std::vector GetPipelineParallelGroupRanks(int pp_world_size) { + std::vector ranks; + ranks.reserve(pp_world_size); + for (int i = 0; i < pp_world_size; ++i) { ranks.push_back(i); } + return ranks; +} } // namespace infini_train::nn::parallel diff --git a/infini_train/src/optimizer.cc b/infini_train/src/optimizer.cc index 988cfab4..c163ea61 100644 --- a/infini_train/src/optimizer.cc +++ b/infini_train/src/optimizer.cc @@ -20,6 +20,10 @@ SGD::SGD(const std::vector> ¶ms, float learning_rate void SGD::Step() { for (auto param : params_) { + if (!param->grad()) { + LOG(INFO) << "Skipping param with null grad."; + continue; + } auto device = param->GetDevice(); device->SetDevice(); auto kernel = Dispatcher::Instance().GetKernel({device->Type(), "AccumulateGrad"});