Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions tests/model/test_qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.distributed as dist
from xtuner.v1.model import Qwen3_5_VLMoE35BA3Config
from xtuner.v1.loss.ce_loss import CELossConfig
from xtuner.v1.model.moe.moe import SequenceContext
from xtuner.v1.model.moe.moe import SequenceContext, MTPConfig
from xtuner.v1.utils.test_utils import init_data_mesh
from xtuner.v1.datasets import Qwen3VLTokenizeFnConfig
from xtuner.v1.config import FSDPConfig
Expand Down Expand Up @@ -233,6 +233,7 @@ def test_save_hf_with_mtp(self, device, sp_size):

with torch.device("meta"):
model_cfg = Qwen3_5_VLMoE35BA3Config(compile_cfg=False)
model_cfg.text_config.mtp_config = MTPConfig(num_layers=1)
qwen3vl_model = model_cfg.build().to(torch.bfloat16)

fsdp_config = FSDPConfig(cpu_offload=False)
Expand Down Expand Up @@ -262,8 +263,6 @@ def test_save_hf_with_mtp(self, device, sp_size):

# Verify all original HF weights are preserved correctly
for key in origin_index["weight_map"].keys():
if "mtp" in key:
continue # TODO: remove this after MTP is implemented
origin_safetensor_name = origin_index["weight_map"][key]
saved_safetensor_name = saved_index["weight_map"][key]

Expand Down
4 changes: 4 additions & 0 deletions tests/train/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def create_pg(self, device):

@patch("xtuner.v1.train.trainer.is_hf_model_path", Mock(return_value=True))
@patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine()))
@patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[]))
@prepare
def test_save_hf_interval(self):
"""Test save_hf is called at correct intervals during training."""
Expand Down Expand Up @@ -184,6 +185,7 @@ def test_save_hf_interval(self):

@patch("xtuner.v1.train.trainer.is_hf_model_path", Mock(return_value=True))
@patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine()))
@patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[]))
@prepare
def test_save_checkpoint_interval(self):
self.create_pg(DEVICE)
Expand Down Expand Up @@ -258,6 +260,7 @@ def test_save_checkpoint_interval(self):

@patch("xtuner.v1.train.trainer.is_hf_model_path", Mock(return_value=True))
@patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine()))
@patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[]))
@prepare
def test_resume(self):
self.create_pg(DEVICE)
Expand Down Expand Up @@ -738,6 +741,7 @@ def __call__(self, checkpoint, step, epoch, total_step, total_epoch):
assert len(loaded.get_hooks(HookStage.AFTER_SAVE_DCP)) == 1


@patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[]))
@patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine()))
def test_resume_and_load_checkpoint_cfg(tmp_path: Path):
# 0. prepare environment
Expand Down
3 changes: 3 additions & 0 deletions xtuner/v1/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ZLossContext,
ZLossKwargs,
)
from .mtp_loss import MTPLossContext
from .rl_loss import LogProbConfig, LogProbContext


Expand All @@ -29,6 +30,8 @@
"BaseLossConfig",
"BaseLossContext",
"BaseLossKwargs",
"LMHeadLossContext",
"MTPLossContext",
"LogProbConfig",
"LogProbContext",
]
Expand Down
92 changes: 92 additions & 0 deletions xtuner/v1/loss/mtp_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) OpenMMLab. All rights reserved.
from torch.distributed.device_mesh import DeviceMesh

from xtuner.v1.loss.ce_loss import CELossConfig, CELossKwargs, LMHeadLossContext
from xtuner.v1.module.mtp.utils import roll_packed_tensor
from xtuner.v1.utils.device import get_device


DEVICE = get_device()


class MTPLossKwargs(CELossKwargs):
"""Keyword arguments for MTP loss computation.

Inherits all fields from CELossKwargs. The ``shifted_labels`` field is
expected to be pre-rolled by ``MTPLossConfig.build()`` before this object
is constructed, so no additional fields are required.

Args:
shifted_labels (torch.Tensor): The shifted and rolled labels for MTP
loss computation.
loss_weight (torch.Tensor | None): Per-token loss weight.
"""


class MTPLossConfig(CELossConfig):
"""Loss configuration for Multi-Token Prediction (MTP).

Extends ``CELossConfig`` with a ``mtp_depth`` field that controls how many
additional positions the labels are rolled during ``build()``. This class
is intended for internal use by the model and is not exposed to users.

Args:
mtp_depth (int): 1-indexed MTP layer depth. The first MTP layer uses
``mtp_depth=1`` (shift=-1 on top of the existing label shift).
"""

mtp_depth: int

@property
def loss_ctx_cls(self) -> type["MTPLossContext"]:
return MTPLossContext

@property
def _loss_kwargs_cls(self) -> type["MTPLossKwargs"]:
return MTPLossKwargs

def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "MTPLossContext | None":
"""Build MTPLossContext from data dict.

Rolls ``shifted_labels`` by ``-mtp_depth`` positions (per-sequence,
respecting packed-sequence boundaries) before constructing the loss
context. The roll is performed on the full sequence prior to any
sequence-parallel split so that boundary positions and ``cu_seq_lens``
are always consistent.

Args:
data (dict): Data dict containing loss-related fields.
Required keys: ``shifted_labels``, ``seq_ctx``.
sp_mesh (DeviceMesh | None): Sequence parallel mesh.

Returns:
MTPLossContext | None: Built loss context, or ``None`` if
``shifted_labels`` is not present in ``data``.
"""
if "shifted_labels" not in data:
return None

shifted_labels = data["shifted_labels"]
cu_seq_lens = data["seq_ctx"].cu_seq_lens_k

rolled = roll_packed_tensor(shifted_labels, cu_seq_lens, shifts=-self.mtp_depth, dim=-1, fill_value=-100)

loss_kwargs = MTPLossKwargs(shifted_labels=rolled).to(DEVICE)
if sp_mesh is not None and sp_mesh.size() > 1:
loss_kwargs = loss_kwargs.sp_split(sp_mesh)

return MTPLossContext(self, loss_kwargs)


class MTPLossContext(LMHeadLossContext):
"""Loss context for Multi-Token Prediction (MTP).

Inherits all computation logic from ``LMHeadLossContext``. The label
rolling is handled upstream in ``MTPLossConfig.build()``, so no override
is needed here.

Args:
loss_cfg (MTPLossConfig): The MTP loss configuration.
loss_kwargs (MTPLossKwargs): Pre-rolled keyword arguments for loss
computation.
"""
3 changes: 2 additions & 1 deletion xtuner/v1/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .dense.qwen3 import Qwen3Dense0P6BConfig, Qwen3Dense4BConfig, Qwen3Dense8BConfig, Qwen3DenseConfig
from .moe.deepseek_v3 import DeepSeekV3Config
from .moe.gpt_oss import GptOss21BA3P6Config, GptOss117BA5P8Config, GptOssConfig
from .moe.moe import BalancingLossConfig, MoE, MoEModelOutputs, ZLossConfig
from .moe.moe import BalancingLossConfig, MoE, MoEConfig, MoEModelOutputs, ZLossConfig
from .moe.qwen3 import Qwen3MoE30BA3Config, Qwen3MoEConfig, Qwen3MoEFoPEConfig


Expand Down Expand Up @@ -87,6 +87,7 @@ def get_model_config_from_hf(model_path: Path):
"get_model_config",
"get_model_config_from_hf",
"MoE",
"MoEConfig",
"MoEModelOutputs",
"BalancingLossConfig",
"ZLossConfig",
Expand Down
3 changes: 3 additions & 0 deletions xtuner/v1/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,6 +1521,9 @@ def _load_fused_hf_param(
continue
_loaded_tensor.append(weight.to(local_tensor.device))

if not _loaded_tensor:
return missing_keys

if not hf_keys:
# fp8 pad
assert self.config.float8_cfg is not None
Expand Down
6 changes: 3 additions & 3 deletions xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from xtuner.v1.model.base import TransformerConfig
from xtuner.v1.model.moe.moe import MoEConfig
from xtuner.v1.model.moe.qwen3_5_text import Qwen3_5_VLTextMoE35BA3BConfig
from xtuner.v1.utils import get_logger

Expand All @@ -19,7 +19,7 @@ class Qwen3_5_ProjectorConfig(Qwen3VLProjectorConfig):
class Qwen3_5_BaseConfig(Qwen3VLBaseConfig):
vision_config: Qwen3_5_VisionConfig
projector_config: Qwen3_5_ProjectorConfig
text_config: TransformerConfig
text_config: MoEConfig

image_token_id: int = 248056
video_token_id: int = 248057
Expand All @@ -30,4 +30,4 @@ class Qwen3_5_BaseConfig(Qwen3VLBaseConfig):
class Qwen3_5_VLMoE35BA3Config(Qwen3_5_BaseConfig):
vision_config: Qwen3_5_VisionConfig = Qwen3_5_VisionConfig()
projector_config: Qwen3_5_ProjectorConfig = Qwen3_5_ProjectorConfig()
text_config: TransformerConfig = Qwen3_5_VLTextMoE35BA3BConfig()
text_config: MoEConfig = Qwen3_5_VLTextMoE35BA3BConfig()
24 changes: 16 additions & 8 deletions xtuner/v1/model/dense/qwen3vl_text.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import re

import torch
import torch.nn.functional as F

from xtuner.v1.data_proto import SequenceContext
from xtuner.v1.loss import CELossContext
from xtuner.v1.loss import BaseLossContext
from xtuner.v1.model.base import ModelOutputs

from .qwen3 import Qwen3Dense, Qwen3Dense4BConfig, Qwen3Dense8BConfig
Expand Down Expand Up @@ -34,10 +35,10 @@ def _deepstack_process(
hidden_states[visual_pos_masks, :] = local_this
return hidden_states

def forward(
def forward( # type: ignore[override]
self,
seq_ctx: SequenceContext, # todo(@yehaochen): support intra layer micro-batch
loss_ctx: CELossContext,
loss_ctx: dict[str, BaseLossContext | list[BaseLossContext]] | None = None,
) -> ModelOutputs:
input_ids = seq_ctx.input_ids
position_ids = seq_ctx.position_ids
Expand Down Expand Up @@ -78,11 +79,18 @@ def forward(

hidden_states = self.norm(hidden_states)

loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx)
output["loss"] = loss
output["logits"] = logits
output["extra_info"] = extra_info
return ModelOutputs(**output) # type: ignore[typeddict-item]
if loss_ctx is None:
# Inference mode
logits = F.linear(hidden_states, self.lm_head.weight, self.lm_head.bias)
output["logits"] = logits
else:
# Training mode
loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx["lm"]) # type: ignore[call-overload]
output["loss"] = loss
output["logits"] = logits
output["extra_info"] = extra_info

return ModelOutputs(**output)


class Qwen3VLTextDense4BConfig(Qwen3Dense4BConfig):
Expand Down
Loading
Loading