-
Notifications
You must be signed in to change notification settings - Fork 413
add qwen35 rl config and fix mix bug #1640
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,178 @@ | ||||||
| import os | ||||||
| from transformers import AutoTokenizer | ||||||
| from xtuner.v1.config import ( | ||||||
| AdamWConfig, | ||||||
| FSDPConfig, | ||||||
| LRConfig, | ||||||
| ) | ||||||
| import json | ||||||
| from xtuner.v1.ray.judger.dapo_math import DapoMathJudgerConfig | ||||||
| from xtuner.v1.data_proto.rl_data import SampleParams | ||||||
| from xtuner.v1.model import Qwen3_5_VLMoE35BA3Config | ||||||
| from xtuner.v1.ray.base import AcceleratorResourcesConfig | ||||||
| from xtuner.v1.ray.config.worker import RolloutConfig | ||||||
| from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig | ||||||
| from xtuner.v1.ray.judger.controller import JudgerConfig | ||||||
| from xtuner.v1.rl.base import WorkerConfig | ||||||
| from xtuner.v1.rl.grpo import GRPOLossConfig | ||||||
| from xtuner.v1.train.rl_trainer import RLTrainerConfig | ||||||
| from xtuner.v1.datasets import RLTokenizeFnConfig, DatasetConfig, Qwen3VLTokenizeFnConfig, DataloaderConfig | ||||||
| from xtuner.v1.rl.base.rollout_is import RolloutImportanceSampling | ||||||
|
|
||||||
| work_dir = os.environ["WORK_DIR"] | ||||||
| model_path = os.environ["MODEL_PATH"] | ||||||
| meta_data_path = os.environ["DATA_PATH"] | ||||||
|
|
||||||
| # basic settings | ||||||
| experimental_name = "grpo_mix_data" | ||||||
| total_epochs = 15 | ||||||
| global_batch_size = 256 | ||||||
| prompt_repeat_k = 8 | ||||||
| rollout_tp_size = 2 | ||||||
| rollout_ep_size = 1 | ||||||
| max_prompt_length = 2048 | ||||||
| max_response_length = 8192 | ||||||
| pack_max_length = 32768 | ||||||
| train_optimizer_steps = 8 | ||||||
| hf_interval = 15 | ||||||
|
|
||||||
| # 1. resources | ||||||
| resources = AcceleratorResourcesConfig( | ||||||
| accelerator="GPU", | ||||||
| num_workers=8, | ||||||
| num_cpus_per_worker=12, | ||||||
| cpu_memory_per_worker=16 * 1024**3, # 16 GB | ||||||
| ) | ||||||
|
|
||||||
| # 2. rollout | ||||||
| rollout_config = RolloutConfig( | ||||||
| fp32_lm_head=True, | ||||||
| env=experimental_name, | ||||||
| device=resources.accelerator, | ||||||
| model_path=model_path, | ||||||
| dtype="bfloat16", | ||||||
| tensor_parallel_size=rollout_tp_size, | ||||||
| expert_parallel_size=rollout_ep_size, | ||||||
| gpu_memory_utilization=0.8, | ||||||
| context_length = max_response_length + max_prompt_length, | ||||||
| enable_return_routed_experts=True, | ||||||
| rollout_max_batch_size_per_instance=512, | ||||||
| ) | ||||||
|
|
||||||
| # sampling params | ||||||
| training_sample_params = SampleParams( | ||||||
| max_tokens=max_response_length, | ||||||
| top_k=0, | ||||||
| top_p=1.0, | ||||||
| temperature=1.0, | ||||||
| min_tokens=0, | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | ||||||
|
|
||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: Nit:
Suggested change
(Per CLAUDE.md, prefer |
||||||
| ds_collections = json.loads(open(meta_data_path).read()) | ||||||
| train_dataset_cfg = [] | ||||||
| for name, _data in ds_collections.items(): | ||||||
| tokenize_fn_cfg = Qwen3VLTokenizeFnConfig(processor_path=model_path, | ||||||
| max_length=max_prompt_length, | ||||||
| system_message=_data.get('system_message', None), | ||||||
| template_name="qwen3.5-vl-rl") | ||||||
| _data_cfg = {"dataset": DatasetConfig(name=name, | ||||||
| anno_path=_data['annotation'], | ||||||
| media_root=_data.get('media_root', ''), | ||||||
| sample_ratio=_data.get('sample_ratio', 1.0), | ||||||
| class_name='VLMJsonlDataset'), | ||||||
| "tokenize_fn": RLTokenizeFnConfig(max_length=max_prompt_length, | ||||||
| tokenize_fn_cfg=tokenize_fn_cfg), | ||||||
| } | ||||||
| train_dataset_cfg.append(_data_cfg) | ||||||
|
|
||||||
| dataloader_config = DataloaderConfig(num_workers=8, | ||||||
| collator="fake_collator", | ||||||
| pack_level="none") | ||||||
|
|
||||||
| # 3. judger | ||||||
| from xtuner.v1.utils.rl_test_utils import get_eos_token | ||||||
| eos_token_id = get_eos_token(model_path) | ||||||
| eos_token_str = tokenizer.convert_ids_to_tokens(eos_token_id) | ||||||
|
Comment on lines
+97
to
+98
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: Nit: Mid-file import. Move |
||||||
| dapomath_judger_config = DapoMathJudgerConfig( | ||||||
| judger_name="dapo_math", | ||||||
| eos_token=eos_token_str, | ||||||
| enable_overlong_buffer = True, | ||||||
| max_response_len =max_response_length, | ||||||
| overlong_buffer_len=4096, | ||||||
| overlong_penalty_factor=1.0, | ||||||
| tokenizer=tokenizer) | ||||||
| judger_cfg = JudgerConfig(reward_judger_configs=[dapomath_judger_config]) | ||||||
|
|
||||||
| # 4. dataflow and evaluator | ||||||
| dataflow_config = DataFlowConfig( | ||||||
| env=experimental_name, | ||||||
| prompt_repeat_k=prompt_repeat_k, | ||||||
| global_batch_size=global_batch_size, | ||||||
| sample_params=training_sample_params, | ||||||
| # max_concurrent=64, # optional, will be determined automatically if not set | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: Nit: Chinese comment ( |
||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| # replay buffer config: : 不需要修改 | ||||||
| replay_buffer_cfg = ReplayBufferConfig( | ||||||
| dataset_cfg=train_dataset_cfg, dataloader_cfg=dataloader_config, tokenizer=tokenizer | ||||||
| ) | ||||||
|
|
||||||
| # 5. Train worker | ||||||
| # NOTE: modify model_cfg | ||||||
| model_cfg = Qwen3_5_VLMoE35BA3Config(freeze_vision=True, freeze_projector=True) | ||||||
| optim_cfg = AdamWConfig(lr=1e-6, betas=(0.9, 0.999), max_grad_norm=1.0, weight_decay=0.1, foreach=False) | ||||||
| loss_cfg = GRPOLossConfig( | ||||||
| policy_loss_cfg=dict( | ||||||
| cliprange_high=0.28, | ||||||
| cliprange_low=0.2, | ||||||
| loss_type="vanilla", | ||||||
| clip_ratio_c=10.0, | ||||||
| log_prob_diff_min=-20.0, | ||||||
| log_prob_diff_max=20.0, | ||||||
| ), | ||||||
| ignore_idx=-100, | ||||||
| use_kl_loss=False, | ||||||
| kl_loss_coef=0.0, | ||||||
| kl_loss_type="low_var_kl", | ||||||
| mode="chunk", | ||||||
| chunk_size=512, | ||||||
| rollout_is=RolloutImportanceSampling( | ||||||
| rollout_is_level="token", | ||||||
| rollout_is_mode="both", | ||||||
| rollout_is_threshold=(5, 0.5), | ||||||
| rollout_is_mask_threshold=(5, 0.5), | ||||||
| rollout_is_veto_threshold=(20, 0), | ||||||
| ), | ||||||
| ) | ||||||
| lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) | ||||||
| fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1, fp32_lm_head=True) | ||||||
| train_worker_cfg: WorkerConfig = WorkerConfig( | ||||||
| model_cfg=model_cfg, | ||||||
| load_from=model_path, | ||||||
| optim_cfg=optim_cfg, | ||||||
| loss_cfg=loss_cfg, | ||||||
| lr_cfg=lr_cfg, | ||||||
| fsdp_cfg=fsdp_cfg, | ||||||
| sp_size=1, | ||||||
| optimizer_steps=train_optimizer_steps, | ||||||
| pack_max_length=pack_max_length, | ||||||
| ) | ||||||
|
|
||||||
| # 6. RL Trainer | ||||||
| trainer = RLTrainerConfig( | ||||||
| load_from=model_path, | ||||||
| resources=resources, | ||||||
| rollout_config=rollout_config, | ||||||
| dataflow_config=dataflow_config, | ||||||
| judger_config=judger_cfg, | ||||||
| replay_buffer_config=replay_buffer_cfg, | ||||||
| train_worker_config=train_worker_cfg, | ||||||
| tokenizer_path=model_path, | ||||||
| work_dir=work_dir, | ||||||
| total_epochs=total_epochs, | ||||||
| hf_interval=hf_interval, | ||||||
| ) | ||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -233,6 +233,7 @@ def __init__( | |||||
| hash: str | None = None, | ||||||
| add_eos_token: bool = True, # for mllm pretrain | ||||||
| add_bos_token: bool = False, # for mllm pretrain | ||||||
| template_name: str = "qwen3-vl", | ||||||
| ): | ||||||
| self.oss_loader = None | ||||||
| self.debug = debug | ||||||
|
|
@@ -294,7 +295,7 @@ def __init__( | |||||
| f"rand_video_max_frames: {self.rand_video_max_frames}" | ||||||
| ) | ||||||
|
|
||||||
| self.chat_template = CHAT_TEMPLATE_MAP["qwen3-vl"] | ||||||
| self.chat_template = copy.deepcopy(CHAT_TEMPLATE_MAP[template_name]) | ||||||
| if system_message is not None: | ||||||
| self.chat_template.default_system = system_message | ||||||
|
|
||||||
|
|
@@ -904,13 +905,16 @@ class Qwen3VLTokenizeFnConfig(BaseMLLMTokenizeFnConfig): | |||||
| # it's helpful to add labels to the images and videos for better reference. | ||||||
| add_vision_id: bool = True | ||||||
|
|
||||||
| template_name: str = "qwen3_vl" | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: Critical: Default value mismatch — the default here is
Suggested change
|
||||||
|
|
||||||
| def build( | ||||||
| self, tokenizer, tokenizer_hash: str | None = None, anno_name: str = "", **kwargs | ||||||
| ) -> Qwen3VLTokenizeFunction: | ||||||
| return Qwen3VLTokenizeFunction( | ||||||
| tokenizer, | ||||||
| self.processor_path, | ||||||
| anno_name, | ||||||
| template_name=self.template_name, | ||||||
| min_pixels=self.min_pixels, | ||||||
| max_pixels=self.max_pixels, | ||||||
| oss_loader_cfg=self.oss_loader_cfg, | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Claude: Nit: Extraneous spaces around
=in keyword arguments (PEP 8):