Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 178 additions & 0 deletions examples/v1/config/rl_qwen3p5_vl_35B_grpo_mixdata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import os
from transformers import AutoTokenizer
from xtuner.v1.config import (
AdamWConfig,
FSDPConfig,
LRConfig,
)
import json
from xtuner.v1.ray.judger.dapo_math import DapoMathJudgerConfig
from xtuner.v1.data_proto.rl_data import SampleParams
from xtuner.v1.model import Qwen3_5_VLMoE35BA3Config
from xtuner.v1.ray.base import AcceleratorResourcesConfig
from xtuner.v1.ray.config.worker import RolloutConfig
from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig
from xtuner.v1.ray.judger.controller import JudgerConfig
from xtuner.v1.rl.base import WorkerConfig
from xtuner.v1.rl.grpo import GRPOLossConfig
from xtuner.v1.train.rl_trainer import RLTrainerConfig
from xtuner.v1.datasets import RLTokenizeFnConfig, DatasetConfig, Qwen3VLTokenizeFnConfig, DataloaderConfig
from xtuner.v1.rl.base.rollout_is import RolloutImportanceSampling

work_dir = os.environ["WORK_DIR"]
model_path = os.environ["MODEL_PATH"]
meta_data_path = os.environ["DATA_PATH"]

# basic settings
experimental_name = "grpo_mix_data"
total_epochs = 15
global_batch_size = 256
prompt_repeat_k = 8
rollout_tp_size = 2
rollout_ep_size = 1
max_prompt_length = 2048
max_response_length = 8192
pack_max_length = 32768
train_optimizer_steps = 8
hf_interval = 15

# 1. resources
resources = AcceleratorResourcesConfig(
accelerator="GPU",
num_workers=8,
num_cpus_per_worker=12,
cpu_memory_per_worker=16 * 1024**3, # 16 GB
)

# 2. rollout
rollout_config = RolloutConfig(
fp32_lm_head=True,
env=experimental_name,
device=resources.accelerator,
model_path=model_path,
dtype="bfloat16",
tensor_parallel_size=rollout_tp_size,
expert_parallel_size=rollout_ep_size,
gpu_memory_utilization=0.8,
Comment on lines +55 to +56
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit: Extraneous spaces around = in keyword arguments (PEP 8):

Suggested change
expert_parallel_size=rollout_ep_size,
gpu_memory_utilization=0.8,
context_length=max_response_length + max_prompt_length,
enable_return_routed_experts=True,

context_length = max_response_length + max_prompt_length,
enable_return_routed_experts=True,
rollout_max_batch_size_per_instance=512,
)

# sampling params
training_sample_params = SampleParams(
max_tokens=max_response_length,
top_k=0,
top_p=1.0,
temperature=1.0,
min_tokens=0,
)


tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit: open() without a context manager — the file handle is never closed. Consider:

Suggested change
ds_collections = json.loads(Path(meta_data_path).read_text())

(Per CLAUDE.md, prefer pathlib.Path over os.path for filesystem operations.)

ds_collections = json.loads(open(meta_data_path).read())
train_dataset_cfg = []
for name, _data in ds_collections.items():
tokenize_fn_cfg = Qwen3VLTokenizeFnConfig(processor_path=model_path,
max_length=max_prompt_length,
system_message=_data.get('system_message', None),
template_name="qwen3.5-vl-rl")
_data_cfg = {"dataset": DatasetConfig(name=name,
anno_path=_data['annotation'],
media_root=_data.get('media_root', ''),
sample_ratio=_data.get('sample_ratio', 1.0),
class_name='VLMJsonlDataset'),
"tokenize_fn": RLTokenizeFnConfig(max_length=max_prompt_length,
tokenize_fn_cfg=tokenize_fn_cfg),
}
train_dataset_cfg.append(_data_cfg)

dataloader_config = DataloaderConfig(num_workers=8,
collator="fake_collator",
pack_level="none")

# 3. judger
from xtuner.v1.utils.rl_test_utils import get_eos_token
eos_token_id = get_eos_token(model_path)
eos_token_str = tokenizer.convert_ids_to_tokens(eos_token_id)
Comment on lines +97 to +98
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit: Mid-file import. Move from xtuner.v1.utils.rl_test_utils import get_eos_token to the top-level imports block for consistency.

dapomath_judger_config = DapoMathJudgerConfig(
judger_name="dapo_math",
eos_token=eos_token_str,
enable_overlong_buffer = True,
max_response_len =max_response_length,
overlong_buffer_len=4096,
overlong_penalty_factor=1.0,
tokenizer=tokenizer)
judger_cfg = JudgerConfig(reward_judger_configs=[dapomath_judger_config])

# 4. dataflow and evaluator
dataflow_config = DataFlowConfig(
env=experimental_name,
prompt_repeat_k=prompt_repeat_k,
global_batch_size=global_batch_size,
sample_params=training_sample_params,
# max_concurrent=64, # optional, will be determined automatically if not set
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit: Chinese comment (不需要修改). Use English for consistency across the codebase.

)


# replay buffer config: : 不需要修改
replay_buffer_cfg = ReplayBufferConfig(
dataset_cfg=train_dataset_cfg, dataloader_cfg=dataloader_config, tokenizer=tokenizer
)

# 5. Train worker
# NOTE: modify model_cfg
model_cfg = Qwen3_5_VLMoE35BA3Config(freeze_vision=True, freeze_projector=True)
optim_cfg = AdamWConfig(lr=1e-6, betas=(0.9, 0.999), max_grad_norm=1.0, weight_decay=0.1, foreach=False)
loss_cfg = GRPOLossConfig(
policy_loss_cfg=dict(
cliprange_high=0.28,
cliprange_low=0.2,
loss_type="vanilla",
clip_ratio_c=10.0,
log_prob_diff_min=-20.0,
log_prob_diff_max=20.0,
),
ignore_idx=-100,
use_kl_loss=False,
kl_loss_coef=0.0,
kl_loss_type="low_var_kl",
mode="chunk",
chunk_size=512,
rollout_is=RolloutImportanceSampling(
rollout_is_level="token",
rollout_is_mode="both",
rollout_is_threshold=(5, 0.5),
rollout_is_mask_threshold=(5, 0.5),
rollout_is_veto_threshold=(20, 0),
),
)
lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6)
fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1, fp32_lm_head=True)
train_worker_cfg: WorkerConfig = WorkerConfig(
model_cfg=model_cfg,
load_from=model_path,
optim_cfg=optim_cfg,
loss_cfg=loss_cfg,
lr_cfg=lr_cfg,
fsdp_cfg=fsdp_cfg,
sp_size=1,
optimizer_steps=train_optimizer_steps,
pack_max_length=pack_max_length,
)

# 6. RL Trainer
trainer = RLTrainerConfig(
load_from=model_path,
resources=resources,
rollout_config=rollout_config,
dataflow_config=dataflow_config,
judger_config=judger_cfg,
replay_buffer_config=replay_buffer_cfg,
train_worker_config=train_worker_cfg,
tokenizer_path=model_path,
work_dir=work_dir,
total_epochs=total_epochs,
hf_interval=hf_interval,
)
16 changes: 16 additions & 0 deletions xtuner/v1/data_proto/templates/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,22 @@
image_context_token="<|image_pad|>",
video_context_token="<|video_pad|>",
),
"qwen3.5-vl-rl": HybridChatTemplate(
system="<|im_start|>system\n{system}<|im_end|>\n",
tool_prompt="\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>{tool_text}"
"\n</tools>\n\nFor each function call, return a json object with function name and arguments within "
"""<tool_call></tool_call> XML tags:\n<tool_call>\n{{"name": <function-name>, """
""""arguments": <args-json-object>}}\n</tool_call>""", # TODO: fix tool call
tool_extractor="<|im_start|>user\n<tool_response>\n{tool_extractor}\n</tool_response><|im_end|>\n<|im_start|>assistant\n",
user="<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n<think>\n", # only add <think>\n to the end
stop_words=["<|im_end|>", "<|endoftext|>"],
assistant="{assistant}<|im_end|>",
image_start_token="<|vision_start|>",
image_end_token="<|vision_end|>",
image_context_token="<|image_pad|>",
video_context_token="<|video_pad|>",
),
"llama3": HybridChatTemplate(
system="<|start_header_id|>system<|end_header_id|>\n\n{system}<|eot_id|>",
user=(
Expand Down
6 changes: 5 additions & 1 deletion xtuner/v1/datasets/mllm_tokenize_fn/qwen3_vl_tokenize_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def __init__(
hash: str | None = None,
add_eos_token: bool = True, # for mllm pretrain
add_bos_token: bool = False, # for mllm pretrain
template_name: str = "qwen3-vl",
):
self.oss_loader = None
self.debug = debug
Expand Down Expand Up @@ -294,7 +295,7 @@ def __init__(
f"rand_video_max_frames: {self.rand_video_max_frames}"
)

self.chat_template = CHAT_TEMPLATE_MAP["qwen3-vl"]
self.chat_template = copy.deepcopy(CHAT_TEMPLATE_MAP[template_name])
if system_message is not None:
self.chat_template.default_system = system_message

Expand Down Expand Up @@ -904,13 +905,16 @@ class Qwen3VLTokenizeFnConfig(BaseMLLMTokenizeFnConfig):
# it's helpful to add labels to the images and videos for better reference.
add_vision_id: bool = True

template_name: str = "qwen3_vl"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Critical: Default value mismatch — the default here is "qwen3_vl" (underscore) but the CHAT_TEMPLATE_MAP keys use hyphens: "qwen3-vl" and "qwen3.5-vl-rl". This will cause a KeyError at runtime for anyone using the default value.

Suggested change
template_name: str = "qwen3_vl"
template_name: str = "qwen3-vl"


def build(
self, tokenizer, tokenizer_hash: str | None = None, anno_name: str = "", **kwargs
) -> Qwen3VLTokenizeFunction:
return Qwen3VLTokenizeFunction(
tokenizer,
self.processor_path,
anno_name,
template_name=self.template_name,
min_pixels=self.min_pixels,
max_pixels=self.max_pixels,
oss_loader_cfg=self.oss_loader_cfg,
Expand Down
3 changes: 2 additions & 1 deletion xtuner/v1/ray/dataflow/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ def sample(self, env: str, prompt_repeat_k: int) -> List[RLDataFlowItem]:
multimodal_train_info = data.pop("multimodal_train_info", {})
if "pixel_values" in multimodal_train_info:
multimodal_train_info["pixel_values"] = ray.put(multimodal_train_info["pixel_values"])
data["multimodal_train_info"] = multimodal_train_info
# If it is a mixture of pure text and image data, there will be position_id but no pixel_values
data["multimodal_train_info"] = multimodal_train_info

for data_item in group_data_item:
data_item.uid = RLUIDItem(
Expand Down
Loading