Skip to content

feat: Add LoRA (Low-Rank Adaptation) support for efficient model fine-tuning#108

Open
chen2021673 wants to merge 5 commits intomasterfrom
add_lora
Open

feat: Add LoRA (Low-Rank Adaptation) support for efficient model fine-tuning#108
chen2021673 wants to merge 5 commits intomasterfrom
add_lora

Conversation

@chen2021673
Copy link
Contributor

@chen2021673 chen2021673 commented Feb 12, 2026

Summary

Added LoRA (Low-Rank Adaptation) support for parameter-efficient fine-tuning. This feature significantly reduces the number of trainable parameters through low-rank decomposition, enabling efficient fine-tuning of large models.

Changes

New Features

LoRA Infrastructure (infini_train/include/nn/lora/):

  • lora_config.h/cc - LoRA configuration (rank, alpha, dropout)
  • lora_linear.h/cc - LoRA linear layer wrapper
  • lora_model.h/cc - Multi-LoRA layer management
  • lora_parallel_linear.h/cc - Tensor parallelism support
  • lora_utils.h/cc - Utility functions

Tests:

  • test/lora/test_lora.cc - Unit tests

Documentation:

  • docs/lora_usage.md - Usage documentation

Examples:

  • example/gpt2/main.cc - Added LoRA training example

Build:

  • CMakeLists.txt - Added test_lora build target

Test Result

精度:
image
性能:
image
llama3 运行结果对比:
image

image

chen2021673 and others added 5 commits February 12, 2026 09:11
- Add LoRA module infrastructure with configurable rank, alpha, dropout
- Implement LoRALinear wrapper for seamless integration with Linear layers
- Support tensor parallelism via LoRAParallelLinear
- Add LoRAModel utility for managing multiple LoRA layers
- Integrate LoRA configuration and utilities
- Add GPT2 example demonstrating LoRA fine-tuning
- Include comprehensive usage documentation and test suite

Co-Authored-By: Claude Opus 4.6 <[email protected]>
- Refactor LoRA config construction with proper target module parsing
- Add GetLoRAModel for in-place LoRA layer injection
- Fix DDP reducer to correctly handle LoRA parameters
- Fix RowParallel/ColumnParallel LoRA input handling to match base module behavior
- Add shape-based defensive checks for TP/SP consistency
- Move TP/SP communication helper function declarations to utils.h
- Move getter implementations from header to .cc file
- Add unit test for SaveLoRAWeights/LoadLoRAWeights functionality

Co-Authored-By: Claude Opus 4.6 <[email protected]>

std::vector<std::shared_ptr<Tensor>> LoRAModel::TrainableParameters() const { return GetLoRAParameters(base_model_); }

std::vector<std::shared_ptr<Tensor>> LoRAModel::Parameters() const { return base_model_->Parameters(); }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块原则上 base_model 的 Parameters() 的行为已经被改写过了(调用栈:FreezeBaseModel->GetLoRAParameters->Module::LoRAParameters()),最好加个 NOTE 说明下行为跟 naive 版本的 Module::Parameters() 不太一样

// 3. Add LoRA contribution to base output
// Both should now have the same sequence dimension
auto output = base_output->Add(scaled_lora);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对比 ColumnParallelLinear::Forward,似乎少了一步 allgather 操作:GatherFromTPRegionFunc()。
目前仅有的几个测例的正确性上应该不影响,因为后续默认都会跟着 RowParallelLinear,这种情况下 gather_output=false。

// Freeze base model parameters
FreezeBaseModel(base_model_);

LOG(INFO) << "LoRAModel created with rank=" << config_.rank << ", alpha=" << config_.alpha;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

base_model_ 是不是得注册到 modules_ 里面(DDP 里就用了 modules_[kModuleName] = std::move(module);),不然可能会导致 To(dtype) 或者 To(device) 操作出问题,现在似乎是必须 base_model 先完成 To(dtype) 或者 To(device) 了再用 LoRAModel 来包,否则先包 LoraModel 再 To() 的话会导致 frozen 的 params 不受影响

// LoRA A: [rank, in_features] - replicated across TP ranks (implemented as Linear)
// LoRA B: [out_features_per_partition, rank] - sharded like base weight (implemented as ColumnParallelLinear with
// gather_output)
class LoRAColumnParallelLinear : public nn::CloneableModule<LoRAColumnParallelLinear> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉是不是可以继承自原 ColumnParallelLinear,篇幅上可以省一些基类的成员定义和 getter

// Weight shape: [out_features, in_features_per_partition]
// LoRA A: [rank, in_features_per_partition] - sharded like base weight (implemented as RowParallelLinear with
// input_is_parallel) LoRA B: [out_features, rank] - replicated (implemented as Linear)
class LoRARowParallelLinear : public nn::CloneableModule<LoRARowParallelLinear> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同理

continue;
}

if (type == Linear::kType) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件里从这里开始有比较多这种三个 if 判断,但实际上就是一个 class name 的差异的代码,感觉可以采取一些更优雅的写法

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants