Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions example/common/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,11 @@ void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t s
ifs.seekg(base + std::streamoff(len * sizeof(float)));
}

std::vector<int> GetPipelineParallelGroupRanks(int pp_world_size) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数挪到 infini_train/src/nn/parallel/utils.cc 里,与 ddp/tp 保持一致

std::vector<int> ranks;
ranks.reserve(pp_world_size);
for (int i = 0; i < pp_world_size; ++i) { ranks.push_back(i); }
return ranks;
}

} // namespace infini_train
1 change: 1 addition & 0 deletions example/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ void ReadVectorAllFloat(std::ifstream &ifs, float *dst, int64_t len);

void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t start, int64_t cnt);

std::vector<int> GetPipelineParallelGroupRanks(int rank);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

挪到 infini_train/src/nn/parallel/utils.h 里

} // namespace infini_train
130 changes: 91 additions & 39 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "infini_train/include/nn/parallel/distributed_data_parallel.h"
#include "infini_train/include/nn/parallel/global.h"
#include "infini_train/include/nn/parallel/parallel_functional.h"
#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h"
#include "infini_train/include/nn/parallel/rank.h"
#include "infini_train/include/nn/parallel/reduce_op_type.h"
#include "infini_train/include/nn/parallel/tensor_parallel.h"
Expand Down Expand Up @@ -63,6 +64,10 @@ DEFINE_int32(
"When set > 1, enables data parallelism with device=cuda on the specified number of visible CUDA devices.");
DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size");
DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel");
DEFINE_uint32(
pipeline_parallel, 1,
"Pipeline Parallel world size, will always use device=cuda and use all cuda visible devices when set to true");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use all cuda visible devices when set to true

pp 现在应该是能设置使用的卡数的?这里是忘改了吗?


// precision
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");

Expand Down Expand Up @@ -106,6 +111,7 @@ void Train(const nn::parallel::Rank &rank) {
int ddp_world_size = global::GetDataParallelSize();
int tp_world_size = global::GetTensorParallelSize();
int sp_world_size = global::GetSequenceParallelEnabled() ? tp_world_size : 0;
int pp_world_size = global::GetPipelineParallelSize();

if (FLAGS_sequence_parallel) {
CHECK_EQ(FLAGS_sequence_length % tp_world_size, 0)
Expand All @@ -114,9 +120,11 @@ void Train(const nn::parallel::Rank &rank) {

int ddp_rank = 0;
int tp_rank = 0;
int pp_rank = 0;

const ProcessGroup *ddp_pg = nullptr;
const ProcessGroup *tp_pg = nullptr;
const ProcessGroup *pp_pg = nullptr;

if (rank.IsParallel()) {
device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank());
Expand All @@ -134,6 +142,14 @@ void Train(const nn::parallel::Rank &rank) {
// NOTE(zbl): Reserved for VocabParallelEmbedding
nn::parallel::tp_rank = tp_rank;
}

if (pp_world_size > 1) {
pp_pg = ProcessGroupFactory::Instance()->GetOrCreate(
GetPipelineParallelProcessGroupName(rank.thread_rank()), GetPipelineParallelGroupRanks(pp_world_size));
pp_rank = pp_pg->GetGroupRank(rank.thread_rank());

nn::parallel::pp_rank = pp_rank;
}
} else {
device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDefaultDevice()
: DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, 0);
Expand Down Expand Up @@ -182,8 +198,10 @@ void Train(const nn::parallel::Rank &rank) {
model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank());
}

auto num_microbatches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

命名:num_micro_batches
llama 同理

DistributedDataLoader train_loader(std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length),
FLAGS_batch_size, ddp_rank, ddp_world_size);
FLAGS_batch_size * num_microbatches, ddp_rank, ddp_world_size);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里得为 pp 专门区分,因为在非 pp 情况下,num_microbatches=grad_accum_steps,而非 pp 在梯度累积情况下每次是只读取 batch_size 大小数据,靠外层循环多次进行的梯度累计


std::optional<DistributedDataLoader> val_loader = std::nullopt;
if (!FLAGS_input_val_bin.empty()) {
val_loader = DistributedDataLoader(
Expand All @@ -201,7 +219,11 @@ void Train(const nn::parallel::Rank &rank) {
}

// TODO(dcj): support more complex optimizer later
auto optimizer = optimizers::SGD(model->Parameters(), FLAGS_learning_rate);
auto lr = FLAGS_learning_rate;
auto optimizer_factory = [lr](const std::vector<std::shared_ptr<Tensor>> &params) {
return std::make_shared<optimizers::SGD>(params, lr);
};
auto optimizer = optimizer_factory(model->Parameters());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么不直接初始化 optimizer 呢?


auto train_iter = train_loader.begin();
std::shared_ptr<nn::Module> loss_fn
Expand All @@ -211,6 +233,15 @@ void Train(const nn::parallel::Rank &rank) {
loss_fn->To(device);
LOG(INFO) << "Rank " << rank.thread_rank() << ": start training";

if (pp_world_size > 1) {
auto shapes = std::vector<std::vector<int64_t>>{{FLAGS_batch_size, FLAGS_sequence_length, model_config.n_embd}};

model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, num_microbatches, shapes,
pp_rank, optimizer_factory);
}

LOG(INFO) << "start training";

for (int step = 0; step < FLAGS_num_iteration + 1; ++step) {
const bool last_step = step == FLAGS_num_iteration;

Expand All @@ -233,58 +264,78 @@ void Train(const nn::parallel::Rank &rank) {
break;
}

// model->Train();
optimizer.ZeroGrad();
// if we are trying to overfit a single batch, we reset the loader here
if (FLAGS_overfit_single_batch) {
// train_loader.Reset();
}
float lossf = 0.0f;
// model->Train();
if (pp_world_size == 1) {
optimizer->ZeroGrad();

// if we are trying to overfit a single batch, we reset the loader here
if (FLAGS_overfit_single_batch) {
// train_loader.Reset();
}

#ifdef PROFILE_MODE
Profiler::Instance().SetTag("Step_" + std::to_string(step));
Profiler::Instance().SetTag("Step_" + std::to_string(step));
#endif
for (int micro_step = 0; micro_step < grad_accum_steps; ++micro_step) {
// enable autocast for the current step
infini_train::AutocastGuard autocast_guard(device->Type(), dtype);
for (int micro_step = 0; micro_step < grad_accum_steps; ++micro_step) {
// enable autocast for the current step
infini_train::AutocastGuard autocast_guard(device->Type(), dtype);

// (bs, seq_len), (bs, seq_len)
auto [x, y] = *train_iter;
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
// TODO(dcj): support dataloader.reset() later
++train_iter;
x = std::make_shared<Tensor>(x->To(device));
y = std::make_shared<Tensor>(y->To(device));

LOG(INFO) << "Rank " << rank.thread_rank() << ": start forward";
// (bs, seq_len, vocab_size)
auto logits = model->Forward({x, y})[0];
LOG(INFO) << "Rank " << rank.thread_rank() << ": finish model forward, start loss forward";
auto loss = loss_fn->Forward({logits, y})[0];
loss = loss / grad_accum_steps;

// disable autocast for the current step (backward is not under autocast)
autocast_guard.Disable();

LOG(INFO) << "Rank " << rank.thread_rank() << ": finish loss forward";
if (ddp_world_size > 1) {
function::AllReduce(loss, function::ReduceOpType::kAvg);
}
auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice());
lossf += static_cast<const float *>(loss_cpu.DataPtr())[0];
LOG(INFO) << "Rank " << rank.thread_rank() << ": start backward";
loss->Backward();
LOG(INFO) << "Rank " << rank.thread_rank() << ": finish backward";
}

// (bs, seq_len), (bs, seq_len)
optimizer->Step();
} else {
auto [x, y] = *train_iter;
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
// TODO(dcj): support dataloader.reset() later
++train_iter;
x = std::make_shared<Tensor>(x->To(device));
y = std::make_shared<Tensor>(y->To(device));
LOG(INFO) << "Rank " << rank.thread_rank() << ": start forward";
// (bs, seq_len, vocab_size)
auto logits = model->Forward({x, y})[0];
LOG(INFO) << "Rank " << rank.thread_rank() << ": finish model forward, start loss forward";
auto loss = loss_fn->Forward({logits, y})[0];
loss = loss / grad_accum_steps;

// disable autocast for the current step (backward is not under autocast)
autocast_guard.Disable();

LOG(INFO) << "Rank " << rank.thread_rank() << ": finish loss forward";
if (ddp_world_size > 1) {
function::AllReduce(loss, function::ReduceOpType::kAvg);
}
auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice());
lossf += static_cast<const float *>(loss_cpu.DataPtr())[0];
LOG(INFO) << "Rank " << rank.thread_rank() << ": start backward";
loss->Backward();
LOG(INFO) << "Rank " << rank.thread_rank() << ": finish backward";
}
optimizer.Step();

lossf = model->TrainStep({x}, {y}, loss_fn);
auto loss_tensor = std::make_shared<Tensor>(std::vector<int64_t>{}, DataType::kFLOAT32);
static_cast<float *>(loss_tensor->DataPtr())[0] = lossf;
auto loss_device_ptr = std::make_shared<Tensor>(loss_tensor->To(device));
function::AllReduce(loss_device_ptr, function::ReduceOpType::kMax);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么是取最大的,而不是直接取最后一个 stage 的 loss

auto loss_copy = loss_device_ptr->To(DeviceManager::Instance()->GetDefaultDevice());
lossf = static_cast<const float *>(loss_copy.DataPtr())[0];
}
const auto iter_end = std::chrono::high_resolution_clock::now();
const double duration_us = std::chrono::duration<double, std::micro>(iter_end - iter_start).count();
const double tps = FLAGS_total_batch_size / (duration_us / 1e6);

if (rank.IsMainRank()) {
LOG(ERROR) << std::format(
"step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, DP={}, TP={}, SP={})",
step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, tps, ddp_world_size,
tp_world_size, sp_world_size);
LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, "
"DP={}, TP={}, SP={}, PP={})",
step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f,
tps, ddp_world_size, tp_world_size, sp_world_size, pp_world_size);

if ((step + 1) % FLAGS_freq_generate_txt == 0) {
if (!tokenizer) {
Expand All @@ -304,7 +355,8 @@ int main(int argc, char *argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);

nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel);
nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel,
FLAGS_pipeline_parallel);

LOG(INFO) << nn::parallel::global::ProcessGroupOverview();

Expand Down
Loading