Skip to content

Switch loaders to two-phase init and remove bridge methods#579

Draft
kmontemayor2-sc wants to merge 24 commits intomainfrom
kmonte/shared-backend-decomp-4
Draft

Switch loaders to two-phase init and remove bridge methods#579
kmontemayor2-sc wants to merge 24 commits intomainfrom
kmonte/shared-backend-decomp-4

Conversation

@kmontemayor2-sc
Copy link
Copy Markdown
Collaborator

@kmontemayor2-sc kmontemayor2-sc commented Apr 6, 2026

Switch loaders to two-phase init and remove bridge methods
Refactor BaseDistLoader to use the two-phase sampling API directly:

  • init_sampling_backend (shared across all ranks per loader instance)
  • register_sampling_input (unique per compute rank)

Key changes:

  • Add GroupLeaderInfo, _compute_group_leader, _dispatch_grouped_graph_store_phase
    for generic leader-elected grouped RPC dispatch
  • Add _init_graph_store_sampling_backends and _register_graph_store_sampling_inputs
  • Replace _producer_id_list with _backend_id_list + _channel_id_list
  • Remove create_sampling_producer/destroy_sampling_producer bridge methods
  • Keep per-class _counter in each loader (not a global counter) since
    type-prefixed _backend_key already prevents cross-type collisions
  • Fix test_multiple_loaders_in_graph_store to use num_compute_nodes=2
    so backend-sharing assertions are exercised across ranks

kmonte and others added 4 commits April 6, 2026 20:59
…r.py

Move create_dist_sampler(), SamplerInput, and SamplerRuntime out of
dist_sampling_producer.py into a shared utils module so they can be
reused by the upcoming SharedDistSamplingBackend.

Also rename `w` -> `worker` in DistSamplingProducer.init() for clarity.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Introduce SharedDistSamplingBackend which manages a pool of worker
processes servicing multiple compute-rank channels through a fair-queued
round-robin scheduler. This replaces the per-channel producer model in
graph-store mode with a shared backend + lightweight per-channel state.

Includes tests for pure business logic helpers (_compute_num_batches,
_epoch_batch_indices, _compute_worker_seeds_ranges), shuffle behavior,
and completion reporting.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Replace the single-step create_sampling_producer with a two-phase API:
- init_sampling_backend: creates/reuses a SharedDistSamplingBackend
- register_sampling_input: registers a lightweight per-channel input

The existing create_sampling_producer/destroy_sampling_producer methods
are preserved as bridge methods that delegate to the new API, keeping
existing loaders working without changes.

Also adds InitSamplingBackendRequest and RegisterBackendRequest message
dataclasses, and per-channel fetch stats logging.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Refactor BaseDistLoader to use the two-phase sampling API directly:
- init_sampling_backend (shared across all ranks per loader instance)
- register_sampling_input (unique per compute rank)

Key changes:
- Add GroupLeaderInfo, _compute_group_leader, _dispatch_grouped_graph_store_phase
  for generic leader-elected grouped RPC dispatch
- Add _init_graph_store_sampling_backends and _register_graph_store_sampling_inputs
- Replace _producer_id_list with _backend_id_list + _channel_id_list
- Remove create_sampling_producer/destroy_sampling_producer bridge methods
- Keep per-class _counter in each loader (not a global counter) since
  type-prefixed _backend_key already prevents cross-type collisions
- Fix test_multiple_loaders_in_graph_store to use num_compute_nodes=2
  so backend-sharing assertions are exercised across ranks

Co-Authored-By: Claude Opus 4.6 <[email protected]>
@kmontemayor2-sc
Copy link
Copy Markdown
Collaborator Author

/all_test

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 6, 2026

GiGL Automation

@ 21:22:55UTC : 🔄 Integration Test started.

@ 22:40:25UTC : ✅ Workflow completed successfully.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 6, 2026

GiGL Automation

@ 21:22:57UTC : 🔄 Python Unit Test started.

@ 22:32:01UTC : ✅ Workflow completed successfully.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 6, 2026

GiGL Automation

@ 21:22:59UTC : 🔄 Scala Unit Test started.

@ 21:32:06UTC : ✅ Workflow completed successfully.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 6, 2026

GiGL Automation

@ 21:22:59UTC : 🔄 E2E Test started.

@ 22:47:31UTC : ✅ Workflow completed successfully.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 6, 2026

GiGL Automation

@ 21:23:00UTC : 🔄 Lint Test started.

@ 21:31:11UTC : ✅ Workflow completed successfully.

time.sleep(group_info.stagger_sleep)
results = issue_phase_rpcs() if group_info.is_leader else []
all_results: list[list[T]] = [[] for _ in range(runtime.world_size)]
torch.distributed.all_gather_object(all_results, results)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Semgrep identified an issue in your code:

The torch.distributed.all_gather_object() call deserializes untrusted data from remote ranks using pickle, allowing arbitrary code execution if an attacker controls any rank in the distributed system.

More details about this

The torch.distributed.all_gather_object() call is using pickle deserialization under the hood to share producer_id_list across all ranks in the distributed system. This means untrusted data from other processes gets automatically deserialized and executed.

An attacker with access to any rank in the distributed training job could craft a malicious pickled object and send it during the all-gather phase. When producer_id_list (or other ranks' data) gets deserialized on your rank, the attacker's code runs immediately with the same privileges as your training process.

For example:

  1. Attacker gains write access to rank 1's memory or intercepts its state
  2. They insert a pickled Python object that executes shell commands when unpickled (e.g., os.system('steal_data.sh'))
  3. Your rank calls torch.distributed.all_gather_object(all_producer_ids, producer_id_list)
  4. PyTorch's pickle unpickles all ranks' data, triggering the attacker's code during deserialization on your process
  5. The shell commands run with your process's credentials, potentially exfiltrating model weights or training data

To resolve this comment:

✨ Commit Assistant Fix Suggestion
  1. Avoid using torch.distributed.all_gather_object, as this relies on Python pickling, which may lead to arbitrary code execution if untrusted data is ever deserialized.
  2. Replace the use of all_gather_object with a tensor-based collective, such as torch.distributed.all_gather, by converting your data to a tensor (for example, use torch.tensor(producer_id_list)).
  3. Predefine the size and type of the tensor holding the gathered data, for example: all_producer_ids = torch.empty(runtime.world_size * num_producers, dtype=torch.long), where num_producers is the expected length of producer_id_list for each rank.
  4. Call torch.distributed.all_gather([all_producer_ids], producer_id_list_tensor), where producer_id_list_tensor = torch.tensor(producer_id_list, dtype=torch.long) for each rank.
  5. After gathering, reconstruct the original data structure as needed from all_producer_ids. For example, split the combined tensor into per-rank lists using slicing.
  6. If the number of producers may vary, agree on a fixed length and pad with a sentinel value such as -1 so tensors can be safely communicated.

This change ensures only primitive tensor data is shared between ranks, eliminating pickle-related risks.

💬 Ignore this finding

Reply with Semgrep commands to ignore this finding.

  • /fp <comment> for false positive
  • /ar <comment> for acceptable risk
  • /other <comment> for all other reasons

Alternatively, triage in Semgrep AppSec Platform to ignore the finding created by pickles-in-pytorch-distributed.

You can view more details about this finding in the Semgrep AppSec Platform.

The leader's RPC results, broadcast to all ranks in the group.
"""
all_keys: list[Optional[str]] = [None] * runtime.world_size
torch.distributed.all_gather_object(all_keys, my_key)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Semgrep identified an issue in your code:

Using torch.distributed.all_gather_object() to share my_worker_key creates an arbitrary code execution risk, since pickle deserialization can execute attacker-controlled code if a compromised worker sends malicious data.

More details about this

The torch.distributed.all_gather_object() call on this line uses pickle to serialize and deserialize the my_worker_key string across distributed processes. An attacker who can control the data sent from any worker process could craft a malicious pickle payload that executes arbitrary code when deserialized.

Exploit scenario:

  1. An attacker compromises or spoofs one of the worker processes in the distributed training cluster
  2. They set my_worker_key to a malicious pickle-serialized object instead of a normal string
  3. When this line executes, PyTorch deserializes the pickle object on all receiving ranks
  4. The malicious pickle payload executes arbitrary code on those processes with the same privileges as the training job, potentially stealing model weights, injecting backdoors, or exfiltrating data

To resolve this comment:

✨ Commit Assistant Fix Suggestion
  1. Avoid using torch.distributed.all_gather_object, as it uses pickle internally and can allow arbitrary code execution if untrusted data is deserialized.
  2. If exchanging strings across ranks, switch to using torch.distributed.all_gather, which is safe for tensors. You can do this by encoding your strings to byte tensors before gathering and decoding after.
  3. Replace the vulnerable line with logic similar to:
    • Convert the local string to bytes: my_worker_key_bytes = my_worker_key.encode('utf-8')
    • Find the maximum length of all keys to ensure tensors are the same size across ranks. This usually requires an all_reduce to get the max. For example:
      key_len_tensor = torch.tensor([len(my_worker_key_bytes)], device='cpu')
      max_len_tensor = key_len_tensor.clone()
      torch.distributed.all_reduce(max_len_tensor, op=torch.distributed.ReduceOp.MAX)
    • Pad your byte string to max_len_tensor.item(): padded_bytes = my_worker_key_bytes.ljust(max_len_tensor.item(), b'\x00')
    • Create a tensor: my_worker_key_tensor = torch.ByteTensor(list(padded_bytes))
    • Prepare a gather tensor: all_worker_key_tensors = [torch.empty_like(my_worker_key_tensor) for _ in range(runtime.world_size)]
    • Call torch.distributed.all_gather(all_worker_key_tensors, my_worker_key_tensor)
    • After gathering, decode each tensor: [bytes(t.tolist()).rstrip(b'\x00').decode('utf-8') for t in all_worker_key_tensors]
  4. Replace all uses of all_worker_keys with the decoded string list.

Alternatively, if all ranks already know or can deterministically construct the set of worker keys, you can avoid broadcasting entirely by constructing the list locally.

Using tensors for communication prevents vulnerabilities from deserialization attacks, as tensor operations do not use pickle.

💬 Ignore this finding

Reply with Semgrep commands to ignore this finding.

  • /fp <comment> for false positive
  • /ar <comment> for acceptable risk
  • /other <comment> for all other reasons

Alternatively, triage in Semgrep AppSec Platform to ignore the finding created by pickles-in-pytorch-distributed.

You can view more details about this finding in the Semgrep AppSec Platform.

kmonte and others added 6 commits April 15, 2026 00:18
- Fix typo "patition" → "partition" in docstrings.
- Clear _fetch_stats_by_channel_id on shutdown.
- Fix init_sampling_backend race via lock-overlap on
  backend_state.lock; second callers block on that lock and
  observe init_complete / init_error instead of racing to a
  rolled-back map entry.
- Snapshot active_channels length under self._lock in
  register_sampling_input so the log reflects state at registration.
- Serialize destroy_sampling_input and start_new_epoch_sampling
  under channel_state.lock, restoring the pre-two-phase
  destroy/start_epoch invariant. Move _fetch_stats_by_channel_id.pop
  inside self._lock.
- Document idempotency of start_new_epoch_sampling and the
  expected QueueTimeoutError case in fetch_one_sampled_message.

Adds three concurrency regression tests: two for the
init_sampling_backend race (success and failure-propagation via
__cause__), and one verifying monotonic dispatch under the
channel_state.lock wrap on start_new_epoch_sampling.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
@kmontemayor2-sc kmontemayor2-sc changed the title Kmonte/shared backend decomp 4 Switch loaders to two-phase init and remove bridge methods Apr 21, 2026
@kmontemayor2-sc
Copy link
Copy Markdown
Collaborator Author

/all_test

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 21, 2026

GiGL Automation

@ 15:58:09UTC : 🔄 E2E Test started.

@ 17:27:53UTC : ✅ Workflow completed successfully.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 21, 2026

GiGL Automation

@ 15:58:10UTC : 🔄 Lint Test started.

@ 16:05:37UTC : ✅ Workflow completed successfully.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 21, 2026

GiGL Automation

@ 15:58:11UTC : 🔄 Python Unit Test started.

@ 17:02:30UTC : ✅ Workflow completed successfully.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 21, 2026

GiGL Automation

@ 15:58:12UTC : 🔄 Scala Unit Test started.

@ 16:07:48UTC : ✅ Workflow completed successfully.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 21, 2026

GiGL Automation

@ 15:58:13UTC : 🔄 Integration Test started.

@ 17:14:12UTC : ✅ Workflow completed successfully.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants