-
Notifications
You must be signed in to change notification settings - Fork 13
feat: add pipeline parallel #88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,4 +30,5 @@ void ReadVectorAllFloat(std::ifstream &ifs, float *dst, int64_t len); | |
|
|
||
| void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t start, int64_t cnt); | ||
|
|
||
| std::vector<int> GetPipelineParallelGroupRanks(int rank); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 挪到 infini_train/src/nn/parallel/utils.h 里 |
||
| } // namespace infini_train | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| #include "infini_train/include/nn/parallel/distributed_data_parallel.h" | ||
| #include "infini_train/include/nn/parallel/global.h" | ||
| #include "infini_train/include/nn/parallel/parallel_functional.h" | ||
| #include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" | ||
| #include "infini_train/include/nn/parallel/rank.h" | ||
| #include "infini_train/include/nn/parallel/reduce_op_type.h" | ||
| #include "infini_train/include/nn/parallel/tensor_parallel.h" | ||
|
|
@@ -63,6 +64,10 @@ DEFINE_int32( | |
| "When set > 1, enables data parallelism with device=cuda on the specified number of visible CUDA devices."); | ||
| DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size"); | ||
| DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel"); | ||
| DEFINE_uint32( | ||
| pipeline_parallel, 1, | ||
| "Pipeline Parallel world size, will always use device=cuda and use all cuda visible devices when set to true"); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
pp 现在应该是能设置使用的卡数的?这里是忘改了吗? |
||
|
|
||
| // precision | ||
| DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)"); | ||
|
|
||
|
|
@@ -106,6 +111,7 @@ void Train(const nn::parallel::Rank &rank) { | |
| int ddp_world_size = global::GetDataParallelSize(); | ||
| int tp_world_size = global::GetTensorParallelSize(); | ||
| int sp_world_size = global::GetSequenceParallelEnabled() ? tp_world_size : 0; | ||
| int pp_world_size = global::GetPipelineParallelSize(); | ||
|
|
||
| if (FLAGS_sequence_parallel) { | ||
| CHECK_EQ(FLAGS_sequence_length % tp_world_size, 0) | ||
|
|
@@ -114,9 +120,11 @@ void Train(const nn::parallel::Rank &rank) { | |
|
|
||
| int ddp_rank = 0; | ||
| int tp_rank = 0; | ||
| int pp_rank = 0; | ||
|
|
||
| const ProcessGroup *ddp_pg = nullptr; | ||
| const ProcessGroup *tp_pg = nullptr; | ||
| const ProcessGroup *pp_pg = nullptr; | ||
|
|
||
| if (rank.IsParallel()) { | ||
| device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank()); | ||
|
|
@@ -134,6 +142,14 @@ void Train(const nn::parallel::Rank &rank) { | |
| // NOTE(zbl): Reserved for VocabParallelEmbedding | ||
| nn::parallel::tp_rank = tp_rank; | ||
| } | ||
|
|
||
| if (pp_world_size > 1) { | ||
| pp_pg = ProcessGroupFactory::Instance()->GetOrCreate( | ||
| GetPipelineParallelProcessGroupName(rank.thread_rank()), GetPipelineParallelGroupRanks(pp_world_size)); | ||
| pp_rank = pp_pg->GetGroupRank(rank.thread_rank()); | ||
|
|
||
| nn::parallel::pp_rank = pp_rank; | ||
| } | ||
| } else { | ||
| device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDefaultDevice() | ||
| : DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, 0); | ||
|
|
@@ -182,8 +198,10 @@ void Train(const nn::parallel::Rank &rank) { | |
| model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank()); | ||
| } | ||
|
|
||
| auto num_microbatches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 命名:num_micro_batches |
||
| DistributedDataLoader train_loader(std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length), | ||
| FLAGS_batch_size, ddp_rank, ddp_world_size); | ||
| FLAGS_batch_size * num_microbatches, ddp_rank, ddp_world_size); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里得为 pp 专门区分,因为在非 pp 情况下,num_microbatches=grad_accum_steps,而非 pp 在梯度累积情况下每次是只读取 batch_size 大小数据,靠外层循环多次进行的梯度累计 |
||
|
|
||
| std::optional<DistributedDataLoader> val_loader = std::nullopt; | ||
| if (!FLAGS_input_val_bin.empty()) { | ||
| val_loader = DistributedDataLoader( | ||
|
|
@@ -201,7 +219,11 @@ void Train(const nn::parallel::Rank &rank) { | |
| } | ||
|
|
||
| // TODO(dcj): support more complex optimizer later | ||
| auto optimizer = optimizers::SGD(model->Parameters(), FLAGS_learning_rate); | ||
| auto lr = FLAGS_learning_rate; | ||
| auto optimizer_factory = [lr](const std::vector<std::shared_ptr<Tensor>> ¶ms) { | ||
| return std::make_shared<optimizers::SGD>(params, lr); | ||
| }; | ||
| auto optimizer = optimizer_factory(model->Parameters()); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里为什么不直接初始化 optimizer 呢? |
||
|
|
||
| auto train_iter = train_loader.begin(); | ||
| std::shared_ptr<nn::Module> loss_fn | ||
|
|
@@ -211,6 +233,15 @@ void Train(const nn::parallel::Rank &rank) { | |
| loss_fn->To(device); | ||
| LOG(INFO) << "Rank " << rank.thread_rank() << ": start training"; | ||
|
|
||
| if (pp_world_size > 1) { | ||
| auto shapes = std::vector<std::vector<int64_t>>{{FLAGS_batch_size, FLAGS_sequence_length, model_config.n_embd}}; | ||
|
|
||
| model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, num_microbatches, shapes, | ||
| pp_rank, optimizer_factory); | ||
| } | ||
|
|
||
| LOG(INFO) << "start training"; | ||
|
|
||
| for (int step = 0; step < FLAGS_num_iteration + 1; ++step) { | ||
| const bool last_step = step == FLAGS_num_iteration; | ||
|
|
||
|
|
@@ -233,58 +264,78 @@ void Train(const nn::parallel::Rank &rank) { | |
| break; | ||
| } | ||
|
|
||
| // model->Train(); | ||
| optimizer.ZeroGrad(); | ||
| // if we are trying to overfit a single batch, we reset the loader here | ||
| if (FLAGS_overfit_single_batch) { | ||
| // train_loader.Reset(); | ||
| } | ||
| float lossf = 0.0f; | ||
| // model->Train(); | ||
| if (pp_world_size == 1) { | ||
| optimizer->ZeroGrad(); | ||
|
|
||
| // if we are trying to overfit a single batch, we reset the loader here | ||
| if (FLAGS_overfit_single_batch) { | ||
| // train_loader.Reset(); | ||
| } | ||
|
|
||
| #ifdef PROFILE_MODE | ||
| Profiler::Instance().SetTag("Step_" + std::to_string(step)); | ||
| Profiler::Instance().SetTag("Step_" + std::to_string(step)); | ||
| #endif | ||
| for (int micro_step = 0; micro_step < grad_accum_steps; ++micro_step) { | ||
| // enable autocast for the current step | ||
| infini_train::AutocastGuard autocast_guard(device->Type(), dtype); | ||
| for (int micro_step = 0; micro_step < grad_accum_steps; ++micro_step) { | ||
| // enable autocast for the current step | ||
| infini_train::AutocastGuard autocast_guard(device->Type(), dtype); | ||
|
|
||
| // (bs, seq_len), (bs, seq_len) | ||
| auto [x, y] = *train_iter; | ||
| // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below | ||
| // TODO(dcj): support dataloader.reset() later | ||
| ++train_iter; | ||
| x = std::make_shared<Tensor>(x->To(device)); | ||
| y = std::make_shared<Tensor>(y->To(device)); | ||
|
|
||
| LOG(INFO) << "Rank " << rank.thread_rank() << ": start forward"; | ||
| // (bs, seq_len, vocab_size) | ||
| auto logits = model->Forward({x, y})[0]; | ||
| LOG(INFO) << "Rank " << rank.thread_rank() << ": finish model forward, start loss forward"; | ||
| auto loss = loss_fn->Forward({logits, y})[0]; | ||
| loss = loss / grad_accum_steps; | ||
|
|
||
| // disable autocast for the current step (backward is not under autocast) | ||
| autocast_guard.Disable(); | ||
|
|
||
| LOG(INFO) << "Rank " << rank.thread_rank() << ": finish loss forward"; | ||
| if (ddp_world_size > 1) { | ||
| function::AllReduce(loss, function::ReduceOpType::kAvg); | ||
| } | ||
| auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice()); | ||
| lossf += static_cast<const float *>(loss_cpu.DataPtr())[0]; | ||
| LOG(INFO) << "Rank " << rank.thread_rank() << ": start backward"; | ||
| loss->Backward(); | ||
| LOG(INFO) << "Rank " << rank.thread_rank() << ": finish backward"; | ||
| } | ||
|
|
||
| // (bs, seq_len), (bs, seq_len) | ||
| optimizer->Step(); | ||
| } else { | ||
| auto [x, y] = *train_iter; | ||
| // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below | ||
| // TODO(dcj): support dataloader.reset() later | ||
| ++train_iter; | ||
| x = std::make_shared<Tensor>(x->To(device)); | ||
| y = std::make_shared<Tensor>(y->To(device)); | ||
| LOG(INFO) << "Rank " << rank.thread_rank() << ": start forward"; | ||
| // (bs, seq_len, vocab_size) | ||
| auto logits = model->Forward({x, y})[0]; | ||
| LOG(INFO) << "Rank " << rank.thread_rank() << ": finish model forward, start loss forward"; | ||
| auto loss = loss_fn->Forward({logits, y})[0]; | ||
| loss = loss / grad_accum_steps; | ||
|
|
||
| // disable autocast for the current step (backward is not under autocast) | ||
| autocast_guard.Disable(); | ||
|
|
||
| LOG(INFO) << "Rank " << rank.thread_rank() << ": finish loss forward"; | ||
| if (ddp_world_size > 1) { | ||
| function::AllReduce(loss, function::ReduceOpType::kAvg); | ||
| } | ||
| auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice()); | ||
| lossf += static_cast<const float *>(loss_cpu.DataPtr())[0]; | ||
| LOG(INFO) << "Rank " << rank.thread_rank() << ": start backward"; | ||
| loss->Backward(); | ||
| LOG(INFO) << "Rank " << rank.thread_rank() << ": finish backward"; | ||
| } | ||
| optimizer.Step(); | ||
|
|
||
| lossf = model->TrainStep({x}, {y}, loss_fn); | ||
| auto loss_tensor = std::make_shared<Tensor>(std::vector<int64_t>{}, DataType::kFLOAT32); | ||
| static_cast<float *>(loss_tensor->DataPtr())[0] = lossf; | ||
| auto loss_device_ptr = std::make_shared<Tensor>(loss_tensor->To(device)); | ||
| function::AllReduce(loss_device_ptr, function::ReduceOpType::kMax); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里为什么是取最大的,而不是直接取最后一个 stage 的 loss |
||
| auto loss_copy = loss_device_ptr->To(DeviceManager::Instance()->GetDefaultDevice()); | ||
| lossf = static_cast<const float *>(loss_copy.DataPtr())[0]; | ||
| } | ||
| const auto iter_end = std::chrono::high_resolution_clock::now(); | ||
| const double duration_us = std::chrono::duration<double, std::micro>(iter_end - iter_start).count(); | ||
| const double tps = FLAGS_total_batch_size / (duration_us / 1e6); | ||
|
|
||
| if (rank.IsMainRank()) { | ||
| LOG(ERROR) << std::format( | ||
| "step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, DP={}, TP={}, SP={})", | ||
| step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, tps, ddp_world_size, | ||
| tp_world_size, sp_world_size); | ||
| LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, " | ||
| "DP={}, TP={}, SP={}, PP={})", | ||
| step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, | ||
| tps, ddp_world_size, tp_world_size, sp_world_size, pp_world_size); | ||
|
|
||
| if ((step + 1) % FLAGS_freq_generate_txt == 0) { | ||
| if (!tokenizer) { | ||
|
|
@@ -304,7 +355,8 @@ int main(int argc, char *argv[]) { | |
| gflags::ParseCommandLineFlags(&argc, &argv, true); | ||
| google::InitGoogleLogging(argv[0]); | ||
|
|
||
| nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel); | ||
| nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel, | ||
| FLAGS_pipeline_parallel); | ||
|
|
||
| LOG(INFO) << nn::parallel::global::ProcessGroupOverview(); | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数挪到 infini_train/src/nn/parallel/utils.cc 里,与 ddp/tp 保持一致