Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ DEFINE_uint32(text_length, 64, "the length of the generated text");
// optimization
DEFINE_double(learning_rate, 1e-4, "learning rate warmup iterations");
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
DEFINE_int32(zero_stage, 1, "ZeRO stage (1/2/3), default 1 (only take effects when use_distributed_optimizer=true)");
// evaluation
DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?");
DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
Expand Down Expand Up @@ -106,6 +107,7 @@ const std::unordered_map<std::string, GPT2::ModelType> kStrToModelType = {
DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); });
DEFINE_validator(device,
[](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; });
DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 1 && value <= 3; });

void Train(const nn::parallel::Rank &rank) {
using namespace nn::parallel;
Expand Down Expand Up @@ -211,8 +213,8 @@ void Train(const nn::parallel::Rank &rank) {
model, pp_world_size, num_micro_batches, shapes, pp_rank, rank.thread_rank(),
std::dynamic_pointer_cast<GPT2>(model)->GetChunkSize());
if (ddp_world_size > 1) {
auto ddp_config
= DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
auto ddp_config = DistributedDataParallelConfig{
.use_distributed_optimizer = FLAGS_use_distributed_optimizer, .zero_stage = FLAGS_zero_stage};
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
(*mutable_chunks)[chunk_id] = std::make_shared<DistributedDataParallel>(mutable_chunks->at(chunk_id),
Expand All @@ -224,7 +226,8 @@ void Train(const nn::parallel::Rank &rank) {
// before wrapping the model with DistributedDataParallel (DDP).
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
// are created during the conversion.
auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer,
.zero_stage = FLAGS_zero_stage};
model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank(), ddp_config);
}

Expand Down
9 changes: 6 additions & 3 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ DEFINE_uint32(text_length, 64, "the length of the generated text");
// optimization
DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations");
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
DEFINE_int32(zero_stage, 1, "ZeRO stage (1/2/3), default 1 (only take effects when use_distributed_optimizer=true)");
// evaluation
DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?");
DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
Expand Down Expand Up @@ -88,6 +89,7 @@ constexpr char kDtypeBF16[] = "bfloat16";
DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); });
DEFINE_validator(device,
[](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; });
DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 1 && value <= 3; });

void Train(const nn::parallel::Rank &rank) {
using namespace nn::parallel;
Expand Down Expand Up @@ -190,8 +192,8 @@ void Train(const nn::parallel::Rank &rank) {
model, pp_world_size, num_micro_batches, shapes, pp_rank, rank.thread_rank(),
std::dynamic_pointer_cast<LLaMA3>(model)->GetChunkSize());
if (ddp_world_size > 1) {
auto ddp_config
= DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
auto ddp_config = DistributedDataParallelConfig{
.use_distributed_optimizer = FLAGS_use_distributed_optimizer, .zero_stage = FLAGS_zero_stage};
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
(*mutable_chunks)[chunk_id] = std::make_shared<DistributedDataParallel>(mutable_chunks->at(chunk_id),
Expand All @@ -204,7 +206,8 @@ void Train(const nn::parallel::Rank &rank) {
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
// are created during the conversion.

auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer,
.zero_stage = FLAGS_zero_stage};
model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank(), ddp_config);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ class DistributedDataParallelConfig {
// In this case, grad reduce is triggered immediately when a grad is ready or till all grads are ready.
bool overlap_grad_reduce = true;

// ZeRO-DP Stage for memory optimization (Only take effects when use_distributed_optimizer=true)
// ZeRO-1: Optimizer states partitioning, by default
// ZeRO-2: Gradients partitioning
// ZeRO-3: Parameters partitioning
int zero_stage = 1;

// Whether to overlap parameter all-gather with forward compute.
bool overlap_param_gather = true;

Expand All @@ -59,7 +65,7 @@ class DistributedDataParallelConfig {
// Maximum number of parameters in each ParamAndGradBucket.
// NOTE(zbl): This is distinct from DDP Reducer's MB-based bucket caps.
// TODO(zbl): To unify the definition of bucket_size argument for users
size_t bucket_size_in_elements = 40000000;
size_t bucket_size_in_elements = 1000000;

// Whether to pad bucket sizes to improve NCCL bus bandwidth utilization.
bool pad_buckets_for_high_nccl_busbw = false;
Expand Down
26 changes: 24 additions & 2 deletions infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ namespace infini_train::nn::parallel {
class ParamAndGradBucket {
public:
ParamAndGradBucket(const std::vector<std::shared_ptr<Tensor>> &params, const std::shared_ptr<Tensor> &param_data,
const std::shared_ptr<Tensor> &grad_data, size_t offset, size_t num_elements_unpadded,
float gradient_scaling_factor, size_t bucket_id);
DataType param_dtype, const std::shared_ptr<Tensor> &grad_data, DataType grad_dtype,
size_t offset, size_t num_elements_unpadded, float gradient_scaling_factor, size_t bucket_id);

size_t bucket_id() const { return bucket_id_; }

Expand All @@ -33,6 +33,10 @@ class ParamAndGradBucket {

const std::shared_ptr<Tensor> &grad_data() const { return grad_data_; }

DataType param_dtype() const { return param_dtype_; }

DataType grad_dtype() const { return grad_dtype_; }

size_t offset() const { return offset_; }

size_t num_elements_unpadded() const { return num_elements_unpadded_; }
Expand All @@ -49,6 +53,8 @@ class ParamAndGradBucket {
std::vector<std::shared_ptr<Tensor>> params_;
std::shared_ptr<Tensor> param_data_;
std::shared_ptr<Tensor> grad_data_;
DataType param_dtype_;
DataType grad_dtype_;

size_t offset_ = 0;
size_t num_elements_unpadded_ = 0;
Expand All @@ -73,6 +79,11 @@ class ParamAndGradBucketGroup {
// Start grad reduce
void StartGradSync();

// Accumulate a parameter grad into bucket buffer
// ZeRO-2: Use this funtion to take over autograd::AccumulateGrad::Backward
void AccumulateParamGrad(const std::shared_ptr<Tensor> &parameter, const std::shared_ptr<Tensor> &grad,
bool overwrite, float learning_rate);

// Wait for gradient reduce to complete
void FinishGradSync();

Expand All @@ -87,6 +98,9 @@ class ParamAndGradBucketGroup {

const std::vector<std::shared_ptr<ParamAndGradBucket>> &buckets() const { return buckets_; }

// ZeRO-2: Get a bucket's local grad shard buffer
std::shared_ptr<Tensor> GetLocalGradShardBuffer(size_t bucket_idx) const;

const DistributedDataParallelConfig &config() const { return ddp_config_; }

private:
Expand All @@ -98,12 +112,20 @@ class ParamAndGradBucketGroup {

std::unordered_set<Tensor *> params_;
std::unordered_set<Tensor *> params_with_grad_;
// Tensor -> (Bucket, Bucket Index)
std::unordered_map<Tensor *, std::pair<std::shared_ptr<ParamAndGradBucket>, size_t>> param_to_bucket_;

// TODO(zbl): Implement CoalescedWork for aggregate works
// According to Megatron-LM's _coalescing_manager
std::vector<std::shared_ptr<Work>> grad_reduce_work_list_;
std::vector<size_t> grad_reduce_bucket_indices_;
std::vector<std::shared_ptr<Work>> param_gather_work_list_;

// ZeRO-2: persistent grad shard buffers and temporary full grad buffers
std::vector<std::shared_ptr<Tensor>> grad_shard_buffer_list_;
std::vector<std::shared_ptr<Tensor>> temp_full_grad_buffer_list_;
std::vector<bool> temp_full_grad_initialized_;

std::shared_ptr<ParamAndGradBucketGroup> next_param_gather_bucket_group_ = nullptr;

std::vector<std::vector<std::shared_ptr<Tensor>>> param_buffer_shard_list_;
Expand Down
8 changes: 8 additions & 0 deletions infini_train/include/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,12 @@ class Tensor : public std::enable_shared_from_this<Tensor> {
std::shared_ptr<autograd::AccumulateGrad> grad_accumulator();
void ResetAccumulator();

// ZeRO-2: Use this function to take over AccumulateGrad::Backward
using GradAccumulateBypass
= std::function<bool(const std::shared_ptr<Tensor> &grad_output, bool overwrite, float learning_rate)>;
GradAccumulateBypass grad_accumulate_bypass();
void SetGradAccumulateBypass(GradAccumulateBypass);

void RegisterPostAccumulateGradHook(std::shared_ptr<autograd::PostAccumulateGradHook> hook);

autograd::PostAccumulateGradHook *post_accumulate_grad_hook() const;
Expand All @@ -241,6 +247,8 @@ class Tensor : public std::enable_shared_from_this<Tensor> {
// a strong reference to the accumulator to manage its lifetime.
std::shared_ptr<autograd::AccumulateGrad> grad_accumulator_ = nullptr;
std::shared_ptr<autograd::PostAccumulateGradHook> post_accumulate_grad_hook_ = nullptr;
// ZeRO-2: Use this function to take over AccumulateGrad::Backward
GradAccumulateBypass grad_accumulate_bypass_ = nullptr;

bool grad_overwrite_once_ = false;
};
Expand Down
10 changes: 9 additions & 1 deletion infini_train/src/autograd/accumulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,16 @@ AccumulateGrad::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_output
device->SetDevice();

if (grad_output) {
const bool overwrite = tensor_->ConsumeGradOverwriteFlag();
// ZeRO-2: Use a bypass function to perform grad accumulation in temp full grad buffer
auto bypass = tensor_->grad_accumulate_bypass();
if (bypass && bypass(grad_output, overwrite, learning_rate_)) {
tensor_->ResetAccumulator();
return {};
}

if (grad) {
if (tensor_->ConsumeGradOverwriteFlag()) {
if (overwrite) {
// If the tensor is marked to overrite its current grad on next grad update
// See notes in `infini_train::nn::parallel::Reducer::PrepareForBackward()`
// NOTE(zbl): must copy, cannot change grad buffer address
Expand Down
40 changes: 39 additions & 1 deletion infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr<nn::Module> mod
const DistributedDataParallelConfig ddp_config)
: ddp_config_(ddp_config),
ddp_pg_(ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(thread_rank))) {
CHECK(ddp_config_.zero_stage >= 1 && ddp_config_.zero_stage <= 3)
<< "DistributedDataParallel: zero_stage must be in 1/2/3.";
if (ddp_config_.zero_stage >= 3) {
LOG(FATAL) << "DistributedDataParallel: ZeRO-3 is not implemented yet.";
}
if (!ddp_config_.use_distributed_optimizer && ddp_config_.zero_stage >= 1) {
LOG(WARNING) << "DistributedDataParallel: zero_stage is ignored because "
"use_distributed_optimizer is false.";
ddp_config_.zero_stage = 1;
}

for (auto &param : module->Parameters()) {
auto device = param->GetDevice();
CHECK_EQ(device->Index(), thread_rank) << "All parameters must be on the same device as the module";
Expand Down Expand Up @@ -79,6 +90,7 @@ void DistributedDataParallel::BuildParamAndGradBuffers() {
continue;
}

// At the point, zero_stage is already aligned with use_distributed_optimizer.
auto buffer = std::make_shared<ParamAndGradBuffer>(param_list, param_dtype, grad_dtype, ddp_pg_, ddp_config_);

param_grad_buffers_.push_back(buffer);
Expand Down Expand Up @@ -112,6 +124,32 @@ void DistributedDataParallel::BuildParamAndGradBuffers() {
}

void DistributedDataParallel::RegisterBackwardHooks() {
if (ddp_config_.zero_stage >= 2) {
auto &module = modules_.at(kModuleName);
for (auto &param : module->Parameters()) {
if (!param->requires_grad()) {
continue;
}
auto it = param_to_bucket_group_.find(param.get());
if (it == param_to_bucket_group_.end()) {
continue;
}
std::weak_ptr<ParamAndGradBucketGroup> weak_group = it->second;
param->SetGradAccumulateBypass(
[weak_group, param](const std::shared_ptr<Tensor> &grad_output, bool overwrite, float learning_rate) {
if (auto group = weak_group.lock()) {
group->AccumulateParamGrad(param, grad_output, overwrite, learning_rate);
if (group->config().overlap_grad_reduce) {
group->RegisterGradReady(param);
}
return true;
}
return false;
});
}
return;
}

class DDPPostAccumulateHook final : public autograd::PostAccumulateGradHook {
public:
DDPPostAccumulateHook(DistributedDataParallel *ddp, const std::weak_ptr<Tensor> param)
Expand Down Expand Up @@ -143,7 +181,7 @@ void DistributedDataParallel::OnGradReady(const std::shared_ptr<Tensor> &param)
auto it = param_to_bucket_group_.find(param.get());
if (it != param_to_bucket_group_.end()) {
CHECK(param->requires_grad());
if (ddp_config_.overlap_grad_reduce) {
if (ddp_config_.overlap_grad_reduce && (ddp_config_.zero_stage < 2)) {
CHECK(param->grad()) << "param.grad being None is not safe when overlap_grad_reduce is True";
}

Expand Down
17 changes: 14 additions & 3 deletions infini_train/src/nn/parallel/ddp/distributed_optimizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@ void DistributedOptimizer::BuildShardParamsAndBindGrads() {
shard_params_.clear();

for (const auto &group : bucket_groups_) {
for (const auto &bucket : group->buckets()) {
const bool use_grad_shard = group->config().zero_stage >= 2;
const auto &buckets = group->buckets();
for (size_t bucket_idx = 0; bucket_idx < buckets.size(); ++bucket_idx) {
const auto &bucket = buckets[bucket_idx];

auto bucket_param = bucket->param_data();
auto bucket_grad = bucket->grad_data();
auto bucket_grad = use_grad_shard ? group->GetLocalGradShardBuffer(bucket_idx) : bucket->grad_data();

CHECK(bucket_param) << "DistributedOptimizer requires param buffer.";
CHECK(bucket_grad) << "DistributedOptimizer requires grad buffer.";
Expand All @@ -65,7 +68,9 @@ void DistributedOptimizer::BuildShardParamsAndBindGrads() {
CHECK_GT(piece_numel, 0);

const size_t param_piece_offset_bytes = local_start * kDataTypeToSize.at(bucket_param->Dtype());
const size_t grad_piece_offset_bytes = local_start * kDataTypeToSize.at(bucket_grad->Dtype());
// Adjust the offset since bucket_grad is already the shard of grad under ZeRO-2.
auto offset = use_grad_shard ? (local_start - bucket_shard_start) : local_start;
size_t grad_piece_offset_bytes = offset * kDataTypeToSize.at(bucket_grad->Dtype());

auto param_piece = std::make_shared<Tensor>(*bucket_param, param_piece_offset_bytes,
std::vector<int64_t>{static_cast<int64_t>(piece_numel)});
Expand All @@ -74,6 +79,12 @@ void DistributedOptimizer::BuildShardParamsAndBindGrads() {
std::vector<int64_t>{static_cast<int64_t>(piece_numel)});

param_piece->set_grad(grad_piece);
// if (use_grad_shard) {
// // NOTE(zbl): Under ZeRO-2, param->grad() is the shard of grad, not the full grad.
// // The binding is done in the construnctor of DistributedOptimizer.
// // Not until backward is finished, the value of param->grad() will be updated.
// param->set_grad(grad_piece);
// }
shard_params_.push_back(param_piece);
}
}
Expand Down
Loading