diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index 15f13fb16..f6c7653c9 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -8,12 +8,11 @@ - Graph Store mode: barrier loop + async RPC dispatch + channel creation """ -import math import sys import time -from collections import Counter, defaultdict +from collections import Counter from dataclasses import dataclass -from typing import Callable, Optional, Union +from typing import Optional, Union import torch from graphlearn_torch.channel import SampleMessage, ShmChannel @@ -23,7 +22,6 @@ RemoteDistSamplingWorkerOptions, get_context, ) -from graphlearn_torch.distributed.dist_client import async_request_server from graphlearn_torch.distributed.rpc import rpc_is_initialized from graphlearn_torch.sampler import ( EdgeSamplerInput, @@ -42,7 +40,12 @@ from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.dist_sampling_producer import DistSamplingProducer +from gigl.distributed.graph_store.compute import async_request_server from gigl.distributed.graph_store.dist_server import DistServer +from gigl.distributed.graph_store.messages import ( + InitSamplingBackendRequest, + RegisterBackendRequest, +) from gigl.distributed.graph_store.remote_channel import RemoteReceivingChannel from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions @@ -91,8 +94,8 @@ class BaseDistLoader(DistLoader): 3. Call ``create_sampling_config()`` to build the SamplingConfig. 4. For colocated: call ``create_colocated_channel()`` and construct the ``DistSamplingProducer`` (or subclass), then pass the producer as ``producer``. - 5. For graph store: pass the RPC function (e.g. ``DistServer.create_sampling_producer``) - as ``producer``. + 5. For graph store: prepare remote inputs and worker options; the base class + handles the two-phase backend/channel RPC initialization. 6. Call ``super().__init__()`` with the prepared data. Args: @@ -105,18 +108,13 @@ class BaseDistLoader(DistLoader): sampling_config: Configuration for sampling (created via ``create_sampling_config``). device: Target device for sampled results. runtime: Resolved distributed runtime information. - producer: Either a pre-constructed ``DistSamplingProducer`` (colocated mode) - or a callable to dispatch on the ``DistServer`` (graph store mode). + producer: Optional pre-constructed ``DistSamplingProducer`` for colocated mode. sampler_options: Controls which sampler class is instantiated. + backend_key: Unique key identifying the shared sampling backend for this + loader instance. Required for graph store mode; must be ``None`` for + colocated mode. process_start_gap_seconds: Delay between each process for staggered colocated init. - In graph store mode, this is the delay between each batch of concurrent - producer initializations. - max_concurrent_producer_inits: Maximum number of leader ranks that may - dispatch ``create_producer_fn`` RPCs concurrently in graph store mode. - Leaders are grouped into batches of this size; each batch sleeps - ``batch_index * process_start_gap_seconds`` before dispatching. - Only applies to graph store mode. Defaults to ``None`` - (no staggering). + Only applies to colocated mode. """ @staticmethod @@ -220,15 +218,12 @@ def __init__( sampling_config: SamplingConfig, device: torch.device, runtime: DistributedRuntimeInfo, - producer: Union[DistSamplingProducer, Callable[..., int]], + producer: Optional[DistSamplingProducer], sampler_options: SamplerOptions, + backend_key: Optional[str] = None, process_start_gap_seconds: float = 60.0, - max_concurrent_producer_inits: Optional[int] = None, non_blocking_transfers: bool = True, ): - if max_concurrent_producer_inits is None: - max_concurrent_producer_inits = sys.maxsize - # Set right away so __del__ can clean up if we throw during init. # Will be set to False once connections are initialized. self._shutdowned = True @@ -242,6 +237,7 @@ def __init__( self._sampler_options = sampler_options self._non_blocking_transfers = non_blocking_transfers + self._backend_key = backend_key # --- Attributes shared by both modes (mirrors GLT DistLoader.__init__) --- self.input_data = sampler_input @@ -264,9 +260,11 @@ def __init__( self._epoch = 0 # --- Mode-specific attributes and connection initialization --- - if isinstance(producer, DistSamplingProducer): - assert isinstance(dataset, DistDataset) - assert isinstance(worker_options, MpDistSamplingWorkerOptions) + if ( + isinstance(dataset, DistDataset) + and isinstance(worker_options, MpDistSamplingWorkerOptions) + and isinstance(producer, DistSamplingProducer) + ): assert isinstance(sampler_input, NodeSamplerInput) self.data: Optional[DistDataset] = dataset @@ -285,6 +283,7 @@ def __init__( if not self.drop_last and self._input_len % self.batch_size != 0: self._num_expected += 1 + self._remote_input_has_batches: list[bool] = [] self._shutdowned = False self._init_colocated_connections( dataset=dataset, @@ -292,11 +291,10 @@ def __init__( runtime=runtime, process_start_gap_seconds=process_start_gap_seconds, ) - else: - assert isinstance(dataset, RemoteDistDataset) - assert isinstance(worker_options, RemoteDistSamplingWorkerOptions) + elif isinstance(dataset, RemoteDistDataset) and isinstance( + worker_options, RemoteDistSamplingWorkerOptions + ): assert isinstance(sampler_input, list) - assert callable(producer) self.data = None self._is_mp_worker = False @@ -322,13 +320,17 @@ def __init__( node_types = [DEFAULT_HOMOGENEOUS_NODE_TYPE] self._set_ntypes_and_etypes(node_types, edge_types) + self._remote_input_has_batches = [ + self._sampler_input_has_batches(inp) for inp in self._input_data_list + ] self._shutdowned = False - self._init_graph_store_connections( - dataset=dataset, - create_producer_fn=producer, - runtime=runtime, - process_start_gap_seconds=process_start_gap_seconds, - max_concurrent_producer_inits=max_concurrent_producer_inits, + self._init_graph_store_connections(dataset=dataset) + else: + raise TypeError( + "Invalid loader construction. Expected either " + "(DistDataset, MpDistSamplingWorkerOptions, DistSamplingProducer) " + "for colocated mode or (RemoteDistDataset, RemoteDistSamplingWorkerOptions) " + "for graph-store mode." ) @staticmethod @@ -549,7 +551,6 @@ def create_colocated_worker_options( def create_graph_store_worker_options( *, dataset: RemoteDistDataset, - compute_rank: int, worker_key: str, num_workers: int, worker_concurrency: int, @@ -560,7 +561,6 @@ def create_graph_store_worker_options( Args: dataset: Remote dataset proxy used to discover storage-cluster topology. - compute_rank: Global compute-process rank for the current process. worker_key: Unique key used by the storage cluster to deduplicate producers. num_workers: Number of sampling worker processes. worker_concurrency: Max sampling concurrency per worker. @@ -570,10 +570,8 @@ def create_graph_store_worker_options( Returns: Fully configured worker options for graph-store sampling. """ - sampling_ports = dataset.fetch_free_ports_on_storage_cluster( - num_ports=dataset.cluster_info.compute_cluster_world_size - ) - sampling_port = sampling_ports[compute_rank] + sampling_ports = dataset.fetch_free_ports_on_storage_cluster(num_ports=1) + sampling_port = sampling_ports[0] return RemoteDistSamplingWorkerOptions( server_rank=list(range(dataset.cluster_info.num_storage_nodes)), num_workers=num_workers, @@ -633,88 +631,140 @@ def _init_colocated_connections( time.sleep(process_start_gap_seconds * runtime.local_rank) self._mp_producer.init() - def _init_graph_store_connections( - self, - dataset: RemoteDistDataset, - create_producer_fn: Callable[..., int], - runtime: DistributedRuntimeInfo, - process_start_gap_seconds: float = 60.0, - max_concurrent_producer_inits: int = sys.maxsize, # Already resolved from None by __init__ - ) -> None: - """Initialize Graph Store mode connections. - - Validates the GLT distributed context, elects a leader per ``worker_key`` - group to dispatch RPCs that create sampling producers on storage nodes, - then distributes the resulting producer IDs to all ranks in the group via - ``all_gather_object``. - The `worker_key` is set-upstream (and are passed in as part of `RemoteDistSamplingWorkerOptions`) - to group producers together, so that different processes in the compute cluster (e.g. different GPUs) - can share the same producer. - - Only the leader rank (minimum rank sharing a given ``worker_key``) sends - the ``create_producer_fn`` RPCs. This avoids redundant RPCs and - server-side lock contention, since the server deduplicates producers by - ``worker_key`` anyway. - - Leaders are further staggered into batches of size - ``max_concurrent_producer_inits``. Each batch sleeps - ``batch_index * process_start_gap_seconds`` before dispatching RPCs, - limiting the number of leaders that hit storage nodes concurrently. + def _init_graph_store_sampling_backends(self) -> list[int]: + """Initialize or reuse one shared backend per storage server. - All DistLoader attributes are already set by ``__init__`` before this is called. + Every compute rank issues one RPC per storage server. + ``DistServer.init_sampling_backend`` deduplicates concurrent calls + with the same ``backend_key`` internally (the first caller creates + the backend; subsequent callers block on ``backend_state.lock`` + until init completes), so all ranks observe the same ``backend_id`` + on each server without explicit rank-level coordination. - Uses ``async_request_server`` instead of ``ThreadPoolExecutor`` to avoid - TensorPipe rendezvous deadlock with many servers. - - For Graph Store mode it's important to distinguish "compute node" (e.g. physical compute machine) from "compute process" (e.g. process running on the compute node). - Since in practice we have multiple compute processes per compute node, and each compute process needs to initialize the connection to the storage nodes. - E.g. if there are 4 gpus per compute node, then there will be 4 connections from each compute node to each storage node. - - See below for a connection setup. - ╔═══════════════════════════════════════════════════════════════════════════════════════╗ - ║ COMPUTE TO STORAGE NODE CONNECTIONS ║ - ╚═══════════════════════════════════════════════════════════════════════════════════════╝ - - COMPUTE NODES STORAGE NODES - ═════════════ ═════════════ - - ┌──────────────────────┐ (1) ┌───────────────┐ - │ COMPUTE NODE 0 │ │ │ - │ ┌────┬────┬────┬────┤ ══════════════════════════════════│ STORAGE 0 │ - │ │GPU │GPU │GPU │GPU │ ╱ │ │ - │ │ 0 │ 1 │ 2 │ 3 │ ════════════════════╲ ╱ └───────────────┘ - │ └────┴────┴────┴────┤ (2) ╲ ╱ - └──────────────────────┘ ╲ ╱ - ╳ - (3) ╱ ╲ (4) - ┌──────────────────────┐ ╱ ╲ ┌───────────────┐ - │ COMPUTE NODE 1 │ ╱ ╲ │ │ - │ ┌────┬────┬────┬────┤ ═════════════════╱ ═│ STORAGE 1 │ - │ │GPU │GPU │GPU │GPU │ │ │ - │ │ 0 │ 1 │ 2 │ 3 │ ══════════════════════════════════│ │ - │ └────┴────┴────┴────┤ └───────────────┘ - └──────────────────────┘ - - ┌─────────────────────────────────────────────────────────────────────────────┐ - │ (1) Compute Node 0 → Storage 0 (4 connections, one per GPU) │ - │ (2) Compute Node 0 → Storage 1 (4 connections, one per GPU) │ - │ (3) Compute Node 1 → Storage 0 (4 connections, one per GPU) │ - │ (4) Compute Node 1 → Storage 1 (4 connections, one per GPU) │ - └─────────────────────────────────────────────────────────────────────────────┘ + Returns: + List of backend IDs, one per storage server. + + Raises: + RuntimeError: If ``_backend_key`` was not set. + """ + if self._backend_key is None: + raise RuntimeError( + f"{type(self).__name__} was constructed without a backend_key. " + "Graph-store mode requires a non-None backend_key." + ) + futures: list[torch.futures.Future[int]] = [ + async_request_server( + server_rank, + DistServer.init_sampling_backend, + InitSamplingBackendRequest( + backend_key=self._backend_key, + worker_options=self.worker_options, + sampler_options=self._sampler_options, + sampling_config=self.sampling_config, + ), + ) + for server_rank in self._server_rank_list + ] + return torch.futures.wait_all(futures) + + def _register_graph_store_sampling_inputs( + self, backend_id_list: list[int] + ) -> list[int]: + """Register this compute rank's inputs on existing shared backends. + + Each compute rank has a unique ``worker_key``, so registrations are + naturally per-rank and do not need cross-rank coordination. Args: - dataset: The remote dataset proxy for graph store mode. - create_producer_fn: RPC callable to create a sampling producer on a - storage server (e.g. ``DistServer.create_sampling_producer``). - runtime: Resolved distributed runtime information (provides rank and - world_size for the all_gather collectives). - process_start_gap_seconds: Delay in seconds between each batch of - concurrent producer initializations. - max_concurrent_producer_inits: Maximum number of leader ranks that - may dispatch RPCs concurrently. Leaders are grouped into batches - of this size. + backend_id_list: Backend IDs from ``_init_graph_store_sampling_backends``. + + Returns: + List of channel IDs, one per storage server. + """ + assert ( + len(self._server_rank_list) + == len(backend_id_list) + == len(self._input_data_list) + ), ( + f"Mismatched lengths: server_rank_list={len(self._server_rank_list)}, " + f"backend_id_list={len(backend_id_list)}, " + f"input_data_list={len(self._input_data_list)}" + ) + worker_key = self.worker_options.worker_key + futures: list[torch.futures.Future[int]] = [ + async_request_server( + server_rank, + DistServer.register_sampling_input, + RegisterBackendRequest( + backend_id=backend_id, + worker_key=worker_key, + sampler_input=input_data, + sampling_config=self.sampling_config, + buffer_capacity=self.worker_options.buffer_capacity, + buffer_size=self.worker_options.buffer_size, + ), + ) + for server_rank, backend_id, input_data in zip( + self._server_rank_list, + backend_id_list, + self._input_data_list, + ) + ] + return torch.futures.wait_all(futures) + + def _sampler_input_has_batches(self, sampler_input: NodeSamplerInput) -> bool: + """Return whether this sampler input can produce at least one batch. + + Args: + sampler_input: The sampler input to check. + + Returns: + True if the input has enough elements for at least one batch. + """ + input_len = len(sampler_input) + return input_len > 0 and not (self.drop_last and input_len < self.batch_size) + + def _init_graph_store_connections(self, dataset: RemoteDistDataset) -> None: + """Initialize graph-store mode with shared backends and per-rank channels. + + Populates two parallel lists indexed by storage server. A compute rank + corresponds to one loader process; with N compute machines and P + processes per machine there are N*P compute ranks. + + ``_backend_id_list`` + One backend per storage server, shared across all compute ranks + using the same loader instance (keyed by ``_backend_key``). + Used when the operation targets the backend itself (e.g. + ``init_sampling_backend``). + Note that when there are multiple loader instances + (e.g. and ABLP and DistLoader for training), each instance + will have it's own backend id, per server. + + ``_channel_id_list`` + One channel per storage server, unique to this compute rank and + loader instance. Used for per-rank operations (e.g. + ``start_new_epoch_sampling``, ``fetch_one_sampled_message``, + ``destroy_sampling_input``). + + Invariants: + + * ``len(backend_id_list) == len(channel_id_list) == num_storage_servers``. + * All compute ranks sharing a loader instance see the same + ``backend_id_list``. Server-side dedup on ``backend_key`` (via + ``DistServer._backend_id_by_backend_key``) guarantees that every + rank's concurrent ``init_sampling_backend`` RPC returns the same + ID on a given server. + * Each storage server maintains its own ``_next_backend_id`` and + ``_next_channel_id`` counters. Values on different servers advance + independently; cross-server numeric equality is not guaranteed + (partial init failures or other operations on one server can + desynchronize counters). + * Within a single server, channel IDs are unique across all + registrations (no dedup; each ``register_sampling_input`` call + allocates a fresh monotonic ID). A single compute rank may hold + multiple channel IDs on the same server if it owns multiple + concurrent loader instances. """ - # Validate distributed context ctx = get_context() if ctx is None: raise RuntimeError( @@ -726,135 +776,26 @@ def _init_graph_store_connections( f"'{self.__class__.__name__}': must be used on a client worker process." ) - # Move input to CPU before sending to server for inp in self._input_data_list: if not isinstance(inp, RemoteSamplerInput): inp.to(torch.device("cpu")) - node_rank = dataset.cluster_info.compute_node_rank - - _flush() start_time = time.time() - - # --- Leader election via worker_key all_gather --- - # All ranks exchange their worker_key so we can group ranks that share - # the same key and elect the minimum rank as the leader. - my_worker_key: str = self.worker_options.worker_key - all_worker_keys: list[Optional[str]] = [None] * runtime.world_size - torch.distributed.all_gather_object(all_worker_keys, my_worker_key) - - key_to_ranks: dict[str, list[int]] = defaultdict(list) - for r, key in enumerate(all_worker_keys): - assert key is not None, f"Rank {r} did not provide a worker_key" - key_to_ranks[key].append(r) - - leader_rank = min(key_to_ranks[my_worker_key]) - is_leader = runtime.rank == leader_rank - - # --- Stagger leaders into batches --- - # Deterministically assign each unique worker_key an index, then group - # leaders into batches of max_concurrent_producer_inits. Each batch - # sleeps batch_index * process_start_gap_seconds before dispatching. - unique_keys = sorted(key_to_ranks.keys()) - my_key_index = unique_keys.index(my_worker_key) - num_unique_keys = len(unique_keys) - num_batches = math.ceil(num_unique_keys / max_concurrent_producer_inits) - my_batch = my_key_index // max_concurrent_producer_inits - stagger_sleep_seconds = my_batch * process_start_gap_seconds - - logger.info( - f"rank={runtime.rank} worker_key={my_worker_key} " - f"is_leader={is_leader} leader_rank={leader_rank} " - f"group_size={len(key_to_ranks[my_worker_key])} " - f"key_index={my_key_index}/{num_unique_keys} " - f"batch={my_batch}/{num_batches} " - f"stagger_sleep={stagger_sleep_seconds:.1f}s" - ) - _flush() - - # --- Leader dispatches RPCs, followers skip --- - producer_id_list: list[int] = [] - if is_leader: - if stagger_sleep_seconds > 0: - logger.info( - f"rank={runtime.rank} sleeping {stagger_sleep_seconds:.1f}s " - f"(batch {my_batch}/{num_batches}) before RPC dispatch" - ) - _flush() - time.sleep(stagger_sleep_seconds) - - rpc_futures: list[tuple[int, torch.futures.Future[int]]] = [] - logger.info( - f"node_rank={node_rank} rank={runtime.rank} dispatching " - f"create_sampling_producer to " - f"{len(self._server_rank_list)} servers" - ) - _flush() - t_dispatch = time.time() - for server_rank, inp_data in zip( - self._server_rank_list, self._input_data_list - ): - fut = async_request_server( - server_rank, - create_producer_fn, - inp_data, - self.sampling_config, - self.worker_options, - self._sampler_options, - ) - rpc_futures.append((server_rank, fut)) - logger.info( - f"node_rank={node_rank} rank={runtime.rank} all " - f"{len(rpc_futures)} RPCs dispatched in " - f"{time.time() - t_dispatch:.3f}s, waiting for responses" - ) - _flush() - - for server_rank, fut in rpc_futures: - t_wait = time.time() - producer_id = fut.wait() - logger.info( - f"node_rank={node_rank} rank={runtime.rank} " - f"create_sampling_producer" - f"(server_rank={server_rank}) returned " - f"producer_id={producer_id} in {time.time() - t_wait:.2f}s" - ) - _flush() - producer_id_list.append(producer_id) - logger.info( - f"node_rank={node_rank} rank={runtime.rank} all " - f"{len(producer_id_list)} producers " - f"created in {time.time() - t_dispatch:.2f}s total" - ) - _flush() - else: # if not leader - logger.info( - f"Since rank {runtime.rank} is not the leader for worker key {my_worker_key}, we will wait for the leader (rank {leader_rank}) to dispatch RPCs" - ) - - # --- Distribute producer IDs to all ranks --- - all_producer_ids: list[list[int]] = [[] for _ in range(runtime.world_size)] - torch.distributed.all_gather_object(all_producer_ids, producer_id_list) - self._producer_id_list = all_producer_ids[leader_rank] - - logger.info( - f"rank={runtime.rank} received producer_id_list=" - f"{self._producer_id_list} from leader_rank={leader_rank}" + self._backend_id_list = self._init_graph_store_sampling_backends() + self._channel_id_list = self._register_graph_store_sampling_inputs( + backend_id_list=self._backend_id_list, ) - _flush() - - # Create remote receiving channel for cross-machine message passing self._channel = RemoteReceivingChannel( server_rank=self._server_rank_list, - channel_id=self._producer_id_list, + channel_id=self._channel_id_list, prefetch_size=self.worker_options.prefetch_size, - active_mask=[len(inp) > 0 for inp in self._input_data_list], + active_mask=self._remote_input_has_batches, pin_memory=self.to_device is not None and self.to_device.type == "cuda", ) - logger.info( - f"node_rank {node_rank} rank={runtime.rank} initialized " - f"the dist loader in {time.time() - start_time:.2f}s" + f"node_rank {dataset.cluster_info.compute_node_rank} " + f"rank={torch.distributed.get_rank()} " + f"initialized shared graph-store loader in {time.time() - start_time:.2f}s" ) _flush() @@ -868,11 +809,11 @@ def shutdown(self) -> None: self._mp_producer.shutdown() elif rpc_is_initialized() is True: rpc_futures: list[torch.futures.Future[None]] = [] - for server_rank, producer_id in zip( - self._server_rank_list, self._producer_id_list + for server_rank, channel_id in zip( + self._server_rank_list, self._channel_id_list ): fut = async_request_server( - server_rank, DistServer.destroy_sampling_producer, producer_id + server_rank, DistServer.destroy_sampling_input, channel_id ) rpc_futures.append(fut) torch.futures.wait_all(rpc_futures) @@ -916,16 +857,24 @@ def __iter__(self) -> Self: self._mp_producer.produce_all() else: rpc_futures: list[torch.futures.Future[None]] = [] - for server_rank, producer_id in zip( - self._server_rank_list, self._producer_id_list + for server_rank, channel_id, has_batches in zip( + self._server_rank_list, + self._channel_id_list, + self._remote_input_has_batches, ): + if not has_batches: + continue fut = async_request_server( server_rank, DistServer.start_new_epoch_sampling, - producer_id, + channel_id, self._epoch, ) rpc_futures.append(fut) + # Match GLT's remote-loader ordering: do not begin fetching until + # every storage server has acknowledged the epoch start for this + # channel. Otherwise a fetch RPC can race ahead and block the + # corresponding start-epoch RPC on the server-side channel lock. torch.futures.wait_all(rpc_futures) self._channel.reset() self._epoch += 1 diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 215a92a51..2c72c9ceb 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -1,6 +1,6 @@ from collections import abc, defaultdict from itertools import count -from typing import Callable, Optional, Union +from typing import Optional, Union import torch from graphlearn_torch.channel import SampleMessage @@ -21,7 +21,6 @@ PPR_WEIGHT_METADATA_KEY, ) from gigl.distributed.dist_sampling_producer import DistSamplingProducer -from gigl.distributed.graph_store.dist_server import DistServer from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset from gigl.distributed.sampler import ( NEGATIVE_LABEL_METADATA_KEY, @@ -89,7 +88,6 @@ def __init__( prefetch_size: Optional[int] = None, channel_size: str = "4GB", process_start_gap_seconds: float = 60.0, - max_concurrent_producer_inits: Optional[int] = None, num_cpu_threads: Optional[int] = None, shuffle: bool = False, drop_last: bool = False, @@ -195,16 +193,9 @@ def __init__( channel_size (int or str): The shared-memory buffer size (bytes) allocated for the channel. Can be modified for performance tuning; a good starting point is: ``num_workers * 64MB`` (default: "4GB"). - process_start_gap_seconds (float): Delay between each process for initializing neighbor loader. - In colocated mode, each process sleeps ``local_rank * process_start_gap_seconds`` - before initializing. In graph store mode, leader ranks are grouped into batches - of ``max_concurrent_producer_inits`` and each batch sleeps - ``batch_index * process_start_gap_seconds`` before dispatching RPCs. - max_concurrent_producer_inits (int): Maximum number of leader ranks that may - dispatch create-producer RPCs concurrently in graph store mode. Leaders are - grouped into batches of this size; each batch is staggered by - ``process_start_gap_seconds``. Only applies to graph store mode. - Defaults to ``None`` (no staggering). + process_start_gap_seconds (float): Delay between each process for initializing neighbor loader + in colocated mode. Each process sleeps ``local_rank * process_start_gap_seconds`` + before initializing. Only applies to colocated mode. num_cpu_threads (Optional[int]): Number of cpu threads PyTorch should use for CPU training/inference neighbor loading; on top of the per process parallelism. Defaults to `2` if set to `None` when using cpu training/inference. @@ -259,10 +250,6 @@ def __init__( raise ValueError( f"prefetch_size must be None when using Colocated mode, received {prefetch_size}" ) - if max_concurrent_producer_inits is not None: - raise ValueError( - f"max_concurrent_producer_inits must be None when using Colocated mode, received {max_concurrent_producer_inits}" - ) logger.info(f"Sampling cluster setup: {self._sampling_cluster_setup.value}") del supervision_edge_type @@ -316,6 +303,7 @@ def __init__( MpDistSamplingWorkerOptions, RemoteDistSamplingWorkerOptions ] = setup_info[1] dataset_schema: DatasetSchema = setup_info[2] + backend_key: Optional[str] = None else: # Graph Store mode assert isinstance(dataset, RemoteDistDataset), ( "When using Graph Store mode, dataset must be a RemoteDistDataset." @@ -334,6 +322,7 @@ def __init__( sampler_input, worker_options, dataset_schema, + backend_key, ) = self._setup_for_graph_store( input_nodes=input_nodes, dataset=dataset, @@ -362,22 +351,17 @@ def __init__( drop_last=drop_last, ) - # Build the producer: a pre-constructed producer for colocated mode, - # or an RPC callable for graph store mode. + producer: Optional[DistSamplingProducer] = None if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: assert isinstance(dataset, DistDataset) assert isinstance(worker_options, MpDistSamplingWorkerOptions) - producer: Union[DistSamplingProducer, Callable[..., int]] = ( - BaseDistLoader.create_mp_producer( - dataset=dataset, - sampler_input=sampler_input, - sampling_config=sampling_config, - worker_options=worker_options, - sampler_options=sampler_options, - ) + producer = BaseDistLoader.create_mp_producer( + dataset=dataset, + sampler_input=sampler_input, + sampling_config=sampling_config, + worker_options=worker_options, + sampler_options=sampler_options, ) - else: - producer = DistServer.create_sampling_producer # Call base class — handles metadata storage and connection initialization # (including staggered init for colocated mode). @@ -391,8 +375,8 @@ def __init__( runtime=runtime, producer=producer, sampler_options=sampler_options, + backend_key=backend_key, process_start_gap_seconds=process_start_gap_seconds, - max_concurrent_producer_inits=max_concurrent_producer_inits, non_blocking_transfers=non_blocking_transfers, ) @@ -602,7 +586,10 @@ def _setup_for_graph_store( channel_size: str, prefetch_size: int, ) -> tuple[ - list[ABLPNodeSamplerInput], RemoteDistSamplingWorkerOptions, DatasetSchema + list[ABLPNodeSamplerInput], + RemoteDistSamplingWorkerOptions, + DatasetSchema, + str, ]: """ Setup method for Graph Store mode. @@ -618,19 +605,18 @@ def _setup_for_graph_store( prefetch_size: Max prefetched sampled messages per server on client side. Returns: - Tuple of (list[ABLPNodeSamplerInput], RemoteDistSamplingWorkerOptions, DatasetSchema). + Tuple of (list[ABLPNodeSamplerInput], RemoteDistSamplingWorkerOptions, + DatasetSchema, backend_key). """ node_feature_info = dataset.fetch_node_feature_info() edge_feature_info = dataset.fetch_edge_feature_info() edge_types = dataset.fetch_edge_types() compute_rank = torch.distributed.get_rank() - worker_key = ( - f"compute_ablp_loader_rank_{compute_rank}_worker_{self._instance_count}" - ) + backend_key = f"dist_ablp_loader_{self._instance_count}" + worker_key = f"{backend_key}_compute_rank_{compute_rank}" logger.info(f"rank: {compute_rank}, worker_key: {worker_key}") worker_options = BaseDistLoader.create_graph_store_worker_options( dataset=dataset, - compute_rank=compute_rank, worker_key=worker_key, num_workers=num_workers, worker_concurrency=worker_concurrency, @@ -753,6 +739,7 @@ def _setup_for_graph_store( edge_feature_info=edge_feature_info, edge_dir=dataset.fetch_edge_dir(), ), + backend_key, ) def _set_labels( diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index a6089f930..3d6d5a34b 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -1,7 +1,7 @@ import sys from collections import abc from itertools import count -from typing import Callable, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch from graphlearn_torch.channel import SampleMessage @@ -23,7 +23,6 @@ PPR_WEIGHT_METADATA_KEY, ) from gigl.distributed.dist_sampling_producer import DistSamplingProducer -from gigl.distributed.graph_store.dist_server import DistServer as GiglDistServer from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset from gigl.distributed.sampler_options import ( PPRSamplerOptions, @@ -88,7 +87,6 @@ def __init__( channel_size: str = "4GB", prefetch_size: Optional[int] = None, process_start_gap_seconds: float = 60.0, - max_concurrent_producer_inits: Optional[int] = None, num_cpu_threads: Optional[int] = None, shuffle: bool = False, drop_last: bool = False, @@ -152,16 +150,9 @@ def __init__( are active concurrently. (default: ``None``). Only applicable in Graph Store mode. If supplied and not it Graph Store mode, an error will be raised. - process_start_gap_seconds (float): Delay between each process for initializing neighbor loader. - In colocated mode, each process sleeps ``local_rank * process_start_gap_seconds`` - before initializing. In graph store mode, leader ranks are grouped into batches - of ``max_concurrent_producer_inits`` and each batch sleeps - ``batch_index * process_start_gap_seconds`` before dispatching RPCs. - max_concurrent_producer_inits (int): Maximum number of leader ranks that may - dispatch create-producer RPCs concurrently in graph store mode. Leaders are - grouped into batches of this size; each batch is staggered by - ``process_start_gap_seconds``. Only applies to graph store mode. - Defaults to ``None`` (no staggering). + process_start_gap_seconds (float): Delay between each process for initializing neighbor loader + in colocated mode. Each process sleeps ``local_rank * process_start_gap_seconds`` + before initializing. Only applies to colocated mode. num_cpu_threads (Optional[int]): Number of cpu threads PyTorch should use for CPU training/inference neighbor loading; on top of the per process parallelism. Defaults to `2` if set to `None` when using cpu training/inference. @@ -202,10 +193,6 @@ def __init__( raise ValueError( f"prefetch_size must be None when using Colocated mode, received {prefetch_size}" ) - if max_concurrent_producer_inits is not None: - raise ValueError( - f"max_concurrent_producer_inits must be None when using Colocated mode, received {max_concurrent_producer_inits}" - ) logger.info(f"Sampling cluster setup: {self._sampling_cluster_setup.value}") self._instance_count = next(self._counter) @@ -217,6 +204,8 @@ def __init__( ) ) + backend_key: Optional[str] = None + # Mode-specific setup if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: assert isinstance(dataset, DistDataset), ( @@ -243,7 +232,12 @@ def __init__( if prefetch_size is None: logger.info(f"prefetch_size is not provided, using default of 4") prefetch_size = 4 - input_data, worker_options, dataset_schema = self._setup_for_graph_store( + ( + input_data, + worker_options, + dataset_schema, + backend_key, + ) = self._setup_for_graph_store( input_nodes=input_nodes, dataset=dataset, num_workers=num_workers, @@ -271,22 +265,17 @@ def __init__( drop_last=drop_last, ) - # Build the producer: a pre-constructed producer for colocated mode, - # or an RPC callable for graph store mode. + producer: Optional[DistSamplingProducer] = None if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: assert isinstance(dataset, DistDataset) assert isinstance(worker_options, MpDistSamplingWorkerOptions) - producer: Union[DistSamplingProducer, Callable[..., int]] = ( - BaseDistLoader.create_mp_producer( - dataset=dataset, - sampler_input=input_data, - sampling_config=sampling_config, - worker_options=worker_options, - sampler_options=sampler_options, - ) + producer = BaseDistLoader.create_mp_producer( + dataset=dataset, + sampler_input=input_data, + sampling_config=sampling_config, + worker_options=worker_options, + sampler_options=sampler_options, ) - else: - producer = GiglDistServer.create_sampling_producer # Call base class — handles metadata storage and connection initialization # (including staggered init for colocated mode). @@ -300,8 +289,8 @@ def __init__( runtime=runtime, producer=producer, sampler_options=sampler_options, + backend_key=backend_key, process_start_gap_seconds=process_start_gap_seconds, - max_concurrent_producer_inits=max_concurrent_producer_inits, non_blocking_transfers=non_blocking_transfers, ) @@ -320,7 +309,9 @@ def _setup_for_graph_store( worker_concurrency: int, prefetch_size: int, channel_size: str, - ) -> tuple[list[NodeSamplerInput], RemoteDistSamplingWorkerOptions, DatasetSchema]: + ) -> tuple[ + list[NodeSamplerInput], RemoteDistSamplingWorkerOptions, DatasetSchema, str + ]: if input_nodes is None: raise ValueError( f"When using Graph Store mode, input nodes must be provided, received {input_nodes}" @@ -341,11 +332,11 @@ def _setup_for_graph_store( edge_types = dataset.fetch_edge_types() compute_rank = torch.distributed.get_rank() - worker_key = f"compute_rank_{compute_rank}_worker_{self._instance_count}" + backend_key = f"dist_neighbor_loader_{self._instance_count}" + worker_key = f"{backend_key}_compute_rank_{compute_rank}" logger.info(f"Rank {compute_rank} worker key: {worker_key}") worker_options = BaseDistLoader.create_graph_store_worker_options( dataset=dataset, - compute_rank=compute_rank, worker_key=worker_key, num_workers=num_workers, worker_concurrency=worker_concurrency, @@ -418,6 +409,7 @@ def _setup_for_graph_store( edge_feature_info=edge_feature_info, edge_dir=dataset.fetch_edge_dir(), ), + backend_key, ) def _setup_for_colocated( diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index 0147d6929..cf499782d 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -18,19 +18,9 @@ import graphlearn_torch.distributed.dist_server as glt_dist_server import torch from graphlearn_torch.channel import QueueTimeoutError, SampleMessage, ShmChannel -from graphlearn_torch.distributed import ( - RemoteDistSamplingWorkerOptions, - barrier, - init_rpc, - shutdown_rpc, -) +from graphlearn_torch.distributed import barrier, init_rpc, shutdown_rpc from graphlearn_torch.partition import PartitionBook -from graphlearn_torch.sampler import ( - EdgeSamplerInput, - NodeSamplerInput, - RemoteSamplerInput, - SamplingConfig, -) +from graphlearn_torch.sampler import RemoteSamplerInput from gigl.common.logger import Logger from gigl.distributed.dist_dataset import DistDataset @@ -44,8 +34,7 @@ from gigl.distributed.graph_store.shared_dist_sampling_producer import ( SharedDistSamplingBackend, ) -from gigl.distributed.sampler import ABLPNodeSamplerInput -from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions +from gigl.distributed.sampler_options import PPRSamplerOptions from gigl.src.common.types.graph_data import EdgeType, NodeType from gigl.types.graph import FeatureInfo, select_label_edge_types from gigl.utils.data_splitters import get_labels_for_anchor_nodes @@ -680,60 +669,6 @@ def destroy_sampling_input(self, channel_id: int) -> None: if should_shutdown_backend: backend_state.runtime.shutdown() - def create_sampling_producer( - self, - sampler_input: Union[ - NodeSamplerInput, EdgeSamplerInput, RemoteSamplerInput, ABLPNodeSamplerInput - ], - sampling_config: SamplingConfig, - worker_options: RemoteDistSamplingWorkerOptions, - sampler_options: SamplerOptions, - ) -> int: - """Create a sampling producer by delegating to the two-phase API. - - Bridge method that keeps existing loaders working. Internally calls - :meth:`init_sampling_backend` and :meth:`register_sampling_input`, - returning the ``channel_id`` as the ``producer_id``. - - Args: - sampler_input: The input data for sampling. - sampling_config: Configuration of sampling meta info. - worker_options: Options for launching remote sampling workers. - sampler_options: Controls which sampler class is instantiated. - - Returns: - A unique ID (channel_id) usable as a producer_id. - """ - backend_id = self.init_sampling_backend( - InitSamplingBackendRequest( - backend_key=worker_options.worker_key, - worker_options=worker_options, - sampler_options=sampler_options, - sampling_config=sampling_config, - ) - ) - channel_id = self.register_sampling_input( - RegisterBackendRequest( - backend_id=backend_id, - worker_key=worker_options.worker_key, - sampler_input=sampler_input, - sampling_config=sampling_config, - buffer_capacity=worker_options.buffer_capacity, - buffer_size=worker_options.buffer_size, - ) - ) - return channel_id - - def destroy_sampling_producer(self, producer_id: int) -> None: - """Destroy a sampling producer by delegating to :meth:`destroy_sampling_input`. - - Bridge method that keeps existing loaders working. - - Args: - producer_id: The producer ID (channel_id) to destroy. - """ - self.destroy_sampling_input(producer_id) - def start_new_epoch_sampling(self, channel_id: int, epoch: int) -> None: """Start one new epoch on one registered channel. diff --git a/tests/integration/distributed/graph_store/graph_store_integration_test.py b/tests/integration/distributed/graph_store/graph_store_integration_test.py index 20b3da6e7..722d7a1cb 100644 --- a/tests/integration/distributed/graph_store/graph_store_integration_test.py +++ b/tests/integration/distributed/graph_store/graph_store_integration_test.py @@ -376,25 +376,54 @@ def _run_compute_multiple_loaders_test( remote_dist_dataset, ablp_result, prefetch_size=2 ) logger.info( - f"Rank {rank} / {world_size} ablp_loader_1 producers: ({ablp_loader_1._producer_id_list})" + f"Rank {rank} / {world_size} ablp_loader_1 backends/channels: " + f"({ablp_loader_1._backend_id_list}, {ablp_loader_1._channel_id_list})" ) ablp_loader_2 = _build_ablp_loader( remote_dist_dataset, ablp_result, prefetch_size=2 ) logger.info( - f"Rank {rank} / {world_size} ablp_loader_2 producers: ({ablp_loader_2._producer_id_list})" + f"Rank {rank} / {world_size} ablp_loader_2 backends/channels: " + f"({ablp_loader_2._backend_id_list}, {ablp_loader_2._channel_id_list})" ) neighbor_loader_1 = _build_neighbor_loader( remote_dist_dataset, random_negative_input ) logger.info( - f"Rank {rank} / {world_size} neighbor_loader_1 producers: ({neighbor_loader_1._producer_id_list})" + f"Rank {rank} / {world_size} neighbor_loader_1 backends/channels: " + f"({neighbor_loader_1._backend_id_list}, {neighbor_loader_1._channel_id_list})" ) neighbor_loader_2 = _build_neighbor_loader( remote_dist_dataset, random_negative_input ) logger.info( - f"Rank {rank} / {world_size} neighbor_loader_2 producers: ({neighbor_loader_2._producer_id_list})" + f"Rank {rank} / {world_size} neighbor_loader_2 backends/channels: " + f"({neighbor_loader_2._backend_id_list}, {neighbor_loader_2._channel_id_list})" + ) + gathered_ablp_loader_1_backends = [None] * world_size + torch.distributed.all_gather_object( + gathered_ablp_loader_1_backends, tuple(ablp_loader_1._backend_id_list) + ) + assert all( + backend_ids == gathered_ablp_loader_1_backends[0] + for backend_ids in gathered_ablp_loader_1_backends + ), "All ranks should share the same backend ids for one logical loader." + gathered_neighbor_loader_1_backends = [None] * world_size + torch.distributed.all_gather_object( + gathered_neighbor_loader_1_backends, tuple(neighbor_loader_1._backend_id_list) + ) + assert all( + backend_ids == gathered_neighbor_loader_1_backends[0] + for backend_ids in gathered_neighbor_loader_1_backends + ), "All ranks should share the same backend ids for one logical loader." + assert ablp_loader_1._backend_id_list != ablp_loader_2._backend_id_list, ( + "Concurrent ABLP loaders must use distinct backends." + ) + assert neighbor_loader_1._backend_id_list != neighbor_loader_2._backend_id_list, ( + "Concurrent neighbor loaders must use distinct backends." + ) + assert ablp_loader_1._backend_id_list != neighbor_loader_1._backend_id_list, ( + "ABLP and neighbor loaders must not share a backend." ) logger.info( f"Rank {rank} / {world_size} phase 1: loading batches from 4 parallel loaders" @@ -454,9 +483,8 @@ def _run_compute_multiple_loaders_test( local_expected=local_expected_negative_seeds, ) - # Shut down phase 1 loaders to free server-side producers and RPC resources - # before creating new loaders. This mirrors GLT's DistLoader.shutdown() which - # calls DistServer.destroy_sampling_producer for each remote producer. + # Shut down phase 1 loaders to free server-side channels and backend resources + # before creating new loaders. ablp_loader_1.shutdown() ablp_loader_2.shutdown() neighbor_loader_1.shutdown() @@ -469,13 +497,15 @@ def _run_compute_multiple_loaders_test( # ------------------------------------------------------------------ ablp_loader_3 = _build_ablp_loader(remote_dist_dataset, ablp_result) logger.info( - f"Rank {rank} / {world_size} ablp_loader_3 producers: ({ablp_loader_3._producer_id_list})" + f"Rank {rank} / {world_size} ablp_loader_3 backends/channels: " + f"({ablp_loader_3._backend_id_list}, {ablp_loader_3._channel_id_list})" ) neighbor_loader_3 = _build_neighbor_loader( remote_dist_dataset, random_negative_input ) logger.info( - f"Rank {rank} / {world_size} neighbor_loader_3 producers: ({neighbor_loader_3._producer_id_list})" + f"Rank {rank} / {world_size} neighbor_loader_3 backends/channels: " + f"({neighbor_loader_3._backend_id_list}, {neighbor_loader_3._channel_id_list})" ) logger.info( f"Rank {rank} / {world_size} phase 2: loading batches from 2 sequential loaders" @@ -976,7 +1006,7 @@ def test_multiple_loaders_in_graph_store(self): CORA_USER_DEFINED_NODE_ANCHOR_MOCKED_DATASET_INFO.name ] cluster_info = self._create_cluster_info( - num_storage_nodes=1, num_compute_nodes=1, num_processes_per_compute=1 + num_storage_nodes=1, num_compute_nodes=2, num_processes_per_compute=1 ) self._launch_graph_store_test( cluster_info=cluster_info, diff --git a/tests/unit/distributed/distributed_neighborloader_test.py b/tests/unit/distributed/distributed_neighborloader_test.py index 215429a5e..a1a9cca55 100644 --- a/tests/unit/distributed/distributed_neighborloader_test.py +++ b/tests/unit/distributed/distributed_neighborloader_test.py @@ -676,14 +676,6 @@ def test_isolated_homogeneous_neighbor_loader( num_neighbors=[2, 2], input_nodes={-1: torch.tensor([10]), 0: torch.tensor([20])}, ), - param( - "max_concurrent_producer_inits is not None (colocated mode)", - expected_error=ValueError, - dataset=DistDataset(rank=0, world_size=1, edge_dir="out"), - num_neighbors=[2, 2], - input_nodes=torch.tensor([10]), - max_concurrent_producer_inits=1, - ), ] ) def test_distributed_neighbor_loader_invalid_inputs_colocated(