Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
data_path = os.environ["DATA_PATH"]
eval_data_path = os.environ["EVAL_DATA_PATH"]
enable_evaluate = True if eval_data_path != "" else False
enbale_partial_rollout = int(os.environ.get("ENBALE_PARTIAL_ROLLOUT", "0"))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Typo: enbale_partial_rollout should be enable_partial_rollout. The env var name ENBALE_PARTIAL_ROLLOUT has the same typo.

I see this typo is inherited from other config files (rl_qwen3_8B_grpo.py, rl_qwen25_7B_dapo.py), so it's a pre-existing issue. However, since this is a rename/cleanup of this config file, it would be a good opportunity to fix it here and in the env var name.


# basic settings
experimental_name = "grpo_gsm8k_tiny"
Expand All @@ -40,6 +41,8 @@
max_response_length = 1024
pack_max_length = 32768
train_optimizer_steps = 1
hf_interval = 100
enable_initial_evaluate = True
evaluate_step = 15

# 1. resources
Expand Down Expand Up @@ -93,13 +96,13 @@
prompt_repeat_k=prompt_repeat_k,
global_batch_size=global_batch_size,
sample_params=training_sample_params,
max_concurrent=512,
enable_partial_rollout=enbale_partial_rollout,
)

evaluator_cfg = (
EvaluatorConfig(
enable_evaluate=enable_evaluate,
enable_initial_evaluate=True,
enable_initial_evaluate=enable_initial_evaluate,
dataset_cfg=eval_dataset_cfg,
tokenizer=tokenizer,
evaluate_step=evaluate_step,
Expand Down Expand Up @@ -158,5 +161,6 @@
tokenizer_path=model_path,
work_dir=work_dir,
total_epochs=total_epochs,
hf_interval=hf_interval,
exp_tracker="jsonl",
)
118 changes: 118 additions & 0 deletions examples/v1/scripts/run_rl_deterministic.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
set -ex
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Missing #!/bin/bash shebang on line 1. The script starts with set -ex directly. Also, missing a newline at the end of the file (line 117).

ray stop --force
# examples of usage:
# qwen3_8B_grpo_gsm8k training:
# bash examples/v1/scripts/run_rl.sh examples/v1/config/rl_qwen3_8B_grpo.py "sglang" $MODEL_PATH $DATA_PATH $EVAL_DATA_PATH
# qwen2.5_7B_dapo_math training:
# bash examples/v1/scripts/run_rl.sh examples/v1/config/rl_qwen25_7B_dapo.py "sglang" $MODEL_PATH $DATA_PATH $EVAL_DATA_PATH
Comment on lines +3 to +7
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: The usage comments reference run_rl.sh but this script is run_rl_deterministic.sh. Also, the usage shows 5 positional args (CONFIG_PATH, "sglang", MODEL_PATH, DATA_PATH, EVAL_DATA_PATH) but the script only accepts 4 (CONFIG_PATH, MODEL_PATH, DATA_PATH, EVAL_DATA_PATH) — it doesn't take an inference backend argument. These comments appear to be copy-pasted from run_rl.sh without updating.


CONFIG_PATH=$1
MODEL_PATH=$2
DATA_PATH=$3
EVAL_DATA_PATH=${4:-""}

export PYTHONPATH=$(pwd):$PYTHONPATH

# deterministic 环境变量
# NOTE: you should use sglang==0.5.5 to reproduce our results deterministic results.
export XTUNER_USE_SGLANG=1
export XTUNER_USE_LMDEPLOY=0
export XTUNER_USE_VLLM=0
export TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=0
export XTUNER_DETERMINISTIC=true
export XTUNER_USE_FA3=0
# sglang 环境变量
unset PYTORCH_CUDA_ALLOC_CONF
export SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1

# ray 环境变量
export MASTER_PORT=6000
export WORLD_SIZE=${NODE_COUNT:-"1"}
export RANK=${NODE_RANK:-"0"}
export RAY_MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
export RAY_RANK=${RANK:-0} # 0 代表主节点, >0 代表工作节点
export RAY_HEAD_PORT=${RAY_HEAD_PORT:-"6379"}
export RAY_CLIENT_PORT=${RAY_CLIENT_PORT:-"10001"}
export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-"8265"}
# TODO: 提供非环境变量方式配置 ray_max_concurrency
export RAY_MAX_CONCURRENCY=${RAY_MAX_CONCURRENCY:-1024} # dataflow_max_concurrency * prompt_repeat_k

# xtuner 环境变量
export MODEL_PATH=$MODEL_PATH
export DATA_PATH=$DATA_PATH
export EVAL_DATA_PATH=$EVAL_DATA_PATH
export XTUNER_LOG_LEVEL=${XTUNER_LOG_LEVEL:-"INFO"}
export PYTHONUNBUFFERED=1

current_time=$(date "+%m%d%H")
# 取模型路径的最后一级作为model_name,取数据路径的倒数第二级作为data_name
model_dir_name=$(basename "$MODEL_PATH")
data_dir_name=$(basename "$(dirname "$DATA_PATH")")

if [ "x$WORK_DIR" = "x" ]; then
DIR=$(pwd)
export WORK_DIR="${DIR}/work_dirs/${model_dir_name}_${data_dir_name}_${infer_backend_lower}"
else
export WORK_DIR=$WORK_DIR
fi
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: infer_backend_lower is used here but never defined in this script. This will produce a work directory path ending with an empty string (e.g., work_dirs/model_data_). Looks like it was copied from run_rl.sh where the inference backend variable exists.

echo "WORK_DIR: $WORK_DIR"
if [ ! -d "$WORK_DIR" ]; then
mkdir -p "$WORK_DIR"
fi

export LMDEPLOY_LOG_FILE="${WORK_DIR}/lmdeploy_log_${current_time}.txt"
export XTUNER_RL_MEM_DIR="${WORK_DIR}/mem_${current_time}"


# 2. Launch Ray cluster
# 根据 NODE_COUNT 分配 num_cpus, 防止内存OOM
node_count=${NODE_COUNT:-1}
total_cpus=$((node_count * 128))

WORK_DIR=$(realpath "$WORK_DIR")
if [ "$RAY_RANK" -eq 0 ]; then
rm -rf /tmp/ray_log
export RAY_LOG_DIR="${WORK_DIR}/ray_${current_time}/"
mkdir -p ${RAY_LOG_DIR}
ln -sfn "${RAY_LOG_DIR}" /tmp/ray_log
ray start --head \
--node-ip-address="$RAY_MASTER_ADDR" \
--port="$RAY_HEAD_PORT" \
--dashboard-host=0.0.0.0 \
--dashboard-port=$RAY_DASHBOARD_PORT \
--include-dashboard=true \
--disable-usage-stats \
--num-cpus=$total_cpus \
--temp-dir="/tmp/ray_log/"
else
while true; do
if curl --connect-timeout 2 "http://${RAY_MASTER_ADDR}:${RAY_DASHBOARD_PORT}" >/dev/null 2>&1; then
echo "Successfully connected to Ray master at ${RAY_MASTER_ADDR}:${RAY_DASHBOARD_PORT}"
break
else
echo "Waiting for Ray master at ${RAY_MASTER_ADDR}:${RAY_DASHBOARD_PORT} to be available..."
sleep 2
fi
done
ray start --address="$RAY_MASTER_ADDR:$RAY_HEAD_PORT" --block --disable-usage-stats
fi

while true; do
result=$(ray status | grep ${ACCELERATOR} | cut -d ' ' -f2 | cut -d '/' -f2)
expected_accelerator_count=$((node_count * 8))
if [ "$result" = "$expected_accelerator_count.0" ]; then
break
else
echo "Waiting for ${ACCELERATOR} count to be $expected_accelerator_count, current: $result"
sleep 2
fi
done

SCRIPT_NAME=$(basename "$0")
cp "$0" "${WORK_DIR}/${SCRIPT_NAME}"
cp "$CONFIG_PATH" "${WORK_DIR}/config.py"
LOG_FILE="${WORK_DIR}/training_log_${current_time}.txt"

python xtuner/v1/train/cli/rl.py \
--config $CONFIG_PATH \
2>&1 | tee -a "${WORK_DIR}/training_log_${current_time}.txt"
1 change: 1 addition & 0 deletions xtuner/v1/data_proto/rl_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ class SampleParams(BaseModel):
stops: Annotated[list[str], Parameter(help="List of stop sequences.")] = []
stop_token_ids: Annotated[list[int], Parameter(help="List of stop token IDs.")] = []
skip_special_tokens: Annotated[bool, Parameter(help="Whether to skip special tokens.")] = True
sampling_seed: Annotated[int, Parameter(help="The seed for random number generator in sampling.")] = 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: A default sampling_seed of 0 can be ambiguous — some backends treat 0 as "use random seed" while others treat it as an actual seed value. Consider using None as the default with int | None type, and only pass it to the backend when XTUNER_DETERMINISTIC is enabled. This would make the non-deterministic path clearer and avoid accidentally seeding when determinism isn't intended.



class RolloutExtraParams(TypedDict):
Expand Down
22 changes: 20 additions & 2 deletions xtuner/v1/ray/base/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,25 @@ def device_visible_env_name(self):
else:
raise ValueError(f"Unsupported accelerator type: {self.accelerator}")

def get_logical_local_rank(self) -> int:
"""Resolve the assigned accelerator id to the logical local rank.

Ray reports accelerator ids in the physical numbering space. Torch selects devices from the current visible-
device list, which is indexed logically from zero after applying visibility masks.
"""
accelerator_id = str(ray.get_runtime_context().get_accelerator_ids()[self.accelerator][0])
visible_devices = os.environ.get(self.device_visible_env_name)
if visible_devices is None:
return int(accelerator_id)

visible_device_ids = [device_id.strip() for device_id in visible_devices.split(",") if device_id.strip()]
if accelerator_id not in visible_device_ids:
raise ValueError(
f"Assigned accelerator id {accelerator_id} is not present in "
f"{self.device_visible_env_name}={visible_devices}."
)
return visible_device_ids.index(accelerator_id)

def setup_distributed(self, rank: int, master_addr: str, master_port: int, world_size: int):
"""Set up the distributed environment for the worker.

Expand All @@ -215,8 +234,7 @@ def setup_distributed(self, rank: int, master_addr: str, master_port: int, world
os.environ["MASTER_PORT"] = str(master_port)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["LOCAL_RANK"] = str(ray.get_runtime_context().get_accelerator_ids()[self.accelerator][0])

os.environ["LOCAL_RANK"] = str(self.get_logical_local_rank())
# backend 参数是指定通信后端,不是从环境变量获取
# - 'nccl': NVIDIA GPU 间通信(推荐用于 GPU)
# - 'gloo': CPU 通信或跨平台
Expand Down
16 changes: 11 additions & 5 deletions xtuner/v1/ray/dataflow/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
is_valid_for_replaybuffer,
)
from xtuner.v1.datasets.config import DataloaderConfig
from xtuner.v1.utils import get_logger
from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger
from xtuner.v1.utils.device import get_device


Expand Down Expand Up @@ -246,6 +246,7 @@ def __init__(self, dataset_cfg, dataloader_cfg, tokenizer):
self.dataloader_iter = iter(self.dataloader)
self.cur_epoch = 0
self.reduced_consumed_samples = 0
self._next_root_id = 0
self.logger = get_logger()

def sample(self, env: str, prompt_repeat_k: int) -> List[RLDataFlowItem]:
Expand All @@ -263,8 +264,13 @@ def sample(self, env: str, prompt_repeat_k: int) -> List[RLDataFlowItem]:
Returns:
List[RLDataFlowItem]: A list of newly created data items for a rollout.
"""
root_id = uuid4().int
action_id = uuid4().int
if XTUNER_DETERMINISTIC:
root_id = max(self._next_root_id, self.reduced_consumed_samples * prompt_repeat_k)
action_id = root_id
self._next_root_id = root_id + prompt_repeat_k
else:
root_id = uuid4().int
action_id = uuid4().int
group_data_item: List[RLDataFlowItem] = [RLDataFlowItem() for _ in range(prompt_repeat_k)]
try:
data = next(self.dataloader_iter)[0]
Expand All @@ -280,12 +286,12 @@ def sample(self, env: str, prompt_repeat_k: int) -> List[RLDataFlowItem]:
multimodal_train_info["pixel_values"] = ray.put(multimodal_train_info["pixel_values"])
data["multimodal_train_info"] = multimodal_train_info

for data_item in group_data_item:
for item_idx, data_item in enumerate(group_data_item):
data_item.uid = RLUIDItem(
env=env,
root_id=root_id,
action_id=action_id,
observation_id=uuid4().int,
observation_id=root_id + item_idx if XTUNER_DETERMINISTIC else uuid4().int,
)
data_item.data = RLDatasetItem(**data)
data_item.extra_info = RLExtraDataItem(retry_times=0)
Expand Down
9 changes: 6 additions & 3 deletions xtuner/v1/ray/environment/single_turn_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
update_rollout_item,
)
from xtuner.v1.ray.environment.base_env import BaseEnvironment
from xtuner.v1.utils import get_logger, ray_method
from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger, ray_method


class RawSingleTurnEnvironment(BaseEnvironment):
Expand Down Expand Up @@ -65,6 +65,7 @@ def __init__(
# This should be longer than the controller's internal timeout (`rollout_timeout`)
# to account for potential queuing delays and other overheads.
self.timeout_multiplier = 2.0
self.rollout_cfg = rollout_cfg

async def generate( # type: ignore[override]
self, group_data_items: List[RLDataFlowItem], sample_params=None, extra_params=None
Expand Down Expand Up @@ -93,10 +94,12 @@ async def generate( # type: ignore[override]
extra_params = {}
if self.rollout_controller:
response_future = []
for sample in group_data_items:
for i, sample in enumerate(group_data_items):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: copy.deepcopy(sample_params) is now called unconditionally for every sample in the loop, whereas before it was only called in the partial rollout branch. This is a performance regression for the non-deterministic path — deep-copying a Pydantic BaseModel on every iteration is unnecessary overhead when the params aren't being mutated.

Consider guarding the deepcopy:

Suggested change
for i, sample in enumerate(group_data_items):
update_sample_params = copy.deepcopy(sample_params) if XTUNER_DETERMINISTIC else sample_params

sample.data.extra_info["root_id"] = sample.uid.root_id
sample.data.extra_info["action_id"] = sample.uid.action_id
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: self.rollout_cfg is accessed here to get random_seed, but rollout_cfg can be None (see the __init__ signature at line 44: rollout_cfg=None). If XTUNER_DETERMINISTIC is True and rollout_cfg is None, this will raise AttributeError. Add a guard or assert that rollout_cfg is not None when deterministic mode is enabled.

update_sample_params = sample_params
update_sample_params = copy.deepcopy(sample_params)
if XTUNER_DETERMINISTIC:
update_sample_params.sampling_seed = self.rollout_cfg.random_seed + i

if "partial_rollout_input_ids" in sample.env.rollout.extra_info:
input_ids_length = len(sample.data.input_ids) if sample.data.input_ids is not None else 0
Expand Down
25 changes: 17 additions & 8 deletions xtuner/v1/ray/rollout/sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from transformers import AutoTokenizer
from xtuner.v1.ray.config import RolloutConfig
from xtuner.v1.utils import XTUNER_DETERMINISTIC

from .worker import RolloutWorker

Expand All @@ -24,7 +25,7 @@ def __init__(
from sglang.srt.entrypoints.http_server import launch_server

self.server_func = launch_server
self.endpoints["health_generate"] = "health_generate"
self.endpoints["health_generate"] = "health"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Changing the health endpoint from "health_generate" to "health" is a behavioral change that isn't related to deterministic inference. The "health_generate" endpoint checks if the server can actually handle generate requests, while "health" may only confirm the HTTP server is up. This could mask readiness issues where the server is alive but not ready to serve inference.

Was this change intentional? If so, please document the reason (e.g., SGLang version compatibility).

self.endpoints["generate"] = "generate"
self.endpoints["v1/chat/completions"] = "v1/chat/completions"
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path, trust_remote_code=True)
Expand Down Expand Up @@ -86,6 +87,14 @@ def _make_request(self, endpoint: str, payload=None):
response.raise_for_status()
return response.json()

def check_health(self) -> bool:
try:
response = requests.get(f"{self.server_url}/{self.endpoints['health_generate']}", timeout=5.0)
return response.status_code == 200
except requests.RequestException as e:
self.logger.error(f"Health check failed for server {self.server_url}: {e}")
return False

def flush_cache(self):
"""Flush the cache of the server."""
# TODO: 支持 tp
Expand Down Expand Up @@ -144,10 +153,8 @@ def _transform_rollout_config_to_server_configs(self):
grammar_backend = sglang_config_kwargs.get(
"grammar_backend", None
) # for intern-s1 series models, have to set the grammar_backend to "none"
log_level = sglang_config_kwargs.get("log_level", "critical")
log_level_http = sglang_config_kwargs.get("log_level_http", "critical")
enable_deterministic_inference = sglang_config_kwargs.get("enable_deterministic_inference", False)

log_level = sglang_config_kwargs.get("log_level", "error")
log_level_http = sglang_config_kwargs.get("log_level_http", "error")
sglang_server_args = ServerArgs(model_path=self.config.model_path, trust_remote_code=True)
num_gpus_per_engine = (
self.config.expert_parallel_size
Expand All @@ -162,7 +169,6 @@ def _transform_rollout_config_to_server_configs(self):
sglang_server_args.gpu_id_step = 1
sglang_server_args.nnodes = max(1, num_gpus_per_engine // self.config.gpus_per_node)
sglang_server_args.skip_server_warmup = True

sglang_server_args.mem_fraction_static = self.config.gpu_memory_utilization
# note: 非共卡模式下无需设置,共卡模式下需要offload必须设置,否则显存释放不了
sglang_server_args.enable_memory_saver = True
Expand All @@ -173,8 +179,9 @@ def _transform_rollout_config_to_server_configs(self):
sglang_server_args.max_running_requests = self.config.rollout_max_batch_size_per_instance
sglang_server_args.log_level = log_level
sglang_server_args.log_level_http = log_level_http
sglang_server_args.enable_deterministic_inference = enable_deterministic_inference

if XTUNER_DETERMINISTIC:
sglang_server_args.enable_deterministic_inference = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: rl_on_policy_target is not a recognized attribute of SGLang's ServerArgs as of the latest public SGLang release. This will raise an AttributeError (or be silently ignored if SGLang uses a permissive config model) on standard SGLang installations.

Is this targeting a custom/internal fork of SGLang? If so, please add a comment noting the required SGLang version or fork, so users don't hit runtime errors with standard SGLang.

sglang_server_args.rl_on_policy_target = True
if self.config.expert_parallel_size > 1:
sglang_server_args.tp_size = num_gpus_per_engine
sglang_server_args.ep_size = num_gpus_per_engine
Expand Down Expand Up @@ -212,6 +219,8 @@ def _transform_sample_params(self, sample_params: Dict):
"stop_token_ids": sample_params["stop_token_ids"],
"skip_special_tokens": sample_params["skip_special_tokens"],
}
if XTUNER_DETERMINISTIC:
sglang_sample_params["sampling_seed"] = sample_params["sampling_seed"]
return sglang_sample_params

def _transform_extra_params(self, extra_params: Dict):
Expand Down
Loading
Loading