Switch loaders to two-phase init and remove bridge methods#579
Switch loaders to two-phase init and remove bridge methods#579kmontemayor2-sc wants to merge 24 commits intomainfrom
Conversation
…r.py Move create_dist_sampler(), SamplerInput, and SamplerRuntime out of dist_sampling_producer.py into a shared utils module so they can be reused by the upcoming SharedDistSamplingBackend. Also rename `w` -> `worker` in DistSamplingProducer.init() for clarity. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Introduce SharedDistSamplingBackend which manages a pool of worker processes servicing multiple compute-rank channels through a fair-queued round-robin scheduler. This replaces the per-channel producer model in graph-store mode with a shared backend + lightweight per-channel state. Includes tests for pure business logic helpers (_compute_num_batches, _epoch_batch_indices, _compute_worker_seeds_ranges), shuffle behavior, and completion reporting. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Replace the single-step create_sampling_producer with a two-phase API: - init_sampling_backend: creates/reuses a SharedDistSamplingBackend - register_sampling_input: registers a lightweight per-channel input The existing create_sampling_producer/destroy_sampling_producer methods are preserved as bridge methods that delegate to the new API, keeping existing loaders working without changes. Also adds InitSamplingBackendRequest and RegisterBackendRequest message dataclasses, and per-channel fetch stats logging. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Refactor BaseDistLoader to use the two-phase sampling API directly: - init_sampling_backend (shared across all ranks per loader instance) - register_sampling_input (unique per compute rank) Key changes: - Add GroupLeaderInfo, _compute_group_leader, _dispatch_grouped_graph_store_phase for generic leader-elected grouped RPC dispatch - Add _init_graph_store_sampling_backends and _register_graph_store_sampling_inputs - Replace _producer_id_list with _backend_id_list + _channel_id_list - Remove create_sampling_producer/destroy_sampling_producer bridge methods - Keep per-class _counter in each loader (not a global counter) since type-prefixed _backend_key already prevents cross-type collisions - Fix test_multiple_loaders_in_graph_store to use num_compute_nodes=2 so backend-sharing assertions are exercised across ranks Co-Authored-By: Claude Opus 4.6 <[email protected]>
|
/all_test |
GiGL Automation@ 21:22:55UTC : 🔄 @ 22:40:25UTC : ✅ Workflow completed successfully. |
GiGL Automation@ 21:22:57UTC : 🔄 @ 22:32:01UTC : ✅ Workflow completed successfully. |
GiGL Automation@ 21:22:59UTC : 🔄 @ 21:32:06UTC : ✅ Workflow completed successfully. |
GiGL Automation@ 21:22:59UTC : 🔄 @ 22:47:31UTC : ✅ Workflow completed successfully. |
GiGL Automation@ 21:23:00UTC : 🔄 @ 21:31:11UTC : ✅ Workflow completed successfully. |
| time.sleep(group_info.stagger_sleep) | ||
| results = issue_phase_rpcs() if group_info.is_leader else [] | ||
| all_results: list[list[T]] = [[] for _ in range(runtime.world_size)] | ||
| torch.distributed.all_gather_object(all_results, results) |
There was a problem hiding this comment.
Semgrep identified an issue in your code:
The torch.distributed.all_gather_object() call deserializes untrusted data from remote ranks using pickle, allowing arbitrary code execution if an attacker controls any rank in the distributed system.
More details about this
The torch.distributed.all_gather_object() call is using pickle deserialization under the hood to share producer_id_list across all ranks in the distributed system. This means untrusted data from other processes gets automatically deserialized and executed.
An attacker with access to any rank in the distributed training job could craft a malicious pickled object and send it during the all-gather phase. When producer_id_list (or other ranks' data) gets deserialized on your rank, the attacker's code runs immediately with the same privileges as your training process.
For example:
- Attacker gains write access to rank 1's memory or intercepts its state
- They insert a pickled Python object that executes shell commands when unpickled (e.g.,
os.system('steal_data.sh')) - Your rank calls
torch.distributed.all_gather_object(all_producer_ids, producer_id_list) - PyTorch's pickle unpickles all ranks' data, triggering the attacker's code during deserialization on your process
- The shell commands run with your process's credentials, potentially exfiltrating model weights or training data
To resolve this comment:
✨ Commit Assistant Fix Suggestion
- Avoid using
torch.distributed.all_gather_object, as this relies on Python pickling, which may lead to arbitrary code execution if untrusted data is ever deserialized. - Replace the use of
all_gather_objectwith a tensor-based collective, such astorch.distributed.all_gather, by converting your data to a tensor (for example, usetorch.tensor(producer_id_list)). - Predefine the size and type of the tensor holding the gathered data, for example:
all_producer_ids = torch.empty(runtime.world_size * num_producers, dtype=torch.long), wherenum_producersis the expected length ofproducer_id_listfor each rank. - Call
torch.distributed.all_gather([all_producer_ids], producer_id_list_tensor), whereproducer_id_list_tensor = torch.tensor(producer_id_list, dtype=torch.long)for each rank. - After gathering, reconstruct the original data structure as needed from
all_producer_ids. For example, split the combined tensor into per-rank lists using slicing. - If the number of producers may vary, agree on a fixed length and pad with a sentinel value such as
-1so tensors can be safely communicated.
This change ensures only primitive tensor data is shared between ranks, eliminating pickle-related risks.
💬 Ignore this finding
Reply with Semgrep commands to ignore this finding.
/fp <comment>for false positive/ar <comment>for acceptable risk/other <comment>for all other reasons
Alternatively, triage in Semgrep AppSec Platform to ignore the finding created by pickles-in-pytorch-distributed.
You can view more details about this finding in the Semgrep AppSec Platform.
| The leader's RPC results, broadcast to all ranks in the group. | ||
| """ | ||
| all_keys: list[Optional[str]] = [None] * runtime.world_size | ||
| torch.distributed.all_gather_object(all_keys, my_key) |
There was a problem hiding this comment.
Semgrep identified an issue in your code:
Using torch.distributed.all_gather_object() to share my_worker_key creates an arbitrary code execution risk, since pickle deserialization can execute attacker-controlled code if a compromised worker sends malicious data.
More details about this
The torch.distributed.all_gather_object() call on this line uses pickle to serialize and deserialize the my_worker_key string across distributed processes. An attacker who can control the data sent from any worker process could craft a malicious pickle payload that executes arbitrary code when deserialized.
Exploit scenario:
- An attacker compromises or spoofs one of the worker processes in the distributed training cluster
- They set
my_worker_keyto a malicious pickle-serialized object instead of a normal string - When this line executes, PyTorch deserializes the pickle object on all receiving ranks
- The malicious pickle payload executes arbitrary code on those processes with the same privileges as the training job, potentially stealing model weights, injecting backdoors, or exfiltrating data
To resolve this comment:
✨ Commit Assistant Fix Suggestion
- Avoid using
torch.distributed.all_gather_object, as it uses pickle internally and can allow arbitrary code execution if untrusted data is deserialized. - If exchanging strings across ranks, switch to using
torch.distributed.all_gather, which is safe for tensors. You can do this by encoding your strings to byte tensors before gathering and decoding after. - Replace the vulnerable line with logic similar to:
- Convert the local string to bytes:
my_worker_key_bytes = my_worker_key.encode('utf-8') - Find the maximum length of all keys to ensure tensors are the same size across ranks. This usually requires an
all_reduceto get the max. For example:
key_len_tensor = torch.tensor([len(my_worker_key_bytes)], device='cpu')
max_len_tensor = key_len_tensor.clone()
torch.distributed.all_reduce(max_len_tensor, op=torch.distributed.ReduceOp.MAX) - Pad your byte string to
max_len_tensor.item():padded_bytes = my_worker_key_bytes.ljust(max_len_tensor.item(), b'\x00') - Create a tensor:
my_worker_key_tensor = torch.ByteTensor(list(padded_bytes)) - Prepare a gather tensor:
all_worker_key_tensors = [torch.empty_like(my_worker_key_tensor) for _ in range(runtime.world_size)] - Call
torch.distributed.all_gather(all_worker_key_tensors, my_worker_key_tensor) - After gathering, decode each tensor:
[bytes(t.tolist()).rstrip(b'\x00').decode('utf-8') for t in all_worker_key_tensors]
- Convert the local string to bytes:
- Replace all uses of
all_worker_keyswith the decoded string list.
Alternatively, if all ranks already know or can deterministically construct the set of worker keys, you can avoid broadcasting entirely by constructing the list locally.
Using tensors for communication prevents vulnerabilities from deserialization attacks, as tensor operations do not use pickle.
💬 Ignore this finding
Reply with Semgrep commands to ignore this finding.
/fp <comment>for false positive/ar <comment>for acceptable risk/other <comment>for all other reasons
Alternatively, triage in Semgrep AppSec Platform to ignore the finding created by pickles-in-pytorch-distributed.
You can view more details about this finding in the Semgrep AppSec Platform.
…r module docstring Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
…sses Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
Co-Authored-By: Claude Opus 4.6 <[email protected]>
…into kmonte/shared-backend-decomp-3
- Fix typo "patition" → "partition" in docstrings. - Clear _fetch_stats_by_channel_id on shutdown. - Fix init_sampling_backend race via lock-overlap on backend_state.lock; second callers block on that lock and observe init_complete / init_error instead of racing to a rolled-back map entry. - Snapshot active_channels length under self._lock in register_sampling_input so the log reflects state at registration. - Serialize destroy_sampling_input and start_new_epoch_sampling under channel_state.lock, restoring the pre-two-phase destroy/start_epoch invariant. Move _fetch_stats_by_channel_id.pop inside self._lock. - Document idempotency of start_new_epoch_sampling and the expected QueueTimeoutError case in fetch_one_sampled_message. Adds three concurrency regression tests: two for the init_sampling_backend race (success and failure-propagation via __cause__), and one verifying monotonic dispatch under the channel_state.lock wrap on start_new_epoch_sampling. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
…pchat/GiGL into kmonte/shared-backend-decomp-3
|
/all_test |
GiGL Automation@ 15:58:09UTC : 🔄 @ 17:27:53UTC : ✅ Workflow completed successfully. |
GiGL Automation@ 15:58:10UTC : 🔄 @ 16:05:37UTC : ✅ Workflow completed successfully. |
GiGL Automation@ 15:58:11UTC : 🔄 @ 17:02:30UTC : ✅ Workflow completed successfully. |
GiGL Automation@ 15:58:12UTC : 🔄 @ 16:07:48UTC : ✅ Workflow completed successfully. |
GiGL Automation@ 15:58:13UTC : 🔄 @ 17:14:12UTC : ✅ Workflow completed successfully. |
Switch loaders to two-phase init and remove bridge methods
Refactor BaseDistLoader to use the two-phase sampling API directly:
Key changes:
for generic leader-elected grouped RPC dispatch
type-prefixed _backend_key already prevents cross-type collisions
so backend-sharing assertions are exercised across ranks