Conversation
- Introduced new environment variables for RL memory monitoring: XTUNER_RL_MEM_INTERVAL, XTUNER_RL_OBJECT_LIMIT, and XTUNER_RL_OBJECT_TOP_K. - Updated run_rl.sh and run_rl_submit.sh to utilize these new variables for configuring memory monitoring. - Enhanced rl_monitor_actor_memory function to accept additional parameters for object limit and top K objects to monitor. - Added a new summarize_group_payload function in replay_buffer.py to provide detailed statistics on grouped data items. - Implemented memory reference management improvements in controller.py and replay_buffer.py to optimize memory usage during training. These changes aim to improve the flexibility and efficiency of memory monitoring in RL training workflows.
|
@claude review |
| return libc.malloc_trim(0) | ||
| except Exception as e: |
There was a problem hiding this comment.
Claude: Nit: _TRIM_MEMORY_WARNED is a module-level global mutated without any locking. Since trim_memory is called from BaseMLLMTokenizeFn.__call__ which runs in data-loading workers (potentially multi-threaded), there's a small race window where multiple threads see False and all log the warning. Not a correctness issue, but if you want to be tidy a threading.Lock or just removing the "warn once" logic (always warn or never warn) would be cleaner.
| "judged_observations": 0, | ||
| "has_multimodal_prompt": False, | ||
| } | ||
| if not grouped_dataitem: |
There was a problem hiding this comment.
Claude: Nit: summarize_group_payload is a public (module-level) function. Per project standards it should have a docstring with Google Style format, parameter types, and return type.
| def __init__(self, workers: list[TrainingWorker]) -> None: | ||
| self.workers = workers | ||
|
|
||
| def _collect_object_refs(self, obj, refs: list[ObjectRef]): | ||
| if isinstance(obj, ObjectRef): | ||
| refs.append(obj) | ||
| return |
There was a problem hiding this comment.
Claude: Warning: _collect_object_refs recurses into lists/tuples but doesn't handle dict values. If pixel_values or rollout_routed_experts ever hold refs inside a dict (e.g., via extra_info), those would be missed.
Not necessarily a problem today since pixel_values is torch.FloatTensor | None and rollout_routed_experts is torch.Tensor | None, but the recursive traversal pattern suggests it's intended to be generic. If so, consider also handling dicts. If not, the method name/docstring should clarify it only handles lists/tuples of ObjectRef.
| def _strip_rollout_payload_for_rerun(self, replay_meta: ReplayMeta, new_state: RolloutState): | ||
| """Keep prompt refs only and drop rollout outputs that will not be reused.""" | ||
| old_obs_refs = [ref for ref in replay_meta.observation_refs if ref is not None] | ||
| if old_obs_refs: | ||
| ray.internal.free(old_obs_refs, local_only=False) | ||
| replay_meta.observation_refs = [ray.put(RLEnvDataItem()) for _ in replay_meta.observation_ids] |
There was a problem hiding this comment.
Claude: Warning: Each call to _strip_rollout_payload_for_rerun creates N new Ray objects via ray.put(RLEnvDataItem()) — one per observation ID. For actions with many observations, or when many actions are stripped (expired/aborted batches), this could accumulate significant overhead in the object store.
Consider caching a single empty ref and reusing it:
empty_ref = ray.put(RLEnvDataItem())
replay_meta.observation_refs = [empty_ref for _ in replay_meta.observation_ids]Since RLEnvDataItem() is always identical, a single shared object ref avoids N redundant copies in the object store.
| for action_id in self._root2actions[root_id]: | ||
| if action_id == current_action_id: | ||
| continue | ||
| new_replay_meta = self._actions.pop(action_id, None) | ||
| if new_replay_meta: | ||
| self._clear_meta_for_actions(new_replay_meta) |
There was a problem hiding this comment.
Claude: Nit: _clear_meta_for_actions now pops from self._actions internally (line 866 — good, makes it a proper "single source of truth" for cleanup). However, this loop still does self._actions.pop(action_id, None) before calling _clear_meta_for_actions(new_replay_meta), resulting in a redundant double-pop on self._actions for the same key.
Not a bug (the second pop returns None harmlessly), but it's confusing for readers since _clear_meta_for_actions is documented as "the single source of truth for deleting an action." Consider letting _clear_meta_for_actions own the pop:
for action_id in self._root2actions[root_id]:
if action_id == current_action_id:
continue
new_replay_meta = self._actions.get(action_id)
if new_replay_meta:
self._clear_meta_for_actions(new_replay_meta)|
Claude: ## Summary This PR fixes Ray object reference memory leaks in RL training by:
IssuesWarning
Nit
VerdictCOMMENT — The core memory leak fixes (ref freeing in replay buffer cleanup paths, controller ref freeing, PIL context managers) look correct and well-targeted. The monitoring enhancements are a nice addition. The items above are minor improvements, nothing blocking. |
No description provided.