From 4a2527fb42e37a3a79732b2bb2a22e94abed4449 Mon Sep 17 00:00:00 2001 From: bolunz Date: Wed, 4 Feb 2026 15:11:29 +0800 Subject: [PATCH] feat: Support ZeRO-2 based on DistributedOptimizer --- example/gpt2/main.cc | 9 +- example/llama3/main.cc | 9 +- .../ddp/distributed_data_parallel_config.h | 8 +- .../nn/parallel/ddp/param_and_grad_buffer.h | 26 ++- infini_train/include/tensor.h | 8 + infini_train/src/autograd/accumulate.cc | 10 +- .../parallel/ddp/distributed_data_parallel.cc | 40 +++- .../nn/parallel/ddp/distributed_optimizer.cc | 17 +- .../nn/parallel/ddp/param_and_grad_buffer.cc | 179 ++++++++++++++++-- infini_train/src/tensor.cc | 7 + scripts/test_config.json | 110 +++++++++++ 11 files changed, 391 insertions(+), 32 deletions(-) diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index 36707163..959ecfa8 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -52,6 +52,7 @@ DEFINE_uint32(text_length, 64, "the length of the generated text"); // optimization DEFINE_double(learning_rate, 1e-4, "learning rate warmup iterations"); DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)"); +DEFINE_int32(zero_stage, 1, "ZeRO stage (1/2/3), default 1 (only take effects when use_distributed_optimizer=true)"); // evaluation DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?"); DEFINE_uint32(sample_every, 0, "how often to sample from the model?"); @@ -106,6 +107,7 @@ const std::unordered_map kStrToModelType = { DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); }); DEFINE_validator(device, [](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; }); +DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 1 && value <= 3; }); void Train(const nn::parallel::Rank &rank) { using namespace nn::parallel; @@ -211,8 +213,8 @@ void Train(const nn::parallel::Rank &rank) { model, pp_world_size, num_micro_batches, shapes, pp_rank, rank.thread_rank(), std::dynamic_pointer_cast(model)->GetChunkSize()); if (ddp_world_size > 1) { - auto ddp_config - = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer}; + auto ddp_config = DistributedDataParallelConfig{ + .use_distributed_optimizer = FLAGS_use_distributed_optimizer, .zero_stage = FLAGS_zero_stage}; auto *mutable_chunks = dynamic_cast(model.get())->mutable_chunks(); for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) { (*mutable_chunks)[chunk_id] = std::make_shared(mutable_chunks->at(chunk_id), @@ -224,7 +226,8 @@ void Train(const nn::parallel::Rank &rank) { // before wrapping the model with DistributedDataParallel (DDP). // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors // are created during the conversion. - auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer}; + auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer, + .zero_stage = FLAGS_zero_stage}; model = std::make_shared(model, rank.thread_rank(), ddp_config); } diff --git a/example/llama3/main.cc b/example/llama3/main.cc index 421d6679..0cd2bce7 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -51,6 +51,7 @@ DEFINE_uint32(text_length, 64, "the length of the generated text"); // optimization DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations"); DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)"); +DEFINE_int32(zero_stage, 1, "ZeRO stage (1/2/3), default 1 (only take effects when use_distributed_optimizer=true)"); // evaluation DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?"); DEFINE_uint32(sample_every, 0, "how often to sample from the model?"); @@ -88,6 +89,7 @@ constexpr char kDtypeBF16[] = "bfloat16"; DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); }); DEFINE_validator(device, [](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; }); +DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 1 && value <= 3; }); void Train(const nn::parallel::Rank &rank) { using namespace nn::parallel; @@ -190,8 +192,8 @@ void Train(const nn::parallel::Rank &rank) { model, pp_world_size, num_micro_batches, shapes, pp_rank, rank.thread_rank(), std::dynamic_pointer_cast(model)->GetChunkSize()); if (ddp_world_size > 1) { - auto ddp_config - = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer}; + auto ddp_config = DistributedDataParallelConfig{ + .use_distributed_optimizer = FLAGS_use_distributed_optimizer, .zero_stage = FLAGS_zero_stage}; auto *mutable_chunks = dynamic_cast(model.get())->mutable_chunks(); for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) { (*mutable_chunks)[chunk_id] = std::make_shared(mutable_chunks->at(chunk_id), @@ -204,7 +206,8 @@ void Train(const nn::parallel::Rank &rank) { // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors // are created during the conversion. - auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer}; + auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer, + .zero_stage = FLAGS_zero_stage}; model = std::make_shared(model, rank.thread_rank(), ddp_config); } diff --git a/infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h b/infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h index b0f50b21..5fe274c6 100644 --- a/infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h +++ b/infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h @@ -40,6 +40,12 @@ class DistributedDataParallelConfig { // In this case, grad reduce is triggered immediately when a grad is ready or till all grads are ready. bool overlap_grad_reduce = true; + // ZeRO-DP Stage for memory optimization (Only take effects when use_distributed_optimizer=true) + // ZeRO-1: Optimizer states partitioning, by default + // ZeRO-2: Gradients partitioning + // ZeRO-3: Parameters partitioning + int zero_stage = 1; + // Whether to overlap parameter all-gather with forward compute. bool overlap_param_gather = true; @@ -59,7 +65,7 @@ class DistributedDataParallelConfig { // Maximum number of parameters in each ParamAndGradBucket. // NOTE(zbl): This is distinct from DDP Reducer's MB-based bucket caps. // TODO(zbl): To unify the definition of bucket_size argument for users - size_t bucket_size_in_elements = 40000000; + size_t bucket_size_in_elements = 1000000; // Whether to pad bucket sizes to improve NCCL bus bandwidth utilization. bool pad_buckets_for_high_nccl_busbw = false; diff --git a/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h b/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h index c83fe9a5..8ae86678 100644 --- a/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h +++ b/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h @@ -22,8 +22,8 @@ namespace infini_train::nn::parallel { class ParamAndGradBucket { public: ParamAndGradBucket(const std::vector> ¶ms, const std::shared_ptr ¶m_data, - const std::shared_ptr &grad_data, size_t offset, size_t num_elements_unpadded, - float gradient_scaling_factor, size_t bucket_id); + DataType param_dtype, const std::shared_ptr &grad_data, DataType grad_dtype, + size_t offset, size_t num_elements_unpadded, float gradient_scaling_factor, size_t bucket_id); size_t bucket_id() const { return bucket_id_; } @@ -33,6 +33,10 @@ class ParamAndGradBucket { const std::shared_ptr &grad_data() const { return grad_data_; } + DataType param_dtype() const { return param_dtype_; } + + DataType grad_dtype() const { return grad_dtype_; } + size_t offset() const { return offset_; } size_t num_elements_unpadded() const { return num_elements_unpadded_; } @@ -49,6 +53,8 @@ class ParamAndGradBucket { std::vector> params_; std::shared_ptr param_data_; std::shared_ptr grad_data_; + DataType param_dtype_; + DataType grad_dtype_; size_t offset_ = 0; size_t num_elements_unpadded_ = 0; @@ -73,6 +79,11 @@ class ParamAndGradBucketGroup { // Start grad reduce void StartGradSync(); + // Accumulate a parameter grad into bucket buffer + // ZeRO-2: Use this funtion to take over autograd::AccumulateGrad::Backward + void AccumulateParamGrad(const std::shared_ptr ¶meter, const std::shared_ptr &grad, + bool overwrite, float learning_rate); + // Wait for gradient reduce to complete void FinishGradSync(); @@ -87,6 +98,9 @@ class ParamAndGradBucketGroup { const std::vector> &buckets() const { return buckets_; } + // ZeRO-2: Get a bucket's local grad shard buffer + std::shared_ptr GetLocalGradShardBuffer(size_t bucket_idx) const; + const DistributedDataParallelConfig &config() const { return ddp_config_; } private: @@ -98,12 +112,20 @@ class ParamAndGradBucketGroup { std::unordered_set params_; std::unordered_set params_with_grad_; + // Tensor -> (Bucket, Bucket Index) + std::unordered_map, size_t>> param_to_bucket_; // TODO(zbl): Implement CoalescedWork for aggregate works // According to Megatron-LM's _coalescing_manager std::vector> grad_reduce_work_list_; + std::vector grad_reduce_bucket_indices_; std::vector> param_gather_work_list_; + // ZeRO-2: persistent grad shard buffers and temporary full grad buffers + std::vector> grad_shard_buffer_list_; + std::vector> temp_full_grad_buffer_list_; + std::vector temp_full_grad_initialized_; + std::shared_ptr next_param_gather_bucket_group_ = nullptr; std::vector>> param_buffer_shard_list_; diff --git a/infini_train/include/tensor.h b/infini_train/include/tensor.h index 6ff3fa64..c2527530 100644 --- a/infini_train/include/tensor.h +++ b/infini_train/include/tensor.h @@ -227,6 +227,12 @@ class Tensor : public std::enable_shared_from_this { std::shared_ptr grad_accumulator(); void ResetAccumulator(); + // ZeRO-2: Use this function to take over AccumulateGrad::Backward + using GradAccumulateBypass + = std::function &grad_output, bool overwrite, float learning_rate)>; + GradAccumulateBypass grad_accumulate_bypass(); + void SetGradAccumulateBypass(GradAccumulateBypass); + void RegisterPostAccumulateGradHook(std::shared_ptr hook); autograd::PostAccumulateGradHook *post_accumulate_grad_hook() const; @@ -241,6 +247,8 @@ class Tensor : public std::enable_shared_from_this { // a strong reference to the accumulator to manage its lifetime. std::shared_ptr grad_accumulator_ = nullptr; std::shared_ptr post_accumulate_grad_hook_ = nullptr; + // ZeRO-2: Use this function to take over AccumulateGrad::Backward + GradAccumulateBypass grad_accumulate_bypass_ = nullptr; bool grad_overwrite_once_ = false; }; diff --git a/infini_train/src/autograd/accumulate.cc b/infini_train/src/autograd/accumulate.cc index def9cad8..d2ef7e87 100644 --- a/infini_train/src/autograd/accumulate.cc +++ b/infini_train/src/autograd/accumulate.cc @@ -25,8 +25,16 @@ AccumulateGrad::Backward(const std::vector> &grad_output device->SetDevice(); if (grad_output) { + const bool overwrite = tensor_->ConsumeGradOverwriteFlag(); + // ZeRO-2: Use a bypass function to perform grad accumulation in temp full grad buffer + auto bypass = tensor_->grad_accumulate_bypass(); + if (bypass && bypass(grad_output, overwrite, learning_rate_)) { + tensor_->ResetAccumulator(); + return {}; + } + if (grad) { - if (tensor_->ConsumeGradOverwriteFlag()) { + if (overwrite) { // If the tensor is marked to overrite its current grad on next grad update // See notes in `infini_train::nn::parallel::Reducer::PrepareForBackward()` // NOTE(zbl): must copy, cannot change grad buffer address diff --git a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc index 35a73a23..fc331581 100644 --- a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc +++ b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc @@ -24,6 +24,17 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr mod const DistributedDataParallelConfig ddp_config) : ddp_config_(ddp_config), ddp_pg_(ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(thread_rank))) { + CHECK(ddp_config_.zero_stage >= 1 && ddp_config_.zero_stage <= 3) + << "DistributedDataParallel: zero_stage must be in 1/2/3."; + if (ddp_config_.zero_stage >= 3) { + LOG(FATAL) << "DistributedDataParallel: ZeRO-3 is not implemented yet."; + } + if (!ddp_config_.use_distributed_optimizer && ddp_config_.zero_stage >= 1) { + LOG(WARNING) << "DistributedDataParallel: zero_stage is ignored because " + "use_distributed_optimizer is false."; + ddp_config_.zero_stage = 1; + } + for (auto ¶m : module->Parameters()) { auto device = param->GetDevice(); CHECK_EQ(device->Index(), thread_rank) << "All parameters must be on the same device as the module"; @@ -79,6 +90,7 @@ void DistributedDataParallel::BuildParamAndGradBuffers() { continue; } + // At the point, zero_stage is already aligned with use_distributed_optimizer. auto buffer = std::make_shared(param_list, param_dtype, grad_dtype, ddp_pg_, ddp_config_); param_grad_buffers_.push_back(buffer); @@ -112,6 +124,32 @@ void DistributedDataParallel::BuildParamAndGradBuffers() { } void DistributedDataParallel::RegisterBackwardHooks() { + if (ddp_config_.zero_stage >= 2) { + auto &module = modules_.at(kModuleName); + for (auto ¶m : module->Parameters()) { + if (!param->requires_grad()) { + continue; + } + auto it = param_to_bucket_group_.find(param.get()); + if (it == param_to_bucket_group_.end()) { + continue; + } + std::weak_ptr weak_group = it->second; + param->SetGradAccumulateBypass( + [weak_group, param](const std::shared_ptr &grad_output, bool overwrite, float learning_rate) { + if (auto group = weak_group.lock()) { + group->AccumulateParamGrad(param, grad_output, overwrite, learning_rate); + if (group->config().overlap_grad_reduce) { + group->RegisterGradReady(param); + } + return true; + } + return false; + }); + } + return; + } + class DDPPostAccumulateHook final : public autograd::PostAccumulateGradHook { public: DDPPostAccumulateHook(DistributedDataParallel *ddp, const std::weak_ptr param) @@ -143,7 +181,7 @@ void DistributedDataParallel::OnGradReady(const std::shared_ptr ¶m) auto it = param_to_bucket_group_.find(param.get()); if (it != param_to_bucket_group_.end()) { CHECK(param->requires_grad()); - if (ddp_config_.overlap_grad_reduce) { + if (ddp_config_.overlap_grad_reduce && (ddp_config_.zero_stage < 2)) { CHECK(param->grad()) << "param.grad being None is not safe when overlap_grad_reduce is True"; } diff --git a/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc b/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc index 55e5800b..48fd7103 100644 --- a/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc +++ b/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc @@ -35,10 +35,13 @@ void DistributedOptimizer::BuildShardParamsAndBindGrads() { shard_params_.clear(); for (const auto &group : bucket_groups_) { - for (const auto &bucket : group->buckets()) { + const bool use_grad_shard = group->config().zero_stage >= 2; + const auto &buckets = group->buckets(); + for (size_t bucket_idx = 0; bucket_idx < buckets.size(); ++bucket_idx) { + const auto &bucket = buckets[bucket_idx]; auto bucket_param = bucket->param_data(); - auto bucket_grad = bucket->grad_data(); + auto bucket_grad = use_grad_shard ? group->GetLocalGradShardBuffer(bucket_idx) : bucket->grad_data(); CHECK(bucket_param) << "DistributedOptimizer requires param buffer."; CHECK(bucket_grad) << "DistributedOptimizer requires grad buffer."; @@ -65,7 +68,9 @@ void DistributedOptimizer::BuildShardParamsAndBindGrads() { CHECK_GT(piece_numel, 0); const size_t param_piece_offset_bytes = local_start * kDataTypeToSize.at(bucket_param->Dtype()); - const size_t grad_piece_offset_bytes = local_start * kDataTypeToSize.at(bucket_grad->Dtype()); + // Adjust the offset since bucket_grad is already the shard of grad under ZeRO-2. + auto offset = use_grad_shard ? (local_start - bucket_shard_start) : local_start; + size_t grad_piece_offset_bytes = offset * kDataTypeToSize.at(bucket_grad->Dtype()); auto param_piece = std::make_shared(*bucket_param, param_piece_offset_bytes, std::vector{static_cast(piece_numel)}); @@ -74,6 +79,12 @@ void DistributedOptimizer::BuildShardParamsAndBindGrads() { std::vector{static_cast(piece_numel)}); param_piece->set_grad(grad_piece); + // if (use_grad_shard) { + // // NOTE(zbl): Under ZeRO-2, param->grad() is the shard of grad, not the full grad. + // // The binding is done in the construnctor of DistributedOptimizer. + // // Not until backward is finished, the value of param->grad() will be updated. + // param->set_grad(grad_piece); + // } shard_params_.push_back(param_piece); } } diff --git a/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc b/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc index 56984ca0..442dd801 100644 --- a/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc +++ b/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc @@ -6,6 +6,7 @@ #include "glog/logging.h" +#include "infini_train/include/dispatcher.h" #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h" #include "infini_train/include/nn/parallel/global.h" @@ -53,12 +54,12 @@ std::vector> ShardBuffer(const std::shared_ptr b } // namespace ParamAndGradBucket::ParamAndGradBucket(const std::vector> ¶ms, - const std::shared_ptr ¶m_data, - const std::shared_ptr &grad_data, size_t offset, + const std::shared_ptr ¶m_data, DataType param_dtype, + const std::shared_ptr &grad_data, DataType grad_dtype, size_t offset, size_t num_elements_unpadded, float gradient_scaling_factor, size_t bucket_id) - : bucket_id_(bucket_id), params_(std::move(params)), param_data_(std::move(param_data)), - grad_data_(std::move(grad_data)), offset_(offset), num_elements_unpadded_(num_elements_unpadded), - gradient_scaling_factor_(gradient_scaling_factor) { + : bucket_id_(bucket_id), params_(std::move(params)), param_data_(std::move(param_data)), param_dtype_(param_dtype), + grad_data_(std::move(grad_data)), grad_dtype_(grad_dtype), offset_(offset), + num_elements_unpadded_(num_elements_unpadded), gradient_scaling_factor_(gradient_scaling_factor) { size_t current_offset = 0; for (const auto ¶m : params_) { auto numel = param->NumElements(); @@ -97,8 +98,12 @@ ParamAndGradBucketGroup::ParamAndGradBucketGroup(const std::vectorparams()) { params_.insert(param.get()); } + for (size_t bucket_idx = 0; bucket_idx < buckets_.size(); ++bucket_idx) { + const auto &bucket = buckets_[bucket_idx]; + for (const auto ¶m : bucket->params()) { + params_.insert(param.get()); + param_to_bucket_[param.get()] = {bucket, bucket_idx}; + } } if (rank_in_collective_pg_ == -1) { auto param = *params_.begin(); @@ -109,15 +114,40 @@ ParamAndGradBucketGroup::ParamAndGradBucketGroup(const std::vector= 2) { + for (size_t i = 0; i < buckets_.size(); ++i) { + auto bucket = buckets_[i]; + CHECK(bucket->param_data()) << "ParamAndGradBucketGroup: param buffer required for ZeRO-2."; + const size_t bucket_numel = bucket->param_data()->NumElements(); + if (bucket_numel == 0) { + continue; + } + CHECK_EQ(bucket_numel % collective_pg_size_, 0); + const size_t shard_numel = bucket_numel / collective_pg_size_; + auto param = bucket->params().front(); + grad_shard_buffer_list_[i] = AllocateFlatBuffer(shard_numel, bucket->grad_dtype(), param->GetDevice()); + } + } } void ParamAndGradBucketGroup::Reset() { params_with_grad_.clear(); grad_reduce_work_list_.clear(); + grad_reduce_bucket_indices_.clear(); param_gather_work_list_.clear(); is_last_microbatch_ = true; grad_reduce_dispatched_ = false; param_gather_dispatched_ = false; + + if (ddp_config_.zero_stage >= 2) { + std::fill(temp_full_grad_buffer_list_.begin(), temp_full_grad_buffer_list_.end(), nullptr); + std::fill(temp_full_grad_initialized_.begin(), temp_full_grad_initialized_.end(), false); + } } void ParamAndGradBucketGroup::RegisterGradReady(const std::shared_ptr ¶meter) { @@ -148,6 +178,69 @@ void ParamAndGradBucketGroup::RegisterGradReady(const std::shared_ptr &p } } +void ParamAndGradBucketGroup::AccumulateParamGrad(const std::shared_ptr ¶meter, + const std::shared_ptr &grad, bool overwrite, + float learning_rate) { + if (ddp_config_.zero_stage < 2) { + LOG(FATAL) << "ParamAndGradBucketGroup: AccumulateParamGrad called when ZeRO-2 is disabled."; + return; + } + if (!grad || !parameter) { + return; + } + + auto it = param_to_bucket_.find(parameter.get()); + if (it == param_to_bucket_.end()) { + return; + } + auto bucket = it->second.first; + const size_t bucket_idx = it->second.second; + + size_t param_start_in_bucket = 0, param_end_in_bucket = 0; + auto found = bucket->GetTensorLocInBucket(parameter, param_start_in_bucket, param_end_in_bucket); + if (!found) { + return; + } + + if (!temp_full_grad_buffer_list_[bucket_idx]) { + CHECK(bucket->param_data()) << "ParamAndGradBucketGroup: param buffer required for ZeRO-2."; + const size_t bucket_numel = bucket->param_data()->NumElements(); + if (bucket_numel == 0) { + return; + } + temp_full_grad_buffer_list_[bucket_idx] + = AllocateFlatBuffer(bucket_numel, bucket->grad_dtype(), parameter->GetDevice()); + temp_full_grad_initialized_[bucket_idx] = false; + } + + if (!temp_full_grad_initialized_[bucket_idx]) { + temp_full_grad_buffer_list_[bucket_idx]->Fill(0.0f); + temp_full_grad_initialized_[bucket_idx] = true; + } + + const size_t offset_bytes = param_start_in_bucket * kDataTypeToSize.at(bucket->grad_dtype()); + auto bucket_grad_view + = std::make_shared(*temp_full_grad_buffer_list_[bucket_idx], offset_bytes, parameter->Dims()); + + if (overwrite) { + bucket_grad_view->CopyFrom(*grad); + } else { + auto kernel = Dispatcher::Instance().GetKernel({parameter->GetDevice()->Type(), "AccumulateGrad"}); + kernel.Call(grad, learning_rate, bucket_grad_view); + } +} + +std::shared_ptr ParamAndGradBucketGroup::GetLocalGradShardBuffer(size_t bucket_idx) const { + if (ddp_config_.zero_stage < 2) { + LOG(WARNING) << "ParamAndGradBucketGroup: GetLocalGradShardBuffer called when ZeRO-2 is disabled."; + return nullptr; + } + if (bucket_idx >= grad_shard_buffer_list_.size()) { + return nullptr; + } + return grad_shard_buffer_list_[bucket_idx]; +} + void ParamAndGradBucketGroup::StartGradSync() { if (!collective_pg_) { LOG(FATAL) << "ParamAndGradBucketGroup: StartGradSync() called with null collective_pg_."; @@ -175,6 +268,20 @@ void ParamAndGradBucketGroup::StartGradSync() { for (auto i = 0; i < buckets_.size(); ++i) { auto bucket = buckets_[i]; + + if (ddp_config_.zero_stage >= 2) { + auto full_grad_buffer = temp_full_grad_buffer_list_[i]; + if (!full_grad_buffer) { + continue; + } + CHECK(grad_shard_buffer_list_[i]) << "ParamAndGradBucketGroup: grad shard buffer missing."; + auto local_data_view = grad_shard_buffer_list_[i]; + grad_reduce_work_list_.push_back( + collective_pg_->ReduceScatter(local_data_view, full_grad_buffer, reduce_op, async_op)); + grad_reduce_bucket_indices_.push_back(i); + continue; + } + std::shared_ptr grad_buffer = bucket->grad_data(); if (!grad_buffer) { continue; @@ -201,6 +308,10 @@ void ParamAndGradBucketGroup::FinishGradSync() { StartGradSync(); } + if (params_with_grad_.empty()) { + return; + } + if (!ddp_config_.overlap_grad_reduce) { // Assume reduce ops are synced and no work needs to be resolved grad_reduce_work_list_.clear(); @@ -212,6 +323,20 @@ void ParamAndGradBucketGroup::FinishGradSync() { << "ParamAndGradBucketGroup: Communication call has not been issued for this bucket(" << params_with_grad_.size() << "/" << params_.size() << " params have grad available)"; + if (ddp_config_.zero_stage >= 2) { + for (size_t idx = 0; idx < grad_reduce_work_list_.size(); ++idx) { + auto &work = grad_reduce_work_list_[idx]; + work->WaitNonBlocking(); + const size_t bucket_idx = grad_reduce_bucket_indices_[idx]; + temp_full_grad_buffer_list_[bucket_idx].reset(); + temp_full_grad_initialized_[bucket_idx] = false; + } + grad_reduce_work_list_.clear(); + grad_reduce_bucket_indices_.clear(); + grad_reduce_dispatched_ = false; + return; + } + for (auto work : grad_reduce_work_list_) { work->WaitNonBlocking(); } grad_reduce_work_list_.clear(); grad_reduce_dispatched_ = false; @@ -400,7 +525,11 @@ void ParamAndGradBuffer::BuildBuckets(DataType param_dtype, DataType grad_dtype) // No param buffer needed if optimzer is not distributed param_buffer_.reset(); } - grad_buffer_ = AllocateFlatBuffer(numel_, grad_dtype, device); + if (ddp_config_.zero_stage >= 2) { + grad_buffer_.reset(); + } else { + grad_buffer_ = AllocateFlatBuffer(numel_, grad_dtype, device); + } LOG(INFO) << "ParamAndGradBuffer: numel_unpadded=" << numel_unpadded_ << ", numel (padded)=" << numel_; @@ -425,14 +554,17 @@ void ParamAndGradBuffer::BuildBuckets(DataType param_dtype, DataType grad_dtype) bucket_param_view = GetBufferView(param_buffer_, start_index, std::vector{static_cast(end_index - start_index)}); } - std::shared_ptr bucket_grad_view = GetBufferView( - grad_buffer_, start_index, std::vector{static_cast(end_index - start_index)}); + std::shared_ptr bucket_grad_view; + if (grad_buffer_) { + bucket_grad_view = GetBufferView(grad_buffer_, start_index, + std::vector{static_cast(end_index - start_index)}); + } // FIXME(zbl): Use default for now float gradient_scaling_factor = 1.0f; - auto bucket - = std::make_shared(bucket_params, bucket_param_view, bucket_grad_view, start_index, - num_elements_unpadded, gradient_scaling_factor, bucket_id); + auto bucket = std::make_shared(bucket_params, bucket_param_view, param_dtype, + bucket_grad_view, grad_dtype, start_index, + num_elements_unpadded, gradient_scaling_factor, bucket_id); for (auto param : bucket_params) { CHECK(param_bucket_map_.find(param.get()) == param_bucket_map_.end()) @@ -455,8 +587,11 @@ void ParamAndGradBuffer::BuildBuckets(DataType param_dtype, DataType grad_dtype) param->SetData(*param_buffer_, param_start_index * kDataTypeToSize.at(param_buffer_->Dtype()), true); } - auto grad_view = GetBufferView(grad_buffer_, param_start_index, param->Dims()); - param->set_grad(grad_view); + std::shared_ptr grad_view; + if (grad_buffer_) { + grad_view = GetBufferView(grad_buffer_, param_start_index, param->Dims()); + param->set_grad(grad_view); + } // Save grad view for each params --i; grads_[i] = grad_view; @@ -507,7 +642,9 @@ void ParamAndGradBuffer::Reset(bool need_rebind) { if (!need_rebind) { grad_buffer_->Fill(0.f); } - need_rebind_grad_views_ = need_rebind; + // NOTE(zbl): Under ZeRO-2, param->grad() is the shard of grad, not the full grad. + // It is constantly pointed to the shard of grad, so no need to rebind. + need_rebind_grad_views_ = need_rebind && (ddp_config_.zero_stage < 2); } void ParamAndGradBuffer::RebindGradViews() { @@ -515,10 +652,16 @@ void ParamAndGradBuffer::RebindGradViews() { return; } + if (!grad_buffer_) { + return; + } + CHECK_EQ(params_.size(), grads_.size()); for (size_t i = 0; i < params_.size(); ++i) { - params_[i]->set_grad(grads_[i]); - params_[i]->MarkGradOverwriteOnNextAccum(); + if (grads_[i]) { + params_[i]->set_grad(grads_[i]); + params_[i]->MarkGradOverwriteOnNextAccum(); + } } need_rebind_grad_views_ = false; diff --git a/infini_train/src/tensor.cc b/infini_train/src/tensor.cc index 05257953..586479fc 100644 --- a/infini_train/src/tensor.cc +++ b/infini_train/src/tensor.cc @@ -664,6 +664,13 @@ void Tensor::ResetAccumulator() { } } +Tensor::GradAccumulateBypass Tensor::grad_accumulate_bypass() { + CHECK(grad_accumulator_) << "grad_accumulate_bypass() should only be called on leaf tensors"; + return grad_accumulate_bypass_; +} + +void Tensor::SetGradAccumulateBypass(GradAccumulateBypass bypass) { grad_accumulate_bypass_ = std::move(bypass); } + void Tensor::RegisterPostAccumulateGradHook(std::shared_ptr hook) { CHECK(requires_grad_) << "cannot register a hook on a tensor that doesn't require gradient"; diff --git a/scripts/test_config.json b/scripts/test_config.json index 84f4fedd..fe47e657 100644 --- a/scripts/test_config.json +++ b/scripts/test_config.json @@ -72,6 +72,18 @@ "use_distributed_optimizer": true } }, + { + "id": "3_distopt_zero2", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "use_distributed_optimizer": true, + "zero_stage": 2 + } + }, { "id": "3_bfloat16", "args": { @@ -93,6 +105,18 @@ "use_distributed_optimizer": true } }, + { + "id": "3_bfloat16_distopt_zero2", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "use_distributed_optimizer": true, + "zero_stage": 2 + } + }, { "id": "4", "args": { @@ -116,6 +140,19 @@ "use_distributed_optimizer": true } }, + { + "id": "4_distopt_zero2", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 4, + "use_distributed_optimizer": true, + "zero_stage": 2 + } + }, { "id": "4_bfloat16", "args": { @@ -139,6 +176,19 @@ "use_distributed_optimizer": true } }, + { + "id": "4_bfloat16_distopt_zero2", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 4, + "use_distributed_optimizer": true, + "zero_stage": 2 + } + }, { "id": "5", "args": { @@ -164,6 +214,20 @@ "use_distributed_optimizer": true } }, + { + "id": "5_distopt_zero2", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 4, + "sequence_parallel": true, + "use_distributed_optimizer": true, + "zero_stage": 2 + } + }, { "id": "5_bfloat16", "args": { @@ -189,6 +253,20 @@ "use_distributed_optimizer": true } }, + { + "id": "5_bfloat16_distopt_zero2", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 4, + "sequence_parallel": true, + "use_distributed_optimizer": true, + "zero_stage": 2 + } + }, { "id": "6", "args": { @@ -264,6 +342,22 @@ "use_distributed_optimizer": true } }, + { + "id": "8_distopt_zero2", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 2, + "sequence_parallel": true, + "pipeline_parallel": 2, + "virtual_pipeline_parallel": 2, + "use_distributed_optimizer": true, + "zero_stage": 2 + } + }, { "id": "8_bfloat16", "args": { @@ -292,6 +386,22 @@ "virtual_pipeline_parallel": 2, "use_distributed_optimizer": true } + }, + { + "id": "8_bfloat16_distopt_zero2", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 2, + "sequence_parallel": true, + "pipeline_parallel": 2, + "virtual_pipeline_parallel": 2, + "use_distributed_optimizer": true, + "zero_stage": 2 + } } ] }