Skip to content

fix rl mem leak#1646

Open
RangiLyu wants to merge 2 commits intoInternLM:mainfrom
RangiLyu:lcq/rl-mem
Open

fix rl mem leak#1646
RangiLyu wants to merge 2 commits intoInternLM:mainfrom
RangiLyu:lcq/rl-mem

Conversation

@RangiLyu
Copy link
Copy Markdown
Collaborator

@RangiLyu RangiLyu commented Apr 1, 2026

No description provided.

- Introduced new environment variables for RL memory monitoring: XTUNER_RL_MEM_INTERVAL, XTUNER_RL_OBJECT_LIMIT, and XTUNER_RL_OBJECT_TOP_K.
- Updated run_rl.sh and run_rl_submit.sh to utilize these new variables for configuring memory monitoring.
- Enhanced rl_monitor_actor_memory function to accept additional parameters for object limit and top K objects to monitor.
- Added a new summarize_group_payload function in replay_buffer.py to provide detailed statistics on grouped data items.
- Implemented memory reference management improvements in controller.py and replay_buffer.py to optimize memory usage during training.

These changes aim to improve the flexibility and efficiency of memory monitoring in RL training workflows.
@RangiLyu
Copy link
Copy Markdown
Collaborator Author

RangiLyu commented Apr 1, 2026

@claude review

Comment on lines +232 to +233
return libc.malloc_trim(0)
except Exception as e:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit: _TRIM_MEMORY_WARNED is a module-level global mutated without any locking. Since trim_memory is called from BaseMLLMTokenizeFn.__call__ which runs in data-loading workers (potentially multi-threaded), there's a small race window where multiple threads see False and all log the warning. Not a correctness issue, but if you want to be tidy a threading.Lock or just removing the "warn once" logic (always warn or never warn) would be cleaner.

"judged_observations": 0,
"has_multimodal_prompt": False,
}
if not grouped_dataitem:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit: summarize_group_payload is a public (module-level) function. Per project standards it should have a docstring with Google Style format, parameter types, and return type.

Comment on lines 29 to +35
def __init__(self, workers: list[TrainingWorker]) -> None:
self.workers = workers

def _collect_object_refs(self, obj, refs: list[ObjectRef]):
if isinstance(obj, ObjectRef):
refs.append(obj)
return
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning: _collect_object_refs recurses into lists/tuples but doesn't handle dict values. If pixel_values or rollout_routed_experts ever hold refs inside a dict (e.g., via extra_info), those would be missed.

Not necessarily a problem today since pixel_values is torch.FloatTensor | None and rollout_routed_experts is torch.Tensor | None, but the recursive traversal pattern suggests it's intended to be generic. If so, consider also handling dicts. If not, the method name/docstring should clarify it only handles lists/tuples of ObjectRef.

Comment on lines +387 to +392
def _strip_rollout_payload_for_rerun(self, replay_meta: ReplayMeta, new_state: RolloutState):
"""Keep prompt refs only and drop rollout outputs that will not be reused."""
old_obs_refs = [ref for ref in replay_meta.observation_refs if ref is not None]
if old_obs_refs:
ray.internal.free(old_obs_refs, local_only=False)
replay_meta.observation_refs = [ray.put(RLEnvDataItem()) for _ in replay_meta.observation_ids]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning: Each call to _strip_rollout_payload_for_rerun creates N new Ray objects via ray.put(RLEnvDataItem()) — one per observation ID. For actions with many observations, or when many actions are stripped (expired/aborted batches), this could accumulate significant overhead in the object store.

Consider caching a single empty ref and reusing it:

empty_ref = ray.put(RLEnvDataItem())
replay_meta.observation_refs = [empty_ref for _ in replay_meta.observation_ids]

Since RLEnvDataItem() is always identical, a single shared object ref avoids N redundant copies in the object store.

Comment on lines 889 to 894
for action_id in self._root2actions[root_id]:
if action_id == current_action_id:
continue
new_replay_meta = self._actions.pop(action_id, None)
if new_replay_meta:
self._clear_meta_for_actions(new_replay_meta)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit: _clear_meta_for_actions now pops from self._actions internally (line 866 — good, makes it a proper "single source of truth" for cleanup). However, this loop still does self._actions.pop(action_id, None) before calling _clear_meta_for_actions(new_replay_meta), resulting in a redundant double-pop on self._actions for the same key.

Not a bug (the second pop returns None harmlessly), but it's confusing for readers since _clear_meta_for_actions is documented as "the single source of truth for deleting an action." Consider letting _clear_meta_for_actions own the pop:

for action_id in self._root2actions[root_id]:
    if action_id == current_action_id:
        continue
    new_replay_meta = self._actions.get(action_id)
    if new_replay_meta:
        self._clear_meta_for_actions(new_replay_meta)

@claude
Copy link
Copy Markdown
Contributor

claude bot commented Apr 1, 2026

Claude: ## Summary

This PR fixes Ray object reference memory leaks in RL training by:

  1. Freeing Ray object refs (action_ref, observation_refs) when actions are cleared or stripped
  2. Stripping rollout payloads (keeping only prompts) for expired/aborted actions via _strip_rollout_payload_for_rerun
  3. Freeing pixel_values and rollout_routed_experts refs after ray.get in controller.fit
  4. Adding trim_memory (glibc malloc_trim) to reclaim heap pages during data loading
  5. Fixing PIL Image.open resource leaks by using context managers
  6. Expanding memory monitoring with Ray object store stats

Issues

Warning

  • replay_buffer.py:387-392_strip_rollout_payload_for_rerun creates N new ray.put(RLEnvDataItem()) objects per call. A single cached empty ref would avoid redundant object store entries.
  • controller.py:29-35_collect_object_refs recurses into lists/tuples but not dicts. If this is meant to be generic, consider handling dicts too; if not, the scope should be documented.
  • misc.py:232-233_TRIM_MEMORY_WARNED global is not thread-safe; minor race possible in multi-threaded data loading.

Nit

  • replay_buffer.py:78summarize_group_payload is a public function missing a docstring.
  • replay_buffer.py:889-894_clear_meta_for_root loop still pops from self._actions before calling _clear_meta_for_actions, which now also pops — redundant double-pop.

Verdict

COMMENT — The core memory leak fixes (ref freeing in replay buffer cleanup paths, controller ref freeing, PIL context managers) look correct and well-targeted. The monitoring enhancements are a nice addition. The items above are minor improvements, nothing blocking.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant