-
Notifications
You must be signed in to change notification settings - Fork 413
support determinisitc RL training with sglang #1644
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| set -ex | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: Missing |
||
| ray stop --force | ||
| # examples of usage: | ||
| # qwen3_8B_grpo_gsm8k training: | ||
| # bash examples/v1/scripts/run_rl.sh examples/v1/config/rl_qwen3_8B_grpo.py "sglang" $MODEL_PATH $DATA_PATH $EVAL_DATA_PATH | ||
| # qwen2.5_7B_dapo_math training: | ||
| # bash examples/v1/scripts/run_rl.sh examples/v1/config/rl_qwen25_7B_dapo.py "sglang" $MODEL_PATH $DATA_PATH $EVAL_DATA_PATH | ||
|
Comment on lines
+3
to
+7
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: The usage comments reference |
||
|
|
||
| CONFIG_PATH=$1 | ||
| MODEL_PATH=$2 | ||
| DATA_PATH=$3 | ||
| EVAL_DATA_PATH=${4:-""} | ||
|
|
||
| export PYTHONPATH=$(pwd):$PYTHONPATH | ||
|
|
||
| # deterministic 环境变量 | ||
| # NOTE: you should use sglang==0.5.5 to reproduce our results deterministic results. | ||
| export XTUNER_USE_SGLANG=1 | ||
| export XTUNER_USE_LMDEPLOY=0 | ||
| export XTUNER_USE_VLLM=0 | ||
| export TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=0 | ||
| export XTUNER_DETERMINISTIC=true | ||
| export XTUNER_USE_FA3=0 | ||
| # sglang 环境变量 | ||
| unset PYTORCH_CUDA_ALLOC_CONF | ||
| export SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 | ||
|
|
||
| # ray 环境变量 | ||
| export MASTER_PORT=6000 | ||
| export WORLD_SIZE=${NODE_COUNT:-"1"} | ||
| export RANK=${NODE_RANK:-"0"} | ||
| export RAY_MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} | ||
| export RAY_RANK=${RANK:-0} # 0 代表主节点, >0 代表工作节点 | ||
| export RAY_HEAD_PORT=${RAY_HEAD_PORT:-"6379"} | ||
| export RAY_CLIENT_PORT=${RAY_CLIENT_PORT:-"10001"} | ||
| export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-"8265"} | ||
| # TODO: 提供非环境变量方式配置 ray_max_concurrency | ||
| export RAY_MAX_CONCURRENCY=${RAY_MAX_CONCURRENCY:-1024} # dataflow_max_concurrency * prompt_repeat_k | ||
|
|
||
| # xtuner 环境变量 | ||
| export MODEL_PATH=$MODEL_PATH | ||
| export DATA_PATH=$DATA_PATH | ||
| export EVAL_DATA_PATH=$EVAL_DATA_PATH | ||
| export XTUNER_LOG_LEVEL=${XTUNER_LOG_LEVEL:-"INFO"} | ||
| export PYTHONUNBUFFERED=1 | ||
|
|
||
| current_time=$(date "+%m%d%H") | ||
| # 取模型路径的最后一级作为model_name,取数据路径的倒数第二级作为data_name | ||
| model_dir_name=$(basename "$MODEL_PATH") | ||
| data_dir_name=$(basename "$(dirname "$DATA_PATH")") | ||
|
|
||
| if [ "x$WORK_DIR" = "x" ]; then | ||
| DIR=$(pwd) | ||
| export WORK_DIR="${DIR}/work_dirs/${model_dir_name}_${data_dir_name}_${infer_backend_lower}" | ||
| else | ||
| export WORK_DIR=$WORK_DIR | ||
| fi | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: |
||
| echo "WORK_DIR: $WORK_DIR" | ||
| if [ ! -d "$WORK_DIR" ]; then | ||
| mkdir -p "$WORK_DIR" | ||
| fi | ||
|
|
||
| export LMDEPLOY_LOG_FILE="${WORK_DIR}/lmdeploy_log_${current_time}.txt" | ||
| export XTUNER_RL_MEM_DIR="${WORK_DIR}/mem_${current_time}" | ||
|
|
||
|
|
||
| # 2. Launch Ray cluster | ||
| # 根据 NODE_COUNT 分配 num_cpus, 防止内存OOM | ||
| node_count=${NODE_COUNT:-1} | ||
| total_cpus=$((node_count * 128)) | ||
|
|
||
| WORK_DIR=$(realpath "$WORK_DIR") | ||
| if [ "$RAY_RANK" -eq 0 ]; then | ||
| rm -rf /tmp/ray_log | ||
| export RAY_LOG_DIR="${WORK_DIR}/ray_${current_time}/" | ||
| mkdir -p ${RAY_LOG_DIR} | ||
| ln -sfn "${RAY_LOG_DIR}" /tmp/ray_log | ||
| ray start --head \ | ||
| --node-ip-address="$RAY_MASTER_ADDR" \ | ||
| --port="$RAY_HEAD_PORT" \ | ||
| --dashboard-host=0.0.0.0 \ | ||
| --dashboard-port=$RAY_DASHBOARD_PORT \ | ||
| --include-dashboard=true \ | ||
| --disable-usage-stats \ | ||
| --num-cpus=$total_cpus \ | ||
| --temp-dir="/tmp/ray_log/" | ||
| else | ||
| while true; do | ||
| if curl --connect-timeout 2 "http://${RAY_MASTER_ADDR}:${RAY_DASHBOARD_PORT}" >/dev/null 2>&1; then | ||
| echo "Successfully connected to Ray master at ${RAY_MASTER_ADDR}:${RAY_DASHBOARD_PORT}" | ||
| break | ||
| else | ||
| echo "Waiting for Ray master at ${RAY_MASTER_ADDR}:${RAY_DASHBOARD_PORT} to be available..." | ||
| sleep 2 | ||
| fi | ||
| done | ||
| ray start --address="$RAY_MASTER_ADDR:$RAY_HEAD_PORT" --block --disable-usage-stats | ||
| fi | ||
|
|
||
| while true; do | ||
| result=$(ray status | grep ${ACCELERATOR} | cut -d ' ' -f2 | cut -d '/' -f2) | ||
| expected_accelerator_count=$((node_count * 8)) | ||
| if [ "$result" = "$expected_accelerator_count.0" ]; then | ||
| break | ||
| else | ||
| echo "Waiting for ${ACCELERATOR} count to be $expected_accelerator_count, current: $result" | ||
| sleep 2 | ||
| fi | ||
| done | ||
|
|
||
| SCRIPT_NAME=$(basename "$0") | ||
| cp "$0" "${WORK_DIR}/${SCRIPT_NAME}" | ||
| cp "$CONFIG_PATH" "${WORK_DIR}/config.py" | ||
| LOG_FILE="${WORK_DIR}/training_log_${current_time}.txt" | ||
|
|
||
| python xtuner/v1/train/cli/rl.py \ | ||
| --config $CONFIG_PATH \ | ||
| 2>&1 | tee -a "${WORK_DIR}/training_log_${current_time}.txt" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -429,6 +429,7 @@ class SampleParams(BaseModel): | |
| stops: Annotated[list[str], Parameter(help="List of stop sequences.")] = [] | ||
| stop_token_ids: Annotated[list[int], Parameter(help="List of stop token IDs.")] = [] | ||
| skip_special_tokens: Annotated[bool, Parameter(help="Whether to skip special tokens.")] = True | ||
| sampling_seed: Annotated[int, Parameter(help="The seed for random number generator in sampling.")] = 0 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: A default |
||
|
|
||
|
|
||
| class RolloutExtraParams(TypedDict): | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -16,7 +16,7 @@ | |||||
| update_rollout_item, | ||||||
| ) | ||||||
| from xtuner.v1.ray.environment.base_env import BaseEnvironment | ||||||
| from xtuner.v1.utils import get_logger, ray_method | ||||||
| from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger, ray_method | ||||||
|
|
||||||
|
|
||||||
| class RawSingleTurnEnvironment(BaseEnvironment): | ||||||
|
|
@@ -65,6 +65,7 @@ def __init__( | |||||
| # This should be longer than the controller's internal timeout (`rollout_timeout`) | ||||||
| # to account for potential queuing delays and other overheads. | ||||||
| self.timeout_multiplier = 2.0 | ||||||
| self.rollout_cfg = rollout_cfg | ||||||
|
|
||||||
| async def generate( # type: ignore[override] | ||||||
| self, group_data_items: List[RLDataFlowItem], sample_params=None, extra_params=None | ||||||
|
|
@@ -93,10 +94,12 @@ async def generate( # type: ignore[override] | |||||
| extra_params = {} | ||||||
| if self.rollout_controller: | ||||||
| response_future = [] | ||||||
| for sample in group_data_items: | ||||||
| for i, sample in enumerate(group_data_items): | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: Consider guarding the deepcopy:
Suggested change
|
||||||
| sample.data.extra_info["root_id"] = sample.uid.root_id | ||||||
| sample.data.extra_info["action_id"] = sample.uid.action_id | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: |
||||||
| update_sample_params = sample_params | ||||||
| update_sample_params = copy.deepcopy(sample_params) | ||||||
| if XTUNER_DETERMINISTIC: | ||||||
| update_sample_params.sampling_seed = self.rollout_cfg.random_seed + i | ||||||
|
|
||||||
| if "partial_rollout_input_ids" in sample.env.rollout.extra_info: | ||||||
| input_ids_length = len(sample.data.input_ids) if sample.data.input_ids is not None else 0 | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,7 @@ | |
|
|
||
| from transformers import AutoTokenizer | ||
| from xtuner.v1.ray.config import RolloutConfig | ||
| from xtuner.v1.utils import XTUNER_DETERMINISTIC | ||
|
|
||
| from .worker import RolloutWorker | ||
|
|
||
|
|
@@ -24,7 +25,7 @@ def __init__( | |
| from sglang.srt.entrypoints.http_server import launch_server | ||
|
|
||
| self.server_func = launch_server | ||
| self.endpoints["health_generate"] = "health_generate" | ||
| self.endpoints["health_generate"] = "health" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: Changing the health endpoint from Was this change intentional? If so, please document the reason (e.g., SGLang version compatibility). |
||
| self.endpoints["generate"] = "generate" | ||
| self.endpoints["v1/chat/completions"] = "v1/chat/completions" | ||
| self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path, trust_remote_code=True) | ||
|
|
@@ -86,6 +87,14 @@ def _make_request(self, endpoint: str, payload=None): | |
| response.raise_for_status() | ||
| return response.json() | ||
|
|
||
| def check_health(self) -> bool: | ||
| try: | ||
| response = requests.get(f"{self.server_url}/{self.endpoints['health_generate']}", timeout=5.0) | ||
| return response.status_code == 200 | ||
| except requests.RequestException as e: | ||
| self.logger.error(f"Health check failed for server {self.server_url}: {e}") | ||
| return False | ||
|
|
||
| def flush_cache(self): | ||
| """Flush the cache of the server.""" | ||
| # TODO: 支持 tp | ||
|
|
@@ -144,10 +153,8 @@ def _transform_rollout_config_to_server_configs(self): | |
| grammar_backend = sglang_config_kwargs.get( | ||
| "grammar_backend", None | ||
| ) # for intern-s1 series models, have to set the grammar_backend to "none" | ||
| log_level = sglang_config_kwargs.get("log_level", "critical") | ||
| log_level_http = sglang_config_kwargs.get("log_level_http", "critical") | ||
| enable_deterministic_inference = sglang_config_kwargs.get("enable_deterministic_inference", False) | ||
|
|
||
| log_level = sglang_config_kwargs.get("log_level", "error") | ||
| log_level_http = sglang_config_kwargs.get("log_level_http", "error") | ||
| sglang_server_args = ServerArgs(model_path=self.config.model_path, trust_remote_code=True) | ||
| num_gpus_per_engine = ( | ||
| self.config.expert_parallel_size | ||
|
|
@@ -162,7 +169,6 @@ def _transform_rollout_config_to_server_configs(self): | |
| sglang_server_args.gpu_id_step = 1 | ||
| sglang_server_args.nnodes = max(1, num_gpus_per_engine // self.config.gpus_per_node) | ||
| sglang_server_args.skip_server_warmup = True | ||
|
|
||
| sglang_server_args.mem_fraction_static = self.config.gpu_memory_utilization | ||
| # note: 非共卡模式下无需设置,共卡模式下需要offload必须设置,否则显存释放不了 | ||
| sglang_server_args.enable_memory_saver = True | ||
|
|
@@ -173,8 +179,9 @@ def _transform_rollout_config_to_server_configs(self): | |
| sglang_server_args.max_running_requests = self.config.rollout_max_batch_size_per_instance | ||
| sglang_server_args.log_level = log_level | ||
| sglang_server_args.log_level_http = log_level_http | ||
| sglang_server_args.enable_deterministic_inference = enable_deterministic_inference | ||
|
|
||
| if XTUNER_DETERMINISTIC: | ||
| sglang_server_args.enable_deterministic_inference = True | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: Is this targeting a custom/internal fork of SGLang? If so, please add a comment noting the required SGLang version or fork, so users don't hit runtime errors with standard SGLang. |
||
| sglang_server_args.rl_on_policy_target = True | ||
| if self.config.expert_parallel_size > 1: | ||
| sglang_server_args.tp_size = num_gpus_per_engine | ||
| sglang_server_args.ep_size = num_gpus_per_engine | ||
|
|
@@ -212,6 +219,8 @@ def _transform_sample_params(self, sample_params: Dict): | |
| "stop_token_ids": sample_params["stop_token_ids"], | ||
| "skip_special_tokens": sample_params["skip_special_tokens"], | ||
| } | ||
| if XTUNER_DETERMINISTIC: | ||
| sglang_sample_params["sampling_seed"] = sample_params["sampling_seed"] | ||
| return sglang_sample_params | ||
|
|
||
| def _transform_extra_params(self, extra_params: Dict): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Claude: Typo:
enbale_partial_rolloutshould beenable_partial_rollout. The env var nameENBALE_PARTIAL_ROLLOUThas the same typo.I see this typo is inherited from other config files (
rl_qwen3_8B_grpo.py,rl_qwen25_7B_dapo.py), so it's a pre-existing issue. However, since this is a rename/cleanup of this config file, it would be a good opportunity to fix it here and in the env var name.