diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index d0d6ab7e8..c95542946 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -13,7 +13,8 @@ import time from collections import Counter, defaultdict from dataclasses import dataclass -from typing import Callable, Optional, Union +from itertools import count +from typing import Callable, Optional, TypeVar, Union import torch from graphlearn_torch.channel import SampleMessage, ShmChannel @@ -23,7 +24,6 @@ RemoteDistSamplingWorkerOptions, get_context, ) -from graphlearn_torch.distributed.dist_client import async_request_server from graphlearn_torch.distributed.rpc import rpc_is_initialized from graphlearn_torch.sampler import ( EdgeSamplerInput, @@ -42,7 +42,12 @@ from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.dist_sampling_producer import DistSamplingProducer +from gigl.distributed.graph_store.compute import async_request_server from gigl.distributed.graph_store.dist_server import DistServer +from gigl.distributed.graph_store.messages import ( + InitSamplingBackendRequest, + RegisterBackendRequest, +) from gigl.distributed.graph_store.remote_channel import RemoteReceivingChannel from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions @@ -55,6 +60,7 @@ logger = Logger() DEFAULT_NUM_CPU_THREADS = 2 +T = TypeVar("T") # We don't see logs for graph store mode for whatever reason. @@ -78,6 +84,81 @@ class DistributedRuntimeInfo: should_cleanup_distributed_context: bool +@dataclass(frozen=True) +class GroupLeaderInfo: + """Leader election result for one graph-store dispatch group.""" + + leader_rank: int + is_leader: bool + my_batch: int + num_batches: int + stagger_sleep: float + group_size: int + + +def _compute_group_leader( + my_key: str, + all_keys: list[Optional[str]], + rank: int, + process_start_gap_seconds: float, + max_concurrent_producer_inits: int, +) -> GroupLeaderInfo: + """Compute leader election and stagger information for one key group.""" + key_to_ranks: dict[str, list[int]] = defaultdict(list) + for other_rank, key in enumerate(all_keys): + assert key is not None, f"Rank {other_rank} did not provide a key." + key_to_ranks[key].append(other_rank) + + leader_rank = min(key_to_ranks[my_key]) + unique_keys = sorted(key_to_ranks.keys()) + my_key_index = unique_keys.index(my_key) + num_batches = math.ceil(len(unique_keys) / max_concurrent_producer_inits) + my_batch = my_key_index // max_concurrent_producer_inits + stagger_sleep = my_batch * process_start_gap_seconds + return GroupLeaderInfo( + leader_rank=leader_rank, + is_leader=rank == leader_rank, + my_batch=my_batch, + num_batches=num_batches, + stagger_sleep=stagger_sleep, + group_size=len(key_to_ranks[my_key]), + ) + + +def _dispatch_grouped_graph_store_phase( + *, + my_key: str, + runtime: DistributedRuntimeInfo, + process_start_gap_seconds: float, + max_concurrent_producer_inits: int, + issue_phase_rpcs: Callable[[], list[T]], +) -> list[T]: + """Run one grouped graph-store RPC phase under leader election.""" + all_keys: list[Optional[str]] = [None] * runtime.world_size + torch.distributed.all_gather_object(all_keys, my_key) + group_info = _compute_group_leader( + my_key=my_key, + all_keys=all_keys, + rank=runtime.rank, + process_start_gap_seconds=process_start_gap_seconds, + max_concurrent_producer_inits=max_concurrent_producer_inits, + ) + logger.info( + f"rank={runtime.rank} phase_key={my_key} " + f"is_leader={group_info.is_leader} leader_rank={group_info.leader_rank} " + f"group_size={group_info.group_size} " + f"batch={group_info.my_batch}/{group_info.num_batches} " + f"stagger_sleep={group_info.stagger_sleep:.1f}s" + ) + _flush() + if group_info.is_leader and group_info.stagger_sleep > 0: + time.sleep(group_info.stagger_sleep) + results = issue_phase_rpcs() if group_info.is_leader else [] + all_results: list[list[T]] = [[] for _ in range(runtime.world_size)] + torch.distributed.all_gather_object(all_results, results) + return all_results[group_info.leader_rank] + + class BaseDistLoader(DistLoader): """Base class for GiGL distributed loaders. @@ -91,8 +172,8 @@ class BaseDistLoader(DistLoader): 3. Call ``create_sampling_config()`` to build the SamplingConfig. 4. For colocated: call ``create_colocated_channel()`` and construct the ``DistSamplingProducer`` (or subclass), then pass the producer as ``producer``. - 5. For graph store: pass the RPC function (e.g. ``DistServer.create_sampling_producer``) - as ``producer``. + 5. For graph store: prepare remote inputs and worker options; the base class + handles the two-phase backend/channel RPC initialization. 6. Call ``super().__init__()`` with the prepared data. Args: @@ -105,20 +186,21 @@ class BaseDistLoader(DistLoader): sampling_config: Configuration for sampling (created via ``create_sampling_config``). device: Target device for sampled results. runtime: Resolved distributed runtime information. - producer: Either a pre-constructed ``DistSamplingProducer`` (colocated mode) - or a callable to dispatch on the ``DistServer`` (graph store mode). + producer: Optional pre-constructed ``DistSamplingProducer`` for colocated mode. sampler_options: Controls which sampler class is instantiated. process_start_gap_seconds: Delay between each process for staggered colocated init. In graph store mode, this is the delay between each batch of concurrent - producer initializations. + backend or registration initializations. max_concurrent_producer_inits: Maximum number of leader ranks that may - dispatch ``create_producer_fn`` RPCs concurrently in graph store mode. + dispatch graph-store RPC batches concurrently. Leaders are grouped into batches of this size; each batch sleeps ``batch_index * process_start_gap_seconds`` before dispatching. Only applies to graph store mode. Defaults to ``None`` (no staggering). """ + _global_loader_counter = count(0) + @staticmethod def resolve_runtime( context: Optional[DistributedContext] = None, @@ -220,7 +302,7 @@ def __init__( sampling_config: SamplingConfig, device: torch.device, runtime: DistributedRuntimeInfo, - producer: Union[DistSamplingProducer, Callable[..., int]], + producer: Optional[DistSamplingProducer], sampler_options: SamplerOptions, process_start_gap_seconds: float = 60.0, max_concurrent_producer_inits: Optional[int] = None, @@ -242,6 +324,8 @@ def __init__( self._sampler_options = sampler_options self._non_blocking_transfers = non_blocking_transfers + if not hasattr(self, "_backend_key"): + self._backend_key: Optional[str] = None # --- Attributes shared by both modes (mirrors GLT DistLoader.__init__) --- self.input_data = sampler_input @@ -264,7 +348,11 @@ def __init__( self._epoch = 0 # --- Mode-specific attributes and connection initialization --- - if isinstance(producer, DistSamplingProducer): + if ( + isinstance(dataset, DistDataset) + and isinstance(worker_options, MpDistSamplingWorkerOptions) + and isinstance(producer, DistSamplingProducer) + ): assert isinstance(dataset, DistDataset) assert isinstance(worker_options, MpDistSamplingWorkerOptions) assert isinstance(sampler_input, NodeSamplerInput) @@ -285,6 +373,7 @@ def __init__( if not self.drop_last and self._input_len % self.batch_size != 0: self._num_expected += 1 + self._remote_input_has_batches: list[bool] = [] self._shutdowned = False self._init_colocated_connections( dataset=dataset, @@ -292,11 +381,10 @@ def __init__( runtime=runtime, process_start_gap_seconds=process_start_gap_seconds, ) - else: - assert isinstance(dataset, RemoteDistDataset) - assert isinstance(worker_options, RemoteDistSamplingWorkerOptions) + elif isinstance(dataset, RemoteDistDataset) and isinstance( + worker_options, RemoteDistSamplingWorkerOptions + ): assert isinstance(sampler_input, list) - assert callable(producer) self.data = None self._is_mp_worker = False @@ -322,14 +410,23 @@ def __init__( node_types = [DEFAULT_HOMOGENEOUS_NODE_TYPE] self._set_ntypes_and_etypes(node_types, edge_types) + self._remote_input_has_batches = [ + self._sampler_input_has_batches(inp) for inp in self._input_data_list + ] self._shutdowned = False self._init_graph_store_connections( dataset=dataset, - create_producer_fn=producer, runtime=runtime, process_start_gap_seconds=process_start_gap_seconds, max_concurrent_producer_inits=max_concurrent_producer_inits, ) + else: + raise TypeError( + "Invalid loader construction. Expected either " + "(DistDataset, MpDistSamplingWorkerOptions, DistSamplingProducer) " + "for colocated mode or (RemoteDistDataset, RemoteDistSamplingWorkerOptions) " + "for graph-store mode." + ) @staticmethod def create_sampling_config( @@ -549,7 +646,6 @@ def create_colocated_worker_options( def create_graph_store_worker_options( *, dataset: RemoteDistDataset, - compute_rank: int, worker_key: str, num_workers: int, worker_concurrency: int, @@ -560,7 +656,6 @@ def create_graph_store_worker_options( Args: dataset: Remote dataset proxy used to discover storage-cluster topology. - compute_rank: Global compute-process rank for the current process. worker_key: Unique key used by the storage cluster to deduplicate producers. num_workers: Number of sampling worker processes. worker_concurrency: Max sampling concurrency per worker. @@ -570,10 +665,8 @@ def create_graph_store_worker_options( Returns: Fully configured worker options for graph-store sampling. """ - sampling_ports = dataset.fetch_free_ports_on_storage_cluster( - num_ports=dataset.cluster_info.compute_cluster_world_size - ) - sampling_port = sampling_ports[compute_rank] + sampling_ports = dataset.fetch_free_ports_on_storage_cluster(num_ports=1) + sampling_port = sampling_ports[0] return RemoteDistSamplingWorkerOptions( server_rank=list(range(dataset.cluster_info.num_storage_nodes)), num_workers=num_workers, @@ -633,88 +726,129 @@ def _init_colocated_connections( time.sleep(process_start_gap_seconds * runtime.local_rank) self._mp_producer.init() + def _init_graph_store_sampling_backends( + self, + runtime: DistributedRuntimeInfo, + process_start_gap_seconds: float, + max_concurrent_producer_inits: int, + ) -> list[int]: + """Initialize or reuse one shared backend per storage server.""" + if self._backend_key is None: + raise RuntimeError( + f"{type(self).__name__} did not set _backend_key. " + "Subclasses must set self._backend_key in _setup_for_graph_store()." + ) + backend_key = self._backend_key + + def issue_rpcs() -> list[int]: + futures: list[torch.futures.Future[int]] = [] + for server_rank in self._server_rank_list: + futures.append( + async_request_server( + server_rank, + DistServer.init_sampling_backend, + InitSamplingBackendRequest( + backend_key=backend_key, + worker_options=self.worker_options, + sampler_options=self._sampler_options, + sampling_config=self.sampling_config, + ), + ) + ) + return [future.wait() for future in futures] + + return _dispatch_grouped_graph_store_phase( + my_key=backend_key, + runtime=runtime, + process_start_gap_seconds=process_start_gap_seconds, + max_concurrent_producer_inits=max_concurrent_producer_inits, + issue_phase_rpcs=issue_rpcs, + ) + + def _register_graph_store_sampling_inputs( + self, + runtime: DistributedRuntimeInfo, + backend_id_list: list[int], + process_start_gap_seconds: float, + max_concurrent_producer_inits: int, + ) -> list[int]: + """Register this compute rank's inputs on existing shared backends.""" + assert ( + len(self._server_rank_list) + == len(backend_id_list) + == len(self._input_data_list) + ), ( + f"Mismatched lengths: server_rank_list={len(self._server_rank_list)}, " + f"backend_id_list={len(backend_id_list)}, " + f"input_data_list={len(self._input_data_list)}" + ) + worker_key = self.worker_options.worker_key + + def issue_rpcs() -> list[int]: + futures: list[torch.futures.Future[int]] = [] + for server_rank, backend_id, input_data in zip( + self._server_rank_list, + backend_id_list, + self._input_data_list, + ): + futures.append( + async_request_server( + server_rank, + DistServer.register_sampling_input, + RegisterBackendRequest( + backend_id=backend_id, + worker_key=worker_key, + sampler_input=input_data, + sampling_config=self.sampling_config, + buffer_capacity=self.worker_options.buffer_capacity, + buffer_size=self.worker_options.buffer_size, + ), + ) + ) + return [future.wait() for future in futures] + + return _dispatch_grouped_graph_store_phase( + my_key=worker_key, + runtime=runtime, + process_start_gap_seconds=process_start_gap_seconds, + max_concurrent_producer_inits=max_concurrent_producer_inits, + issue_phase_rpcs=issue_rpcs, + ) + + def _sampler_input_has_batches(self, sampler_input: NodeSamplerInput) -> bool: + """Return whether this sampler input can produce at least one batch.""" + input_len = len(sampler_input) + return input_len > 0 and not (self.drop_last and input_len < self.batch_size) + def _init_graph_store_connections( self, dataset: RemoteDistDataset, - create_producer_fn: Callable[..., int], runtime: DistributedRuntimeInfo, process_start_gap_seconds: float = 60.0, - max_concurrent_producer_inits: int = sys.maxsize, # Already resolved from None by __init__ + max_concurrent_producer_inits: int = sys.maxsize, ) -> None: - """Initialize Graph Store mode connections. - - Validates the GLT distributed context, elects a leader per ``worker_key`` - group to dispatch RPCs that create sampling producers on storage nodes, - then distributes the resulting producer IDs to all ranks in the group via - ``all_gather_object``. - The `worker_key` is set-upstream (and are passed in as part of `RemoteDistSamplingWorkerOptions`) - to group producers together, so that different processes in the compute cluster (e.g. different GPUs) - can share the same producer. - - Only the leader rank (minimum rank sharing a given ``worker_key``) sends - the ``create_producer_fn`` RPCs. This avoids redundant RPCs and - server-side lock contention, since the server deduplicates producers by - ``worker_key`` anyway. - - Leaders are further staggered into batches of size - ``max_concurrent_producer_inits``. Each batch sleeps - ``batch_index * process_start_gap_seconds`` before dispatching RPCs, - limiting the number of leaders that hit storage nodes concurrently. - - All DistLoader attributes are already set by ``__init__`` before this is called. - - Uses ``async_request_server`` instead of ``ThreadPoolExecutor`` to avoid - TensorPipe rendezvous deadlock with many servers. - - For Graph Store mode it's important to distinguish "compute node" (e.g. physical compute machine) from "compute process" (e.g. process running on the compute node). - Since in practice we have multiple compute processes per compute node, and each compute process needs to initialize the connection to the storage nodes. - E.g. if there are 4 gpus per compute node, then there will be 4 connections from each compute node to each storage node. - - See below for a connection setup. - ╔═══════════════════════════════════════════════════════════════════════════════════════╗ - ║ COMPUTE TO STORAGE NODE CONNECTIONS ║ - ╚═══════════════════════════════════════════════════════════════════════════════════════╝ - - COMPUTE NODES STORAGE NODES - ═════════════ ═════════════ - - ┌──────────────────────┐ (1) ┌───────────────┐ - │ COMPUTE NODE 0 │ │ │ - │ ┌────┬────┬────┬────┤ ══════════════════════════════════│ STORAGE 0 │ - │ │GPU │GPU │GPU │GPU │ ╱ │ │ - │ │ 0 │ 1 │ 2 │ 3 │ ════════════════════╲ ╱ └───────────────┘ - │ └────┴────┴────┴────┤ (2) ╲ ╱ - └──────────────────────┘ ╲ ╱ - ╳ - (3) ╱ ╲ (4) - ┌──────────────────────┐ ╱ ╲ ┌───────────────┐ - │ COMPUTE NODE 1 │ ╱ ╲ │ │ - │ ┌────┬────┬────┬────┤ ═════════════════╱ ═│ STORAGE 1 │ - │ │GPU │GPU │GPU │GPU │ │ │ - │ │ 0 │ 1 │ 2 │ 3 │ ══════════════════════════════════│ │ - │ └────┴────┴────┴────┤ └───────────────┘ - └──────────────────────┘ - - ┌─────────────────────────────────────────────────────────────────────────────┐ - │ (1) Compute Node 0 → Storage 0 (4 connections, one per GPU) │ - │ (2) Compute Node 0 → Storage 1 (4 connections, one per GPU) │ - │ (3) Compute Node 1 → Storage 0 (4 connections, one per GPU) │ - │ (4) Compute Node 1 → Storage 1 (4 connections, one per GPU) │ - └─────────────────────────────────────────────────────────────────────────────┘ - - Args: - dataset: The remote dataset proxy for graph store mode. - create_producer_fn: RPC callable to create a sampling producer on a - storage server (e.g. ``DistServer.create_sampling_producer``). - runtime: Resolved distributed runtime information (provides rank and - world_size for the all_gather collectives). - process_start_gap_seconds: Delay in seconds between each batch of - concurrent producer initializations. - max_concurrent_producer_inits: Maximum number of leader ranks that - may dispatch RPCs concurrently. Leaders are grouped into batches - of this size. + """Initialize graph-store mode with shared backends and per-rank channels. + + Sets two parallel lists indexed by storage server: + + - ``_backend_id_list``: one backend per storage server, shared across all + compute ranks using the same loader instance (keyed by ``_backend_key``). + Use this when the operation targets the backend itself (e.g. + ``init_sampling_backend``). + + - ``_channel_id_list``: one channel per storage server, unique to this + compute rank. Use this for per-rank operations (e.g. + ``start_new_epoch_sampling``, ``fetch_one_sampled_message``, + ``destroy_sampling_input``). + + Example with 2 storage servers and 4 compute ranks sharing one + DistNeighborLoader: + backend_id_list = [0, 1] # same for all 4 ranks + channel_id_list = [0, 1] # rank 0 + channel_id_list = [2, 3] # rank 1 + channel_id_list = [4, 5] # rank 2 + channel_id_list = [6, 7] # rank 3 """ - # Validate distributed context ctx = get_context() if ctx is None: raise RuntimeError( @@ -723,139 +857,35 @@ def _init_graph_store_connections( ) if not ctx.is_client(): raise RuntimeError( - f"'{self.__class__.__name__}': must be used on a client " - f"worker process." + f"'{self.__class__.__name__}': must be used on a client worker process." ) - # Move input to CPU before sending to server for inp in self._input_data_list: if not isinstance(inp, RemoteSamplerInput): inp.to(torch.device("cpu")) - node_rank = dataset.cluster_info.compute_node_rank - - _flush() start_time = time.time() - - # --- Leader election via worker_key all_gather --- - # All ranks exchange their worker_key so we can group ranks that share - # the same key and elect the minimum rank as the leader. - my_worker_key: str = self.worker_options.worker_key - all_worker_keys: list[Optional[str]] = [None] * runtime.world_size - torch.distributed.all_gather_object(all_worker_keys, my_worker_key) - - key_to_ranks: dict[str, list[int]] = defaultdict(list) - for r, key in enumerate(all_worker_keys): - assert key is not None, f"Rank {r} did not provide a worker_key" - key_to_ranks[key].append(r) - - leader_rank = min(key_to_ranks[my_worker_key]) - is_leader = runtime.rank == leader_rank - - # --- Stagger leaders into batches --- - # Deterministically assign each unique worker_key an index, then group - # leaders into batches of max_concurrent_producer_inits. Each batch - # sleeps batch_index * process_start_gap_seconds before dispatching. - unique_keys = sorted(key_to_ranks.keys()) - my_key_index = unique_keys.index(my_worker_key) - num_unique_keys = len(unique_keys) - num_batches = math.ceil(num_unique_keys / max_concurrent_producer_inits) - my_batch = my_key_index // max_concurrent_producer_inits - stagger_sleep_seconds = my_batch * process_start_gap_seconds - - logger.info( - f"rank={runtime.rank} worker_key={my_worker_key} " - f"is_leader={is_leader} leader_rank={leader_rank} " - f"group_size={len(key_to_ranks[my_worker_key])} " - f"key_index={my_key_index}/{num_unique_keys} " - f"batch={my_batch}/{num_batches} " - f"stagger_sleep={stagger_sleep_seconds:.1f}s" + self._backend_id_list = self._init_graph_store_sampling_backends( + runtime=runtime, + process_start_gap_seconds=process_start_gap_seconds, + max_concurrent_producer_inits=max_concurrent_producer_inits, ) - _flush() - - # --- Leader dispatches RPCs, followers skip --- - producer_id_list: list[int] = [] - if is_leader: - if stagger_sleep_seconds > 0: - logger.info( - f"rank={runtime.rank} sleeping {stagger_sleep_seconds:.1f}s " - f"(batch {my_batch}/{num_batches}) before RPC dispatch" - ) - _flush() - time.sleep(stagger_sleep_seconds) - - rpc_futures: list[tuple[int, torch.futures.Future[int]]] = [] - logger.info( - f"node_rank={node_rank} rank={runtime.rank} dispatching " - f"create_sampling_producer to " - f"{len(self._server_rank_list)} servers" - ) - _flush() - t_dispatch = time.time() - for server_rank, inp_data in zip( - self._server_rank_list, self._input_data_list - ): - fut = async_request_server( - server_rank, - create_producer_fn, - inp_data, - self.sampling_config, - self.worker_options, - self._sampler_options, - ) - rpc_futures.append((server_rank, fut)) - logger.info( - f"node_rank={node_rank} rank={runtime.rank} all " - f"{len(rpc_futures)} RPCs dispatched in " - f"{time.time() - t_dispatch:.3f}s, waiting for responses" - ) - _flush() - - for server_rank, fut in rpc_futures: - t_wait = time.time() - producer_id = fut.wait() - logger.info( - f"node_rank={node_rank} rank={runtime.rank} " - f"create_sampling_producer" - f"(server_rank={server_rank}) returned " - f"producer_id={producer_id} in {time.time() - t_wait:.2f}s" - ) - _flush() - producer_id_list.append(producer_id) - logger.info( - f"node_rank={node_rank} rank={runtime.rank} all " - f"{len(producer_id_list)} producers " - f"created in {time.time() - t_dispatch:.2f}s total" - ) - _flush() - else: # if not leader - logger.info( - f"Since rank {runtime.rank} is not the leader for worker key {my_worker_key}, we will wait for the leader (rank {leader_rank}) to dispatch RPCs" - ) - - # --- Distribute producer IDs to all ranks --- - all_producer_ids: list[list[int]] = [[] for _ in range(runtime.world_size)] - torch.distributed.all_gather_object(all_producer_ids, producer_id_list) - self._producer_id_list = all_producer_ids[leader_rank] - - logger.info( - f"rank={runtime.rank} received producer_id_list=" - f"{self._producer_id_list} from leader_rank={leader_rank}" + self._channel_id_list = self._register_graph_store_sampling_inputs( + runtime=runtime, + backend_id_list=self._backend_id_list, + process_start_gap_seconds=process_start_gap_seconds, + max_concurrent_producer_inits=max_concurrent_producer_inits, ) - _flush() - - # Create remote receiving channel for cross-machine message passing self._channel = RemoteReceivingChannel( server_rank=self._server_rank_list, - channel_id=self._producer_id_list, + channel_id=self._channel_id_list, prefetch_size=self.worker_options.prefetch_size, - active_mask=[len(inp) > 0 for inp in self._input_data_list], + active_mask=self._remote_input_has_batches, pin_memory=self.to_device is not None and self.to_device.type == "cuda", ) - logger.info( - f"node_rank {node_rank} rank={runtime.rank} initialized " - f"the dist loader in {time.time() - start_time:.2f}s" + f"node_rank {dataset.cluster_info.compute_node_rank} rank={runtime.rank} " + f"initialized shared graph-store loader in {time.time() - start_time:.2f}s" ) _flush() @@ -869,11 +899,11 @@ def shutdown(self) -> None: self._mp_producer.shutdown() elif rpc_is_initialized() is True: rpc_futures: list[torch.futures.Future[None]] = [] - for server_rank, producer_id in zip( - self._server_rank_list, self._producer_id_list + for server_rank, channel_id in zip( + self._server_rank_list, self._channel_id_list ): fut = async_request_server( - server_rank, DistServer.destroy_sampling_producer, producer_id + server_rank, DistServer.destroy_sampling_input, channel_id ) rpc_futures.append(fut) torch.futures.wait_all(rpc_futures) @@ -917,13 +947,17 @@ def __iter__(self) -> Self: self._mp_producer.produce_all() else: rpc_futures: list[torch.futures.Future[None]] = [] - for server_rank, producer_id in zip( - self._server_rank_list, self._producer_id_list + for server_rank, channel_id, has_batches in zip( + self._server_rank_list, + self._channel_id_list, + self._remote_input_has_batches, ): + if not has_batches: + continue fut = async_request_server( server_rank, DistServer.start_new_epoch_sampling, - producer_id, + channel_id, self._epoch, ) rpc_futures.append(fut) diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 276525c3f..6c3d35bbd 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -1,6 +1,5 @@ from collections import abc, defaultdict -from itertools import count -from typing import Callable, Optional, Union +from typing import Optional, Union import torch from graphlearn_torch.channel import SampleMessage @@ -21,7 +20,6 @@ PPR_WEIGHT_METADATA_KEY, ) from gigl.distributed.dist_sampling_producer import DistSamplingProducer -from gigl.distributed.graph_store.dist_server import DistServer from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset from gigl.distributed.sampler import ( NEGATIVE_LABEL_METADATA_KEY, @@ -64,11 +62,6 @@ class DistABLPLoader(BaseDistLoader): - # Counts instantiations of this class, per process. - # This is needed so we can generate unique worker key for each instance, for graph store mode. - # NOTE: This is per-class, not per-instance. - _counter = count(0) - def __init__( self, dataset: Union[DistDataset, RemoteDistDataset], @@ -266,7 +259,7 @@ def __init__( logger.info(f"Sampling cluster setup: {self._sampling_cluster_setup.value}") del supervision_edge_type - self._instance_count = next(self._counter) + self._instance_count = next(BaseDistLoader._global_loader_counter) # Resolve distributed context runtime = BaseDistLoader.resolve_runtime( @@ -362,22 +355,17 @@ def __init__( drop_last=drop_last, ) - # Build the producer: a pre-constructed producer for colocated mode, - # or an RPC callable for graph store mode. + producer: Optional[DistSamplingProducer] = None if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: assert isinstance(dataset, DistDataset) assert isinstance(worker_options, MpDistSamplingWorkerOptions) - producer: Union[ - DistSamplingProducer, Callable[..., int] - ] = BaseDistLoader.create_mp_producer( + producer = BaseDistLoader.create_mp_producer( dataset=dataset, sampler_input=sampler_input, sampling_config=sampling_config, worker_options=worker_options, sampler_options=sampler_options, ) - else: - producer = DistServer.create_sampling_producer # Call base class — handles metadata storage and connection initialization # (including staggered init for colocated mode). @@ -624,13 +612,11 @@ def _setup_for_graph_store( edge_feature_info = dataset.fetch_edge_feature_info() edge_types = dataset.fetch_edge_types() compute_rank = torch.distributed.get_rank() - worker_key = ( - f"compute_ablp_loader_rank_{compute_rank}_worker_{self._instance_count}" - ) + self._backend_key = f"dist_ablp_loader_{self._instance_count}" + worker_key = f"{self._backend_key}_compute_rank_{compute_rank}" logger.info(f"rank: {compute_rank}, worker_key: {worker_key}") worker_options = BaseDistLoader.create_graph_store_worker_options( dataset=dataset, - compute_rank=compute_rank, worker_key=worker_key, num_workers=num_workers, worker_concurrency=worker_concurrency, diff --git a/gigl/distributed/dist_sampling_producer.py b/gigl/distributed/dist_sampling_producer.py index f155bd929..31fa6574b 100644 --- a/gigl/distributed/dist_sampling_producer.py +++ b/gigl/distributed/dist_sampling_producer.py @@ -16,6 +16,7 @@ DistDataset, DistMpSamplingProducer, MpDistSamplingWorkerOptions, + RemoteDistSamplingWorkerOptions, init_rpc, init_worker_group, shutdown_rpc, @@ -39,6 +40,7 @@ from gigl.common.logger import Logger from gigl.distributed.dist_neighbor_sampler import DistNeighborSampler from gigl.distributed.dist_ppr_sampler import DistPPRNeighborSampler +from gigl.distributed.sampler import ABLPNodeSamplerInput from gigl.distributed.sampler_options import ( KHopNeighborSamplerOptions, PPRSamplerOptions, @@ -47,6 +49,78 @@ logger = Logger() +SamplerInput = Union[NodeSamplerInput, EdgeSamplerInput, ABLPNodeSamplerInput] +SamplerRuntime = Union[DistNeighborSampler, DistPPRNeighborSampler] + + +def _create_dist_sampler( + *, + data: DistDataset, + sampling_config: SamplingConfig, + worker_options: Union[MpDistSamplingWorkerOptions, RemoteDistSamplingWorkerOptions], + channel: ChannelBase, + sampler_options: SamplerOptions, + degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], + current_device: torch.device, +) -> SamplerRuntime: + """Create a GiGL sampler runtime for one channel on one worker.""" + shared_sampler_args = ( + data, + sampling_config.num_neighbors, + sampling_config.with_edge, + sampling_config.with_neg, + sampling_config.with_weight, + sampling_config.edge_dir, + sampling_config.collect_features, + channel, + worker_options.use_all2all, + worker_options.worker_concurrency, + current_device, + ) + if isinstance(sampler_options, KHopNeighborSamplerOptions): + sampler: SamplerRuntime = DistNeighborSampler( + *shared_sampler_args, + seed=sampling_config.seed, + ) + elif isinstance(sampler_options, PPRSamplerOptions): + assert degree_tensors is not None + sampler = DistPPRNeighborSampler( + *shared_sampler_args, + seed=sampling_config.seed, + alpha=sampler_options.alpha, + eps=sampler_options.eps, + max_ppr_nodes=sampler_options.max_ppr_nodes, + num_neighbors_per_hop=sampler_options.num_neighbors_per_hop, + total_degree_dtype=sampler_options.total_degree_dtype, + degree_tensors=degree_tensors, + ) + else: + raise NotImplementedError( + f"Unsupported sampler options type: {type(sampler_options)}" + ) + return sampler + + +def _prepare_degree_tensors( + data: DistDataset, + sampler_options: SamplerOptions, +) -> Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: + """Materialize PPR degree tensors before worker spawn when required.""" + degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]] = None + if isinstance(sampler_options, PPRSamplerOptions): + degree_tensors = data.degree_tensor + if isinstance(degree_tensors, dict): + logger.info( + "Pre-computed degree tensors for PPR sampling across " + f"{len(degree_tensors)} edge types." + ) + elif degree_tensors is not None: + logger.info( + "Pre-computed degree tensor for PPR sampling with " + f"{degree_tensors.size(0)} nodes." + ) + return degree_tensors + def _sampling_worker_loop( rank: int, @@ -100,42 +174,15 @@ def _sampling_worker_loop( if sampling_config.seed is not None: seed_everything(sampling_config.seed) - # Shared args for all sampler types (positional args to DistNeighborSampler.__init__) - shared_sampler_args = ( - data, - sampling_config.num_neighbors, - sampling_config.with_edge, - sampling_config.with_neg, - sampling_config.with_weight, - sampling_config.edge_dir, - sampling_config.collect_features, - channel, - worker_options.use_all2all, - worker_options.worker_concurrency, - current_device, + dist_sampler = _create_dist_sampler( + data=data, + sampling_config=sampling_config, + worker_options=worker_options, + channel=channel, + sampler_options=sampler_options, + degree_tensors=degree_tensors, + current_device=current_device, ) - - if isinstance(sampler_options, KHopNeighborSamplerOptions): - dist_sampler = DistNeighborSampler( - *shared_sampler_args, - seed=sampling_config.seed, - ) - elif isinstance(sampler_options, PPRSamplerOptions): - assert degree_tensors is not None - dist_sampler = DistPPRNeighborSampler( - *shared_sampler_args, - seed=sampling_config.seed, - alpha=sampler_options.alpha, - eps=sampler_options.eps, - max_ppr_nodes=sampler_options.max_ppr_nodes, - num_neighbors_per_hop=sampler_options.num_neighbors_per_hop, - total_degree_dtype=sampler_options.total_degree_dtype, - degree_tensors=degree_tensors, - ) - else: - raise NotImplementedError( - f"Unsupported sampler options type: {type(sampler_options)}" - ) dist_sampler.start_loop() unshuffled_index_loader: Optional[DataLoader] @@ -186,16 +233,13 @@ def _sampling_worker_loop( dist_sampler.wait_all() with sampling_completed_worker_count.get_lock(): - sampling_completed_worker_count.value += ( - 1 # non-atomic, lock is necessary - ) + sampling_completed_worker_count.value += 1 elif command == MpCommand.STOP: keep_running = False else: raise RuntimeError("Unknown command type") except KeyboardInterrupt: - # Main process will raise KeyboardInterrupt anyways. pass if dist_sampler is not None: @@ -236,7 +280,7 @@ def init(self): self.num_workers * self.worker_options.worker_concurrency ) self._task_queues.append(task_queue) - w = mp_context.Process( + worker = mp_context.Process( target=_sampling_worker_loop, args=( rank, @@ -253,7 +297,7 @@ def init(self): self._degree_tensors, ), ) - w.daemon = True - w.start() - self._workers.append(w) + worker.daemon = True + worker.start() + self._workers.append(worker) barrier.wait() diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index 63945719e..9b7825f27 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -1,7 +1,6 @@ import sys from collections import abc -from itertools import count -from typing import Callable, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch from graphlearn_torch.channel import SampleMessage @@ -23,7 +22,6 @@ PPR_WEIGHT_METADATA_KEY, ) from gigl.distributed.dist_sampling_producer import DistSamplingProducer -from gigl.distributed.graph_store.dist_server import DistServer as GiglDistServer from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset from gigl.distributed.sampler_options import ( PPRSamplerOptions, @@ -61,11 +59,6 @@ def flush(): class DistNeighborLoader(BaseDistLoader): - # Counts instantiations of this class, per process. - # This is needed so we can generate unique worker key for each instance, for graph store mode. - # NOTE: This is per-class, not per-instance. - _counter = count(0) - def __init__( self, dataset: Union[DistDataset, RemoteDistDataset], @@ -208,7 +201,7 @@ def __init__( ) logger.info(f"Sampling cluster setup: {self._sampling_cluster_setup.value}") - self._instance_count = next(self._counter) + self._instance_count = next(BaseDistLoader._global_loader_counter) device = ( pin_memory_device if pin_memory_device @@ -271,22 +264,17 @@ def __init__( drop_last=drop_last, ) - # Build the producer: a pre-constructed producer for colocated mode, - # or an RPC callable for graph store mode. + producer: Optional[DistSamplingProducer] = None if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: assert isinstance(dataset, DistDataset) assert isinstance(worker_options, MpDistSamplingWorkerOptions) - producer: Union[ - DistSamplingProducer, Callable[..., int] - ] = BaseDistLoader.create_mp_producer( + producer = BaseDistLoader.create_mp_producer( dataset=dataset, sampler_input=input_data, sampling_config=sampling_config, worker_options=worker_options, sampler_options=sampler_options, ) - else: - producer = GiglDistServer.create_sampling_producer # Call base class — handles metadata storage and connection initialization # (including staggered init for colocated mode). @@ -341,11 +329,11 @@ def _setup_for_graph_store( edge_types = dataset.fetch_edge_types() compute_rank = torch.distributed.get_rank() - worker_key = f"compute_rank_{compute_rank}_worker_{self._instance_count}" + self._backend_key = f"dist_neighbor_loader_{self._instance_count}" + worker_key = f"{self._backend_key}_compute_rank_{compute_rank}" logger.info(f"Rank {compute_rank} worker key: {worker_key}") worker_options = BaseDistLoader.create_graph_store_worker_options( dataset=dataset, - compute_rank=compute_rank, worker_key=worker_key, num_workers=num_workers, worker_concurrency=worker_concurrency, diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index cd0041f1f..1cd496a2a 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -11,36 +11,28 @@ import logging import threading import time -import warnings from collections import abc +from dataclasses import dataclass, field from typing import Any, Callable, Literal, Optional, TypeVar, Union import graphlearn_torch.distributed.dist_server as glt_dist_server import torch from graphlearn_torch.channel import QueueTimeoutError, SampleMessage, ShmChannel -from graphlearn_torch.distributed import ( - RemoteDistSamplingWorkerOptions, - barrier, - init_rpc, - shutdown_rpc, -) +from graphlearn_torch.distributed import barrier, init_rpc, shutdown_rpc from graphlearn_torch.partition import PartitionBook -from graphlearn_torch.sampler import ( - EdgeSamplerInput, - NodeSamplerInput, - RemoteSamplerInput, - SamplingConfig, -) +from graphlearn_torch.sampler import RemoteSamplerInput from gigl.common.logger import Logger from gigl.distributed.dist_dataset import DistDataset -from gigl.distributed.dist_sampling_producer import DistSamplingProducer from gigl.distributed.graph_store.messages import ( FetchABLPInputRequest, FetchNodesRequest, + InitSamplingBackendRequest, + RegisterBackendRequest, +) +from gigl.distributed.graph_store.shared_dist_sampling_producer import ( + SharedDistSamplingBackend, ) -from gigl.distributed.sampler import ABLPNodeSamplerInput -from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions from gigl.distributed.utils.neighborloader import shard_nodes_by_process from gigl.src.common.types.graph_data import EdgeType, NodeType from gigl.types.graph import FeatureInfo, select_label_edge_types @@ -49,12 +41,40 @@ SERVER_EXIT_STATUS_CHECK_INTERVAL = 5.0 r""" Interval (in seconds) to check exit status of server. """ +FETCH_SLOW_LOG_SECS = 1.0 logger = Logger() R = TypeVar("R") +@dataclass +class ChannelState: + backend_id: int + worker_key: str + channel: ShmChannel + epoch: int = -1 + lock: threading.RLock = field(default_factory=threading.RLock) + + +@dataclass +class SamplingBackendState: + backend_id: int + backend_key: str + runtime: SharedDistSamplingBackend + active_channels: set[int] = field(default_factory=set) + lock: threading.RLock = field(default_factory=threading.RLock) + + +@dataclass +class _ChannelFetchStats: + """Per-channel fetch timing stats for ``fetch_one_sampled_message``.""" + + fetch_count: int = 0 + fetch_total_elapsed: float = 0.0 + fetch_slow_count: int = 0 + + class DistServer: r"""A server that supports launching remote sampling workers for training clients. @@ -68,32 +88,33 @@ class DistServer: data and feature data, along with distributed patition books. """ - def __init__(self, dataset: DistDataset) -> None: + def __init__(self, dataset: DistDataset, log_every_n: int = 50) -> None: self.dataset = dataset - # Top-level lock used to safely allocate producer IDs and create per-producer - # locks. We need this because _producer_lock entries don't exist until a - # producer is first requested, so concurrent calls for the same worker_key - # could race on creating the entry. Once a per-producer lock exists, callers - # use it directly without holding _lock. self._lock = threading.RLock() self._exit = False - self._cur_producer_idx = 0 # auto incremental index (same as producer count) - # The mapping from the key in worker options (such as 'train', 'test') - # to producer id - self._worker_key2producer_id: dict[str, int] = {} - self._producer_pool: dict[int, DistSamplingProducer] = {} - self._msg_buffer_pool: dict[int, ShmChannel] = {} - self._epoch: dict[int, int] = {} # last epoch for the producer - # Per-producer locks that guard the lifecycle of individual producers - # (creation, epoch transitions, destruction). This avoids holding the - # top-level _lock during expensive operations like producer init. - self._producer_lock: dict[int, threading.RLock] = {} + self._next_backend_id = 0 + self._next_channel_id = 0 + self._backend_key_to_id: dict[str, int] = {} + self._backend_state_by_id: dict[int, SamplingBackendState] = {} + self._channel_state: dict[int, ChannelState] = {} + self._log_every_n = log_every_n + self._fetch_stats: dict[int, _ChannelFetchStats] = {} def shutdown(self) -> None: - for producer_id in list(self._producer_pool.keys()): - self.destroy_sampling_producer(producer_id) - assert len(self._producer_pool) == 0 - assert len(self._msg_buffer_pool) == 0 + with self._lock: + backends = list(self._backend_state_by_id.values()) + self._backend_key_to_id.clear() + self._backend_state_by_id.clear() + self._channel_state.clear() + for backend_state in backends: + try: + backend_state.runtime.shutdown() + except Exception: + logger.warning( + f"Failed to shut down backend backend_id={backend_state.backend_id} " + f"backend_key={backend_state.backend_key}", + exc_info=True, + ) def wait_for_exit(self) -> None: r"""Block until the exit flag been set to ``True``.""" @@ -429,145 +450,198 @@ def get_ablp_input( ) return anchors, positive_labels, negative_labels - def create_sampling_producer( - self, - sampler_input: Union[ - NodeSamplerInput, EdgeSamplerInput, RemoteSamplerInput, ABLPNodeSamplerInput - ], - sampling_config: SamplingConfig, - worker_options: RemoteDistSamplingWorkerOptions, - sampler_options: SamplerOptions, - ) -> int: - """Create and initialize an instance of ``DistSamplingProducer`` with - a group of subprocesses for distributed sampling. - - Supports both standard ``NodeSamplerInput`` and ``ABLPNodeSamplerInput`` - through ``BaseGiGLSampler`` subclasses (``DistNeighborSampler`` for k-hop, - ``DistPPRNeighborSampler`` for PPR). - - Args: - sampler_input (NodeSamplerInput, EdgeSamplerInput, RemoteSamplerInput, - or ABLPNodeSamplerInput): The input data for sampling. - sampling_config (SamplingConfig): Configuration of sampling meta info. - worker_options (RemoteDistSamplingWorkerOptions): Options for launching - remote sampling workers by this server. - sampler_options (SamplerOptions): Controls which sampler class - is instantiated. - - Returns: - int: A unique id of created sampling producer on this server. - """ + def init_sampling_backend(self, opts: InitSamplingBackendRequest) -> int: + """Create or reuse a shared sampling backend for one loader instance.""" + request_start_time = time.monotonic() + with self._lock: + backend_id = self._backend_key_to_id.get(opts.backend_key) + if backend_id is not None: + return backend_id + backend_id = self._next_backend_id + self._next_backend_id += 1 + backend_state = SamplingBackendState( + backend_id=backend_id, + backend_key=opts.backend_key, + runtime=SharedDistSamplingBackend( + data=self.dataset, + worker_options=opts.worker_options, + sampling_config=opts.sampling_config, + sampler_options=opts.sampler_options, + ), + ) + self._backend_key_to_id[opts.backend_key] = backend_id + self._backend_state_by_id[backend_id] = backend_state + init_start_time = time.monotonic() + try: + backend_state.runtime.init_backend() + except Exception: + with self._lock: + self._backend_key_to_id.pop(opts.backend_key, None) + self._backend_state_by_id.pop(backend_id, None) + raise + init_elapsed = time.monotonic() - init_start_time + total_elapsed = time.monotonic() - request_start_time + logger.info( + f"Initialized sampling backend backend_key={opts.backend_key} " + f"backend_id={backend_id} " + f"init_backend={init_elapsed:.2f}s total={total_elapsed:.2f}s" + ) + return backend_id + def register_sampling_input(self, opts: RegisterBackendRequest) -> int: + """Register one compute-rank input channel on an existing backend.""" request_start_time = time.monotonic() + with self._lock: + backend_state = self._backend_state_by_id[opts.backend_id] + channel_id = self._next_channel_id + self._next_channel_id += 1 + channel = ShmChannel(opts.buffer_capacity, opts.buffer_size) + channel_state = ChannelState( + backend_id=opts.backend_id, + worker_key=opts.worker_key, + channel=channel, + ) + self._channel_state[channel_id] = channel_state + backend_state.active_channels.add(channel_id) + + sampler_input = opts.sampler_input if isinstance(sampler_input, RemoteSamplerInput): sampler_input = sampler_input.to_local_sampler_input(dataset=self.dataset) - with self._lock: - producer_id = self._worker_key2producer_id.get(worker_options.worker_key) - if producer_id is None: - logger.info( - f"Creating new producer for worker key {worker_options.worker_key}" - ) - producer_id = self._cur_producer_idx - self._cur_producer_idx += 1 - else: - logger.info( - f"Reusing producer for worker key {worker_options.worker_key}, producer id {producer_id}" - ) - producer_lock = self._producer_lock.get(producer_id, None) - if producer_lock is None: - producer_lock = threading.RLock() - self._producer_lock[producer_id] = producer_lock - self._worker_key2producer_id[worker_options.worker_key] = producer_id - with producer_lock: - if producer_id not in self._producer_pool: - logger.info( - f"Creating new producer pool entry for producer id {producer_id}" - ) - buffer = ShmChannel( - worker_options.buffer_capacity, worker_options.buffer_size - ) - # Degree tensors for PPR must be computed before constructing - # the producer. The all_reduce inside degree_tensor requires - # all ranks to participate simultaneously and cannot run inside - # worker subprocesses (which only initialize RPC, not - # torch.distributed). - degree_tensors = ( - self.dataset.degree_tensor - if isinstance(sampler_options, PPRSamplerOptions) - else None - ) - producer = DistSamplingProducer( - data=self.dataset, + try: + with backend_state.lock: + backend_state.runtime.register_input( + channel_id=channel_id, + worker_key=opts.worker_key, sampler_input=sampler_input, - sampling_config=sampling_config, - worker_options=worker_options, - channel=buffer, - sampler_options=sampler_options, - degree_tensors=degree_tensors, + sampling_config=opts.sampling_config, + channel=channel, ) - producer_start_time = time.monotonic() - producer.init() - logger.info( - f"Producer {producer_id} initialized in {time.monotonic() - producer_start_time:.2f}s" - ) - self._producer_pool[producer_id] = producer - self._msg_buffer_pool[producer_id] = buffer - self._epoch[producer_id] = -1 - else: - logger.info( - f"Reusing producer pool entry for producer id {producer_id}" + except Exception: + with self._lock: + self._channel_state.pop(channel_id, None) + backend_state.active_channels.discard(channel_id) + raise + + logger.info( + f"Registered sampling input backend_id={opts.backend_id} " + f"channel_id={channel_id} worker_key={opts.worker_key} " + f"active_channels={len(backend_state.active_channels)} " + f"in {time.monotonic() - request_start_time:.2f}s" + ) + return channel_id + + def destroy_sampling_input(self, channel_id: int) -> None: + """Destroy one registered sampling channel and maybe its backend.""" + self._fetch_stats.pop(channel_id, None) + with self._lock: + channel_state = self._channel_state.pop(channel_id, None) + if channel_state is None: + return + backend_state = self._backend_state_by_id.get(channel_state.backend_id) + if backend_state is None: + return + + with backend_state.lock: + backend_state.runtime.unregister_input(channel_id) + + should_shutdown_backend = False + with self._lock: + backend_state.active_channels.discard(channel_id) + if not backend_state.active_channels: + self._backend_state_by_id.pop(backend_state.backend_id, None) + self._backend_key_to_id.pop(backend_state.backend_key, None) + should_shutdown_backend = True + if should_shutdown_backend: + backend_state.runtime.shutdown() + + def start_new_epoch_sampling(self, channel_id: int, epoch: int) -> None: + """Start one new epoch on one registered channel.""" + with self._lock: + channel_state = self._channel_state.get(channel_id) + if channel_state is None: + raise RuntimeError( + f"start_new_epoch_sampling: channel_id={channel_id} not found" ) - request_end_time = time.monotonic() + backend_state = self._backend_state_by_id.get(channel_state.backend_id) + if backend_state is None: + raise RuntimeError( + f"start_new_epoch_sampling: backend for channel_id={channel_id} " + f"backend_id={channel_state.backend_id} not found" + ) + + if channel_state.epoch >= epoch: + return + channel_state.epoch = epoch logger.info( - f"Request to create producer for worker key {worker_options.worker_key} took {request_end_time - request_start_time:.2f}s" + f"Starting epoch channel_id={channel_id} backend_id={channel_state.backend_id} " + f"epoch={epoch}" ) - return producer_id + backend_state.runtime.start_new_epoch_sampling(channel_id, epoch) - def destroy_sampling_producer(self, producer_id: int) -> None: - r"""Shutdown and destroy a sampling producer managed by this server with - its producer id. - """ - with self._producer_lock[producer_id]: - producer = self._producer_pool.get(producer_id, None) - if producer is not None: - producer.shutdown() - self._producer_pool.pop(producer_id) - self._msg_buffer_pool.pop(producer_id) - self._epoch.pop(producer_id) - - def start_new_epoch_sampling(self, producer_id: int, epoch: int) -> None: - r"""Start a new epoch sampling tasks for a specific sampling producer - with its producer id. + def _log_fetch_stats_if_due( + self, channel_id: int, worker_key: str, elapsed: float + ) -> None: + """Accumulate per-channel fetch timing and log aggregated stats every ``log_every_n`` calls. + + Stats are keyed by ``channel_id`` so each channel's RPC thread updates + its own counters without contention. """ - with self._producer_lock[producer_id]: - cur_epoch = self._epoch[producer_id] - if cur_epoch < epoch: - self._epoch[producer_id] = epoch - producer = self._producer_pool.get(producer_id, None) - if producer is not None: - producer.produce_all() + stats = self._fetch_stats.get(channel_id) + if stats is None: + stats = _ChannelFetchStats() + self._fetch_stats[channel_id] = stats + stats.fetch_count += 1 + stats.fetch_total_elapsed += elapsed + if elapsed >= FETCH_SLOW_LOG_SECS: + stats.fetch_slow_count += 1 + if stats.fetch_count >= self._log_every_n: + avg_elapsed = stats.fetch_total_elapsed / stats.fetch_count + logger.info( + f"fetch_one_sampled_message stats: worker_key={worker_key} " + f"avg_elapsed={avg_elapsed:.3f}s " + f"slow_count={stats.fetch_slow_count}/{stats.fetch_count}" + ) + stats.fetch_count = 0 + stats.fetch_total_elapsed = 0.0 + stats.fetch_slow_count = 0 def fetch_one_sampled_message( - self, producer_id: int + self, channel_id: int ) -> tuple[Optional[SampleMessage], bool]: - r"""Fetch a sampled message from the buffer of a specific sampling - producer with its producer id. - """ - producer = self._producer_pool.get(producer_id, None) - if producer is None: - warnings.warn("invalid producer_id {producer_id}") - return None, False - if producer.is_all_sampling_completed_and_consumed(): + """Fetch one sampled message from a registered channel.""" + request_start_time = time.monotonic() + with self._lock: + channel_state = self._channel_state.get(channel_id) + if channel_state is None: + return None, True + backend_state = self._backend_state_by_id.get(channel_state.backend_id) + if backend_state is None: return None, True - buffer = self._msg_buffer_pool.get(producer_id, None) - while True: - try: - msg = buffer.recv(timeout_ms=500) - return msg, False - except QueueTimeoutError as e: - if producer.is_all_sampling_completed(): - return None, True + + with channel_state.lock: + while True: + try: + msg = channel_state.channel.recv(timeout_ms=100) + self._log_fetch_stats_if_due( + channel_id, + channel_state.worker_key, + time.monotonic() - request_start_time, + ) + return msg, False + except QueueTimeoutError: + if ( + backend_state.runtime.is_channel_epoch_done( + channel_id, channel_state.epoch + ) + and channel_state.channel.empty() + ): + self._log_fetch_stats_if_due( + channel_id, + channel_state.worker_key, + time.monotonic() - request_start_time, + ) + return None, True _dist_server: Optional[DistServer] = None diff --git a/gigl/distributed/graph_store/messages.py b/gigl/distributed/graph_store/messages.py index 0ce043539..d92229afa 100644 --- a/gigl/distributed/graph_store/messages.py +++ b/gigl/distributed/graph_store/messages.py @@ -1,11 +1,45 @@ -"""RPC request messages for graph-store fetch operations.""" +"""RPC request messages for graph-store operations.""" from dataclasses import dataclass from typing import Literal, Optional, Union +from graphlearn_torch.distributed import RemoteDistSamplingWorkerOptions +from graphlearn_torch.sampler import ( + EdgeSamplerInput, + NodeSamplerInput, + RemoteSamplerInput, + SamplingConfig, +) + +from gigl.distributed.sampler import ABLPNodeSamplerInput +from gigl.distributed.sampler_options import SamplerOptions from gigl.src.common.types.graph_data import EdgeType, NodeType +@dataclass(frozen=True) +class InitSamplingBackendRequest: + """Request to initialize a shared sampling backend on a storage server.""" + + backend_key: str + worker_options: RemoteDistSamplingWorkerOptions + sampler_options: SamplerOptions + sampling_config: SamplingConfig + + +@dataclass(frozen=True) +class RegisterBackendRequest: + """Request to register one compute-rank input channel on a backend.""" + + backend_id: int + worker_key: str + sampler_input: Union[ + NodeSamplerInput, EdgeSamplerInput, RemoteSamplerInput, ABLPNodeSamplerInput + ] + sampling_config: SamplingConfig + buffer_capacity: int + buffer_size: Union[int, str] + + @dataclass(frozen=True) class FetchNodesRequest: """Request for fetching node IDs from a storage server. diff --git a/gigl/distributed/graph_store/shared_dist_sampling_producer.py b/gigl/distributed/graph_store/shared_dist_sampling_producer.py new file mode 100644 index 000000000..68848aa44 --- /dev/null +++ b/gigl/distributed/graph_store/shared_dist_sampling_producer.py @@ -0,0 +1,793 @@ +"""Shared graph-store sampling backend and fair-queued worker loop. + +This module implements the multi-channel sampling backend used in graph-store +mode. A single ``SharedDistSamplingBackend`` per loader instance manages a +pool of worker processes that service many compute-rank channels through a +fair-queued scheduler (``_shared_sampling_worker_loop``). +""" + +import datetime +import queue +import threading +import time +from collections import defaultdict, deque +from dataclasses import dataclass +from enum import Enum, auto +from multiprocessing.process import BaseProcess +from threading import Barrier +from typing import Optional, Union, cast + +import torch +import torch.multiprocessing as mp +from graphlearn_torch.channel import ChannelBase +from graphlearn_torch.distributed import ( + DistDataset, + RemoteDistSamplingWorkerOptions, + get_context, + init_rpc, + init_worker_group, + shutdown_rpc, +) +from graphlearn_torch.distributed.dist_sampling_producer import MP_STATUS_CHECK_INTERVAL +from graphlearn_torch.sampler import ( + EdgeSamplerInput, + NodeSamplerInput, + SamplingConfig, + SamplingType, +) +from graphlearn_torch.typing import EdgeType +from torch._C import _set_worker_signal_handlers + +from gigl.common.logger import Logger +from gigl.distributed.dist_sampling_producer import ( + SamplerInput, + SamplerRuntime, + _create_dist_sampler, + _prepare_degree_tensors, +) +from gigl.distributed.sampler_options import SamplerOptions + +logger = Logger() + +EPOCH_DONE_EVENT = "EPOCH_DONE" +SCHEDULER_TICK_SECS = 0.05 +SCHEDULER_STATE_LOG_INTERVAL_SECS = 10.0 +SCHEDULER_STATE_MAX_CHANNELS = 6 +SCHEDULER_SLOW_SUBMIT_SECS = 1.0 + + +class SharedMpCommand(Enum): + REGISTER_INPUT = auto() + UNREGISTER_INPUT = auto() + START_EPOCH = auto() + STOP = auto() + + +@dataclass(frozen=True) +class RegisterInputCmd: + channel_id: int + worker_key: str + sampler_input: SamplerInput + sampling_config: SamplingConfig + channel: ChannelBase + + +@dataclass(frozen=True) +class StartEpochCmd: + channel_id: int + epoch: int + seeds_index: Optional[torch.Tensor] + + +@dataclass +class ActiveEpochState: + channel_id: int + epoch: int + input_len: int + batch_size: int + drop_last: bool + seeds_index: Optional[torch.Tensor] + total_batches: int + submitted_batches: int = 0 + completed_batches: int = 0 + cancelled: bool = False + + +def _command_channel_id(command: SharedMpCommand, payload: object) -> Optional[int]: + """Extract the channel id from a worker command payload.""" + if command == SharedMpCommand.STOP: + return None + if isinstance(payload, RegisterInputCmd): + return payload.channel_id + if isinstance(payload, StartEpochCmd): + return payload.channel_id + if isinstance(payload, int): + return payload + return None + + +def _compute_num_batches(input_len: int, batch_size: int, drop_last: bool) -> int: + """Compute the number of batches emitted for an input length.""" + if input_len <= 0: + return 0 + if drop_last: + return input_len // batch_size + return (input_len + batch_size - 1) // batch_size + + +def _epoch_batch_indices(state: ActiveEpochState) -> Optional[torch.Tensor]: + """Return the next batch of indices for an active epoch. + + Returns the index tensor for the next batch, or None if no more batches + should be submitted (epoch cancelled, all batches already submitted, or + incomplete final batch with drop_last=True). + """ + if state.cancelled or state.submitted_batches >= state.total_batches: + return None + + batch_start = state.submitted_batches * state.batch_size + batch_end = min(batch_start + state.batch_size, state.input_len) + if state.drop_last and batch_end - batch_start < state.batch_size: + return None + + if state.seeds_index is None: + return torch.arange(batch_start, batch_end, dtype=torch.long) + return state.seeds_index[batch_start:batch_end] + + +def _compute_worker_seeds_ranges( + input_len: int, batch_size: int, num_workers: int +) -> list[tuple[int, int]]: + """Distribute complete batches across workers like GLT's producer does.""" + num_worker_batches = [0] * num_workers + num_total_complete_batches = input_len // batch_size + for rank in range(num_workers): + num_worker_batches[rank] += num_total_complete_batches // num_workers + for rank in range(num_total_complete_batches % num_workers): + num_worker_batches[rank] += 1 + + index_ranges: list[tuple[int, int]] = [] + start = 0 + for rank in range(num_workers): + end = start + num_worker_batches[rank] * batch_size + if rank == num_workers - 1: + end = input_len + index_ranges.append((start, end)) + start = end + return index_ranges + + +def _shared_sampling_worker_loop( + rank: int, + data: DistDataset, + worker_options: RemoteDistSamplingWorkerOptions, + task_queue: mp.Queue, + event_queue: mp.Queue, + mp_barrier: Barrier, + sampler_options: SamplerOptions, + degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], +) -> None: + """Run one shared graph-store worker that schedules many input channels. + + Each worker subprocess runs this function as a fair-queued batch scheduler. + Multiple input channels (each representing one compute rank's data stream) + share the same sampling worker processes and graph data. + + Algorithm: + 1. Initialize RPC, sampler infrastructure, and signal the parent via barrier. + 2. Enter the main event loop which alternates between: + a. Draining all pending commands from ``task_queue`` (register/unregister + channels, start epochs, stop). + b. Submitting batches round-robin from ``runnable_channels`` — a FIFO + queue of channels that have pending work. Each channel gets one batch + submitted per round to prevent starvation. + c. If no commands were processed and no batches submitted, blocking on + ``task_queue`` with a short timeout to avoid busy-waiting. + 3. Completion callbacks from the sampler update per-channel state and emit + ``EPOCH_DONE_EVENT`` to ``event_queue`` when all batches for an epoch + are finished. + """ + samplers: dict[int, SamplerRuntime] = {} + channels: dict[int, ChannelBase] = {} + inputs: dict[int, SamplerInput] = {} + cfgs: dict[int, SamplingConfig] = {} + route_key_by_channel: dict[int, str] = {} + started_epoch: dict[int, int] = {} + active_epochs_by_channel: dict[int, ActiveEpochState] = {} + runnable_channels: deque[int] = deque() + runnable_set: set[int] = set() + removing: set[int] = set() + state_lock = threading.RLock() + last_state_log_time = 0.0 + current_device: Optional[torch.device] = None + + # --- Scheduler helper functions --- + + def _enqueue_channel_if_runnable_locked(channel_id: int) -> None: + """Add channel to the fair-queue if it has pending batches.""" + state = active_epochs_by_channel.get(channel_id) + if state is None: + return + if state.cancelled or state.submitted_batches >= state.total_batches: + return + if channel_id in runnable_set: + return + runnable_channels.append(channel_id) + runnable_set.add(channel_id) + + def _clear_registered_input_locked(channel_id: int) -> None: + """Remove a channel's registration and clean up all associated state. + + If the channel still has in-flight batches (submitted but not yet + completed), marks it for deferred removal instead of cleaning up + immediately. + ``_on_batch_done`` will finish the cleanup once the last in-flight + batch completes. + """ + state = active_epochs_by_channel.get(channel_id) + if state is not None and state.completed_batches < state.submitted_batches: + removing.add(channel_id) + state.cancelled = True + return + sampler = samplers.pop(channel_id, None) + if sampler is not None: + sampler.wait_all() + sampler.shutdown_loop() + channels.pop(channel_id, None) + inputs.pop(channel_id, None) + cfgs.pop(channel_id, None) + route_key_by_channel.pop(channel_id, None) + started_epoch.pop(channel_id, None) + active_epochs_by_channel.pop(channel_id, None) + runnable_set.discard(channel_id) + removing.discard(channel_id) + + def _format_scheduler_state_locked() -> str: + """Format a human-readable snapshot of the scheduler for logging. + + Must be called while holding ``state_lock``. + """ + channel_ids = sorted(channels.keys()) + preview = channel_ids[:SCHEDULER_STATE_MAX_CHANNELS] + previews: list[str] = [] + for channel_id in preview: + active_epoch = active_epochs_by_channel.get(channel_id) + if active_epoch is None: + previews.append(f"{channel_id}:idle") + else: + previews.append( + f"{channel_id}:e{active_epoch.epoch}" + f"/{active_epoch.submitted_batches}" + f"/{active_epoch.completed_batches}" + f"/{active_epoch.total_batches}" + ) + extra = "" + if len(channel_ids) > len(preview): + extra = f" +{len(channel_ids) - len(preview)}" + return ( + f"registered={len(channels)} active={len(active_epochs_by_channel)} " + f"runnable={len(runnable_set)} removing={len(removing)} " + f"channels=[{', '.join(previews)}]{extra}" + ) + + def _maybe_log_scheduler_state(reason: str, force: bool = False) -> None: + """Log scheduler state at most once per ``SCHEDULER_STATE_LOG_INTERVAL_SECS``. + + Args: + reason: Short tag included in the log line (e.g. "start_epoch"). + force: If True, log regardless of the time-based throttle. + """ + nonlocal last_state_log_time + now = time.monotonic() + if not force and now - last_state_log_time < SCHEDULER_STATE_LOG_INTERVAL_SECS: + return + with state_lock: + scheduler_state = _format_scheduler_state_locked() + logger.info( + f"shared_sampling_scheduler worker_rank={rank} reason={reason} " + f"{scheduler_state}" + ) + last_state_log_time = now + + def _on_batch_done(channel_id: int, epoch: int) -> None: + """Sampler completion callback — invoked from sampler worker threads. + + Updates the channel's completed-batch counter. + When all batches for the epoch are done, emits ``EPOCH_DONE_EVENT`` + to ``event_queue``. + If the channel is pending removal, finishes cleanup via + ``_clear_registered_input_locked``. + """ + with state_lock: + state = active_epochs_by_channel.get(channel_id) + if state is None or state.epoch != epoch: + return + state.completed_batches += 1 + if state.completed_batches == state.total_batches: + active_epochs_by_channel.pop(channel_id, None) + event_queue.put((EPOCH_DONE_EVENT, channel_id, epoch, rank)) + if ( + channel_id in removing + and state.completed_batches == state.submitted_batches + ): + _clear_registered_input_locked(channel_id) + + def _submit_one_batch(channel_id: int) -> bool: + """Submit the next batch for a channel to its sampler. + + Re-enqueues the channel into ``runnable_channels`` if more batches + remain. + Returns True if a batch was submitted, False if the channel had no + pending work. + """ + with state_lock: + state = active_epochs_by_channel.get(channel_id) + if state is None: + return False + batch_indices = _epoch_batch_indices(state) + if batch_indices is None: + return False + state.submitted_batches += 1 + cfg = cfgs[channel_id] + sampler = samplers[channel_id] + channel_input = inputs[channel_id] + current_epoch = state.epoch + if state.submitted_batches < state.total_batches and not state.cancelled: + runnable_channels.append(channel_id) + runnable_set.add(channel_id) + + sampler_input = channel_input[batch_indices] + + callback = lambda _: _on_batch_done(channel_id, current_epoch) + if cfg.sampling_type == SamplingType.NODE: + sampler.sample_from_nodes( + cast(NodeSamplerInput, sampler_input), callback=callback + ) + elif cfg.sampling_type == SamplingType.LINK: + sampler.sample_from_edges( + cast(EdgeSamplerInput, sampler_input), callback=callback + ) + elif cfg.sampling_type == SamplingType.SUBGRAPH: + sampler.subgraph(cast(NodeSamplerInput, sampler_input), callback=callback) + else: + raise RuntimeError(f"Unsupported sampling type: {cfg.sampling_type}") + return True + + def _pump_runnable_channels() -> bool: + """Submit one batch per runnable channel in round-robin order. + + Returns True if at least one batch was submitted. + """ + made_progress = False + with state_lock: + num_candidates = len(runnable_channels) + for _ in range(num_candidates): + with state_lock: + if not runnable_channels: + break + channel_id = runnable_channels.popleft() + runnable_set.discard(channel_id) + made_progress = _submit_one_batch(channel_id) or made_progress + return made_progress + + def _handle_command(command: SharedMpCommand, payload: object) -> bool: + """Dispatch one command from the task queue. + + Returns True to keep running, False on ``STOP``. + """ + channel_id = _command_channel_id(command, payload) + if command == SharedMpCommand.REGISTER_INPUT: + register = cast(RegisterInputCmd, payload) + assert current_device is not None + sampler = _create_dist_sampler( + data=data, + sampling_config=register.sampling_config, + worker_options=worker_options, + channel=register.channel, + sampler_options=sampler_options, + degree_tensors=degree_tensors, + current_device=current_device, + ) + sampler.start_loop() + with state_lock: + samplers[register.channel_id] = sampler + channels[register.channel_id] = register.channel + inputs[register.channel_id] = register.sampler_input + cfgs[register.channel_id] = register.sampling_config + route_key_by_channel[register.channel_id] = register.worker_key + started_epoch[register.channel_id] = -1 + _maybe_log_scheduler_state("register_input", force=True) + return True + + if command == SharedMpCommand.START_EPOCH: + start_epoch = cast(StartEpochCmd, payload) + with state_lock: + if channel_id not in channels: + return True + if started_epoch.get(channel_id, -1) >= start_epoch.epoch: + return True + started_epoch[channel_id] = start_epoch.epoch + sampling_config = cfgs[channel_id] + local_input_len = ( + len(start_epoch.seeds_index) + if start_epoch.seeds_index is not None + else len(inputs[channel_id]) + ) + state = ActiveEpochState( + channel_id=channel_id, + epoch=start_epoch.epoch, + input_len=local_input_len, + batch_size=sampling_config.batch_size, + drop_last=sampling_config.drop_last, + seeds_index=start_epoch.seeds_index, + total_batches=_compute_num_batches( + local_input_len, + sampling_config.batch_size, + sampling_config.drop_last, + ), + ) + active_epochs_by_channel[channel_id] = state + if state.total_batches == 0: + active_epochs_by_channel.pop(channel_id, None) + event_queue.put( + (EPOCH_DONE_EVENT, channel_id, start_epoch.epoch, rank) + ) + return True + _enqueue_channel_if_runnable_locked(channel_id) + _maybe_log_scheduler_state("start_epoch", force=True) + return True + + if command == SharedMpCommand.UNREGISTER_INPUT: + assert channel_id is not None + with state_lock: + _clear_registered_input_locked(channel_id) + _maybe_log_scheduler_state("unregister_input", force=True) + return True + + if command == SharedMpCommand.STOP: + return False + + raise RuntimeError(f"Unknown command type: {command}") + + try: + init_worker_group( + world_size=worker_options.worker_world_size, + rank=worker_options.worker_ranks[rank], + group_name="_sampling_worker_subprocess", + ) + if worker_options.use_all2all: + torch.distributed.init_process_group( + backend="gloo", + timeout=datetime.timedelta(seconds=worker_options.rpc_timeout), + rank=worker_options.worker_ranks[rank], + world_size=worker_options.worker_world_size, + init_method="tcp://{}:{}".format( + worker_options.master_addr, worker_options.master_port + ), + ) + + if worker_options.num_rpc_threads is None: + num_rpc_threads = min(data.num_partitions, 16) + else: + num_rpc_threads = worker_options.num_rpc_threads + current_device = worker_options.worker_devices[rank] + + _set_worker_signal_handlers() + torch.set_num_threads(num_rpc_threads + 1) + + init_rpc( + master_addr=worker_options.master_addr, + master_port=worker_options.master_port, + num_rpc_threads=num_rpc_threads, + rpc_timeout=worker_options.rpc_timeout, + ) + mp_barrier.wait() + + # --- Main event loop --- + keep_running = True + while keep_running: + # Phase 1: Drain all pending commands without blocking. + processed_command = False + while keep_running: + try: + command, payload = task_queue.get_nowait() + except queue.Empty: + break + processed_command = True + keep_running = _handle_command(command, payload) + + # Phase 2: Submit batches round-robin from runnable channels. + made_progress = _pump_runnable_channels() + _maybe_log_scheduler_state("steady_state") + if not keep_running: + break + + # Phase 3: If idle (no commands, no batches), block until next command. + if not (processed_command or made_progress): + try: + command, payload = task_queue.get(timeout=SCHEDULER_TICK_SECS) + except queue.Empty: + continue + keep_running = _handle_command(command, payload) + except KeyboardInterrupt: + pass + finally: + for sampler in list(samplers.values()): + sampler.wait_all() + sampler.shutdown_loop() + shutdown_rpc(graceful=False) + + +class SharedDistSamplingBackend: + """Shared graph-store sampling backend reused across many remote channels.""" + + def __init__( + self, + *, + data: DistDataset, + worker_options: RemoteDistSamplingWorkerOptions, + sampling_config: SamplingConfig, + sampler_options: SamplerOptions, + ) -> None: + self.data = data + self.worker_options = worker_options + self.num_workers = worker_options.num_workers + self._backend_sampling_config = sampling_config + self._sampler_options = sampler_options + self._task_queues: list[mp.Queue] = [] + self._workers: list[BaseProcess] = [] + self._event_queue: Optional[mp.Queue] = None + self._shutdown = False + self._initialized = False + self._lock = threading.RLock() + self._channel_sampling_config: dict[int, SamplingConfig] = {} + self._channel_input_sizes: dict[int, list[int]] = {} + self._channel_worker_seeds_ranges: dict[int, list[tuple[int, int]]] = {} + self._channel_shuffle_generators: dict[int, Optional[torch.Generator]] = {} + self._channel_epoch: dict[int, int] = {} + self._completed_workers: defaultdict[tuple[int, int], set[int]] = defaultdict( + set + ) + + def init_backend(self) -> None: + """Initialize worker processes once for this backend.""" + with self._lock: + if self._initialized: + return + self.worker_options._assign_worker_devices() + current_ctx = get_context() + if current_ctx is None or not current_ctx.is_server(): + raise RuntimeError( + "SharedDistSamplingBackend.init_backend() requires a GLT server context." + ) + self.worker_options._set_worker_ranks(current_ctx) + degree_tensors = _prepare_degree_tensors( + self.data, + self._sampler_options, + ) + mp_context = mp.get_context("spawn") + barrier = mp_context.Barrier(self.num_workers + 1) + self._event_queue = mp_context.Queue() + for rank in range(self.num_workers): + task_queue = mp_context.Queue( + self.num_workers * self.worker_options.worker_concurrency + ) + self._task_queues.append(task_queue) + worker = mp_context.Process( + target=_shared_sampling_worker_loop, + args=( + rank, + self.data, + self.worker_options, + task_queue, + self._event_queue, + barrier, + self._sampler_options, + degree_tensors, + ), + ) + worker.daemon = True + worker.start() + self._workers.append(worker) + barrier.wait() + self._initialized = True + + def _enqueue_worker_command( + self, + worker_rank: int, + command: SharedMpCommand, + payload: object, + ) -> None: + queue_ = self._task_queues[worker_rank] + enqueue_start = time.monotonic() + queue_.put((command, payload)) + elapsed = time.monotonic() - enqueue_start + if elapsed >= SCHEDULER_SLOW_SUBMIT_SECS: + logger.warning( + f"task_queue enqueue_slow worker_rank={worker_rank} " + f"command={command.name} elapsed_secs={elapsed:.2f}" + ) + + def register_input( + self, + channel_id: int, + worker_key: str, + sampler_input: SamplerInput, + sampling_config: SamplingConfig, + channel: ChannelBase, + ) -> None: + """Register a channel-specific input on all backend workers.""" + with self._lock: + if not self._initialized: + raise RuntimeError("SharedDistSamplingBackend is not initialized.") + if channel_id in self._channel_sampling_config: + raise ValueError(f"channel_id {channel_id} is already registered.") + if sampling_config != self._backend_sampling_config: + raise ValueError( + "Sampling config must match the backend sampling config for shared backends." + ) + + shared_sampler_input = sampler_input.share_memory() + worker_ranges = _compute_worker_seeds_ranges( + len(shared_sampler_input), + sampling_config.batch_size, + self.num_workers, + ) + self._channel_sampling_config[channel_id] = sampling_config + self._channel_input_sizes[channel_id] = [ + end - start for start, end in worker_ranges + ] + self._channel_worker_seeds_ranges[channel_id] = worker_ranges + if sampling_config.shuffle: + generator = torch.Generator() + if sampling_config.seed is None: + generator.manual_seed(torch.seed()) + else: + generator.manual_seed(sampling_config.seed) + self._channel_shuffle_generators[channel_id] = generator + else: + self._channel_shuffle_generators[channel_id] = None + self._channel_epoch[channel_id] = -1 + for worker_rank in range(self.num_workers): + self._enqueue_worker_command( + worker_rank, + SharedMpCommand.REGISTER_INPUT, + RegisterInputCmd( + channel_id=channel_id, + worker_key=worker_key, + sampler_input=shared_sampler_input, + sampling_config=sampling_config, + channel=channel, + ), + ) + + def _drain_events(self) -> None: + """Drain worker completion events into the backend-local state.""" + if self._event_queue is None: + return + while True: + try: + event = self._event_queue.get_nowait() + except queue.Empty: + return + if event[0] == EPOCH_DONE_EVENT: + _, channel_id, epoch, worker_rank = event + self._completed_workers[(channel_id, epoch)].add(worker_rank) + + def start_new_epoch_sampling(self, channel_id: int, epoch: int) -> None: + """Start one new epoch for one registered channel.""" + with self._lock: + self._drain_events() + sampling_config = self._channel_sampling_config[channel_id] + if self._channel_epoch[channel_id] >= epoch: + return + previous_epoch = self._channel_epoch[channel_id] + self._channel_epoch[channel_id] = epoch + stale_keys = [ + k + for k in self._completed_workers + if k[0] == channel_id and k[1] <= epoch + ] + for k in stale_keys: + del self._completed_workers[k] + input_len = sum(self._channel_input_sizes[channel_id]) + worker_ranges = self._channel_worker_seeds_ranges[channel_id] + if sampling_config.shuffle: + generator = self._channel_shuffle_generators[channel_id] + assert generator is not None + full_index = torch.randperm(input_len, generator=generator) + for worker_rank, (start, end) in enumerate(worker_ranges): + worker_index = full_index[start:end] + worker_index.share_memory_() + self._enqueue_worker_command( + worker_rank, + SharedMpCommand.START_EPOCH, + StartEpochCmd( + channel_id=channel_id, + epoch=epoch, + seeds_index=worker_index, + ), + ) + else: + for worker_rank, (start, end) in enumerate(worker_ranges): + worker_index = torch.arange(start, end, dtype=torch.long) + worker_index.share_memory_() + self._enqueue_worker_command( + worker_rank, + SharedMpCommand.START_EPOCH, + StartEpochCmd( + channel_id=channel_id, + epoch=epoch, + seeds_index=worker_index, + ), + ) + + def unregister_input(self, channel_id: int) -> None: + """Unregister a channel from the backend workers.""" + with self._lock: + if channel_id not in self._channel_sampling_config: + return + self._drain_events() + self._channel_sampling_config.pop(channel_id, None) + self._channel_input_sizes.pop(channel_id, None) + self._channel_worker_seeds_ranges.pop(channel_id, None) + self._channel_shuffle_generators.pop(channel_id, None) + self._channel_epoch.pop(channel_id, None) + stale_keys = [k for k in self._completed_workers if k[0] == channel_id] + for k in stale_keys: + del self._completed_workers[k] + for worker_rank in range(self.num_workers): + self._enqueue_worker_command( + worker_rank, + SharedMpCommand.UNREGISTER_INPUT, + channel_id, + ) + + def is_channel_epoch_done(self, channel_id: int, epoch: int) -> bool: + """Return whether every worker finished the epoch for one channel.""" + with self._lock: + self._drain_events() + return ( + len(self._completed_workers.get((channel_id, epoch), set())) + == self.num_workers + ) + + def describe_channel(self, channel_id: int) -> dict[str, object]: + """Return lightweight diagnostics for one registered channel.""" + with self._lock: + self._drain_events() + epoch = self._channel_epoch.get(channel_id, -1) + completed_workers = len( + self._completed_workers.get((channel_id, epoch), set()) + ) + return { + "epoch": epoch, + "input_sizes": self._channel_input_sizes.get(channel_id, []), + "completed_workers": completed_workers, + } + + def shutdown(self) -> None: + """Stop all worker processes and release backend resources.""" + with self._lock: + if self._shutdown: + return + self._shutdown = True + try: + for worker_rank in range(len(self._task_queues)): + self._enqueue_worker_command( + worker_rank, + SharedMpCommand.STOP, + None, + ) + for worker in self._workers: + worker.join(timeout=MP_STATUS_CHECK_INTERVAL) + for queue_ in self._task_queues: + queue_.cancel_join_thread() + queue_.close() + if self._event_queue is not None: + self._event_queue.cancel_join_thread() + self._event_queue.close() + finally: + for worker in self._workers: + if worker.is_alive(): + worker.terminate() diff --git a/tests/integration/distributed/graph_store/graph_store_integration_test.py b/tests/integration/distributed/graph_store/graph_store_integration_test.py index a92bdd29d..aa8e0057b 100644 --- a/tests/integration/distributed/graph_store/graph_store_integration_test.py +++ b/tests/integration/distributed/graph_store/graph_store_integration_test.py @@ -379,26 +379,55 @@ def _run_compute_multiple_loaders_test( remote_dist_dataset, ablp_result, prefetch_size=2 ) logger.info( - f"Rank {rank} / {world_size} ablp_loader_1 producers: ({ablp_loader_1._producer_id_list})" + f"Rank {rank} / {world_size} ablp_loader_1 backends/channels: " + f"({ablp_loader_1._backend_id_list}, {ablp_loader_1._channel_id_list})" ) ablp_loader_2 = _build_ablp_loader( remote_dist_dataset, ablp_result, prefetch_size=2 ) logger.info( - f"Rank {rank} / {world_size} ablp_loader_2 producers: ({ablp_loader_2._producer_id_list})" + f"Rank {rank} / {world_size} ablp_loader_2 backends/channels: " + f"({ablp_loader_2._backend_id_list}, {ablp_loader_2._channel_id_list})" ) neighbor_loader_1 = _build_neighbor_loader( remote_dist_dataset, random_negative_input ) logger.info( - f"Rank {rank} / {world_size} neighbor_loader_1 producers: ({neighbor_loader_1._producer_id_list})" + f"Rank {rank} / {world_size} neighbor_loader_1 backends/channels: " + f"({neighbor_loader_1._backend_id_list}, {neighbor_loader_1._channel_id_list})" ) neighbor_loader_2 = _build_neighbor_loader( remote_dist_dataset, random_negative_input ) logger.info( - f"Rank {rank} / {world_size} neighbor_loader_2 producers: ({neighbor_loader_2._producer_id_list})" + f"Rank {rank} / {world_size} neighbor_loader_2 backends/channels: " + f"({neighbor_loader_2._backend_id_list}, {neighbor_loader_2._channel_id_list})" ) + gathered_ablp_loader_1_backends = [None] * world_size + torch.distributed.all_gather_object( + gathered_ablp_loader_1_backends, tuple(ablp_loader_1._backend_id_list) + ) + assert all( + backend_ids == gathered_ablp_loader_1_backends[0] + for backend_ids in gathered_ablp_loader_1_backends + ), "All ranks should share the same backend ids for one logical loader." + gathered_neighbor_loader_1_backends = [None] * world_size + torch.distributed.all_gather_object( + gathered_neighbor_loader_1_backends, tuple(neighbor_loader_1._backend_id_list) + ) + assert all( + backend_ids == gathered_neighbor_loader_1_backends[0] + for backend_ids in gathered_neighbor_loader_1_backends + ), "All ranks should share the same backend ids for one logical loader." + assert ( + ablp_loader_1._backend_id_list != ablp_loader_2._backend_id_list + ), "Concurrent ABLP loaders must use distinct backends." + assert ( + neighbor_loader_1._backend_id_list != neighbor_loader_2._backend_id_list + ), "Concurrent neighbor loaders must use distinct backends." + assert ( + ablp_loader_1._backend_id_list != neighbor_loader_1._backend_id_list + ), "ABLP and neighbor loaders must not share a backend." logger.info( f"Rank {rank} / {world_size} phase 1: loading batches from 4 parallel loaders" ) @@ -457,9 +486,8 @@ def _run_compute_multiple_loaders_test( local_expected=local_expected_negative_seeds, ) - # Shut down phase 1 loaders to free server-side producers and RPC resources - # before creating new loaders. This mirrors GLT's DistLoader.shutdown() which - # calls DistServer.destroy_sampling_producer for each remote producer. + # Shut down phase 1 loaders to free server-side channels and backend resources + # before creating new loaders. ablp_loader_1.shutdown() ablp_loader_2.shutdown() neighbor_loader_1.shutdown() @@ -472,13 +500,15 @@ def _run_compute_multiple_loaders_test( # ------------------------------------------------------------------ ablp_loader_3 = _build_ablp_loader(remote_dist_dataset, ablp_result) logger.info( - f"Rank {rank} / {world_size} ablp_loader_3 producers: ({ablp_loader_3._producer_id_list})" + f"Rank {rank} / {world_size} ablp_loader_3 backends/channels: " + f"({ablp_loader_3._backend_id_list}, {ablp_loader_3._channel_id_list})" ) neighbor_loader_3 = _build_neighbor_loader( remote_dist_dataset, random_negative_input ) logger.info( - f"Rank {rank} / {world_size} neighbor_loader_3 producers: ({neighbor_loader_3._producer_id_list})" + f"Rank {rank} / {world_size} neighbor_loader_3 backends/channels: " + f"({neighbor_loader_3._backend_id_list}, {neighbor_loader_3._channel_id_list})" ) logger.info( f"Rank {rank} / {world_size} phase 2: loading batches from 2 sequential loaders" diff --git a/tests/unit/distributed/dist_sampling_producer_test.py b/tests/unit/distributed/dist_sampling_producer_test.py new file mode 100644 index 000000000..9d19768da --- /dev/null +++ b/tests/unit/distributed/dist_sampling_producer_test.py @@ -0,0 +1,214 @@ +import queue +from typing import cast +from unittest.mock import MagicMock, patch + +import torch +import torch.multiprocessing as mp +from graphlearn_torch.sampler import NodeSamplerInput, SamplingConfig, SamplingType + +from gigl.distributed.graph_store.shared_dist_sampling_producer import ( + EPOCH_DONE_EVENT, + ActiveEpochState, + SharedDistSamplingBackend, + SharedMpCommand, + StartEpochCmd, + _compute_num_batches, + _compute_worker_seeds_ranges, + _epoch_batch_indices, +) +from gigl.distributed.sampler_options import KHopNeighborSamplerOptions +from tests.test_assets.test_case import TestCase + + +def _make_sampling_config(*, shuffle: bool = False) -> SamplingConfig: + return SamplingConfig( + sampling_type=SamplingType.NODE, + num_neighbors=[2], + batch_size=2, + shuffle=shuffle, + drop_last=False, + with_edge=True, + collect_features=True, + with_neg=False, + with_weight=False, + edge_dir="out", + seed=1234, + ) + + +class _FakeProcess: + def __init__(self, *args, **kwargs) -> None: + self.daemon = False + + def start(self) -> None: + return None + + def join(self, timeout: float | None = None) -> None: + return None + + def is_alive(self) -> bool: + return False + + def terminate(self) -> None: + return None + + +class _FakeMpContext: + def Barrier(self, parties: int): + return MagicMock(wait=MagicMock()) + + def Queue(self, maxsize: int = 0): + return MagicMock() + + def Process(self, *args, **kwargs): + return _FakeProcess(*args, **kwargs) + + +class DistSamplingProducerTest(TestCase): + def test_compute_num_batches(self) -> None: + self.assertEqual(_compute_num_batches(0, 2, False), 0) + self.assertEqual(_compute_num_batches(1, 2, True), 0) + self.assertEqual(_compute_num_batches(1, 2, False), 1) + self.assertEqual(_compute_num_batches(5, 2, False), 3) + self.assertEqual(_compute_num_batches(5, 2, True), 2) + + def test_epoch_batch_indices(self) -> None: + active_state = ActiveEpochState( + channel_id=0, + epoch=0, + input_len=6, + batch_size=2, + drop_last=False, + seeds_index=torch.arange(6), + total_batches=3, + submitted_batches=1, + cancelled=False, + ) + result = _epoch_batch_indices(active_state) + assert result is not None + self.assert_tensor_equality(result, torch.tensor([2, 3])) + + def test_compute_worker_seeds_ranges(self) -> None: + self.assertEqual( + _compute_worker_seeds_ranges(input_len=7, batch_size=2, num_workers=3), + [(0, 2), (2, 4), (4, 7)], + ) + + @patch("gigl.distributed.graph_store.shared_dist_sampling_producer.get_context") + @patch("gigl.distributed.graph_store.shared_dist_sampling_producer.mp.get_context") + @patch( + "gigl.distributed.graph_store.shared_dist_sampling_producer._prepare_degree_tensors" + ) + def test_init_backend_prepares_worker_options( + self, + mock_prepare_degree_tensors: MagicMock, + mock_get_mp_context: MagicMock, + mock_get_context: MagicMock, + ) -> None: + worker_options = MagicMock() + worker_options.num_workers = 2 + worker_options.worker_concurrency = 1 + mock_get_context.return_value = MagicMock( + is_server=MagicMock(return_value=True) + ) + mock_get_mp_context.return_value = _FakeMpContext() + backend = SharedDistSamplingBackend( + data=MagicMock(), + worker_options=worker_options, + sampling_config=_make_sampling_config(), + sampler_options=KHopNeighborSamplerOptions(num_neighbors=[2]), + ) + + backend.init_backend() + + worker_options._assign_worker_devices.assert_called_once() + worker_options._set_worker_ranks.assert_called_once_with( + mock_get_context.return_value + ) + self.assertEqual(len(backend._task_queues), 2) + self.assertEqual(len(backend._workers), 2) + self.assertTrue(backend._initialized) + mock_prepare_degree_tensors.assert_called_once() + + def test_start_new_epoch_sampling_shuffle_refreshes_per_epoch(self) -> None: + worker_options = MagicMock() + worker_options.num_workers = 2 + worker_options.worker_concurrency = 1 + backend = SharedDistSamplingBackend( + data=MagicMock(), + worker_options=worker_options, + sampling_config=_make_sampling_config(shuffle=True), + sampler_options=KHopNeighborSamplerOptions(num_neighbors=[2]), + ) + backend._initialized = True + recorded: list[tuple[int, SharedMpCommand, object]] = [] + backend._enqueue_worker_command = lambda worker_rank, command, payload: recorded.append( # type: ignore[method-assign] + (worker_rank, command, payload) + ) + + channel = MagicMock() + input_tensor = torch.arange(6, dtype=torch.long) + backend.register_input( + channel_id=1, + worker_key="loader_a_compute_rank_0", + sampler_input=NodeSamplerInput(node=input_tensor.clone()), + sampling_config=_make_sampling_config(shuffle=True), + channel=channel, + ) + backend.register_input( + channel_id=2, + worker_key="loader_b_compute_rank_0", + sampler_input=NodeSamplerInput(node=input_tensor.clone()), + sampling_config=_make_sampling_config(shuffle=True), + channel=channel, + ) + + def _collect_epoch_indices(channel_id: int, epoch: int) -> torch.Tensor: + recorded.clear() + backend.start_new_epoch_sampling(channel_id, epoch) + worker_payloads = { + worker_rank: cast(StartEpochCmd, payload).seeds_index + for worker_rank, command, payload in recorded + if command == SharedMpCommand.START_EPOCH + } + assert all( + seed_index is not None for seed_index in worker_payloads.values() + ) + return torch.cat( + [ + cast(torch.Tensor, worker_payloads[worker_rank]) + for worker_rank in sorted(worker_payloads) + ] + ) + + channel_1_epoch_0 = _collect_epoch_indices(1, 0) + channel_2_epoch_0 = _collect_epoch_indices(2, 0) + channel_1_epoch_1 = _collect_epoch_indices(1, 1) + + self.assert_tensor_equality(channel_1_epoch_0, channel_2_epoch_0) + self.assertNotEqual( + channel_1_epoch_0.tolist(), + channel_1_epoch_1.tolist(), + ) + + def test_describe_channel_reports_completed_workers(self) -> None: + worker_options = MagicMock() + worker_options.num_workers = 2 + worker_options.worker_concurrency = 1 + backend = SharedDistSamplingBackend( + data=MagicMock(), + worker_options=worker_options, + sampling_config=_make_sampling_config(), + sampler_options=KHopNeighborSamplerOptions(num_neighbors=[2]), + ) + backend._initialized = True + backend._event_queue = cast(mp.Queue, queue.Queue()) + backend._channel_input_sizes[1] = [4, 2] + backend._channel_epoch[1] = 3 + cast(queue.Queue, backend._event_queue).put((EPOCH_DONE_EVENT, 1, 3, 0)) + + description = backend.describe_channel(1) + + self.assertEqual(description["epoch"], 3) + self.assertEqual(description["input_sizes"], [4, 2]) + self.assertEqual(description["completed_workers"], 1) diff --git a/tests/unit/distributed/dist_server_test.py b/tests/unit/distributed/dist_server_test.py index 641d933f9..bf60bc396 100644 --- a/tests/unit/distributed/dist_server_test.py +++ b/tests/unit/distributed/dist_server_test.py @@ -1,10 +1,15 @@ +from unittest.mock import MagicMock, patch + import torch from absl.testing import absltest +from graphlearn_torch.sampler import SamplingConfig, SamplingType from gigl.distributed.graph_store import dist_server from gigl.distributed.graph_store.messages import ( FetchABLPInputRequest, FetchNodesRequest, + InitSamplingBackendRequest, + RegisterBackendRequest, ) from gigl.src.common.types.graph_data import Relation from tests.test_assets.distributed.test_dataset import ( @@ -21,6 +26,22 @@ from tests.test_assets.test_case import TestCase +def _make_sampling_config() -> SamplingConfig: + return SamplingConfig( + sampling_type=SamplingType.NODE, + num_neighbors=[2], + batch_size=2, + shuffle=False, + drop_last=False, + with_edge=True, + collect_features=True, + with_neg=False, + with_weight=False, + edge_dir="out", + seed=None, + ) + + class TestRemoteDataset(TestCase): def setUp(self) -> None: """Reset the global dataset before each test.""" @@ -535,5 +556,181 @@ def test_get_ablp_input_without_negative_labels(self) -> None: self.assertIsNone(neg_labels) +class TestDistServerSampling(TestCase): + def setUp(self) -> None: + dist_server._dist_server = None + self.dataset = create_homogeneous_dataset( + edge_index=DEFAULT_HOMOGENEOUS_EDGE_INDEX, + ) + self.server = dist_server.DistServer(self.dataset) + self.worker_options = MagicMock() + self.worker_options.buffer_capacity = 2 + self.worker_options.buffer_size = "1MB" + self.sampling_config = _make_sampling_config() + self.sampler_options = MagicMock() + + def tearDown(self) -> None: + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + @patch("gigl.distributed.graph_store.dist_server.SharedDistSamplingBackend") + def test_init_sampling_backend_idempotent( + self, mock_backend_cls: MagicMock + ) -> None: + runtime = mock_backend_cls.return_value + + backend_id_1 = self.server.init_sampling_backend( + InitSamplingBackendRequest( + backend_key="neighbor_loader_0", + worker_options=self.worker_options, + sampler_options=self.sampler_options, + sampling_config=self.sampling_config, + ) + ) + backend_id_2 = self.server.init_sampling_backend( + InitSamplingBackendRequest( + backend_key="neighbor_loader_0", + worker_options=self.worker_options, + sampler_options=self.sampler_options, + sampling_config=self.sampling_config, + ) + ) + + self.assertEqual(backend_id_1, backend_id_2) + mock_backend_cls.assert_called_once() + runtime.init_backend.assert_called_once() + + @patch("gigl.distributed.graph_store.dist_server.ShmChannel") + @patch("gigl.distributed.graph_store.dist_server.SharedDistSamplingBackend") + def test_register_creates_channel( + self, + mock_backend_cls: MagicMock, + mock_channel_cls: MagicMock, + ) -> None: + runtime = mock_backend_cls.return_value + backend_id = self.server.init_sampling_backend( + InitSamplingBackendRequest( + backend_key="neighbor_loader_0", + worker_options=self.worker_options, + sampler_options=self.sampler_options, + sampling_config=self.sampling_config, + ) + ) + + channel_id = self.server.register_sampling_input( + RegisterBackendRequest( + backend_id=backend_id, + worker_key="neighbor_loader_0_compute_rank_0", + sampler_input=MagicMock(), + sampling_config=self.sampling_config, + buffer_capacity=2, + buffer_size="1MB", + ) + ) + + self.assertEqual(channel_id, 0) + runtime.register_input.assert_called_once() + mock_channel_cls.assert_called_once_with(2, "1MB") + + @patch("gigl.distributed.graph_store.dist_server.ShmChannel") + @patch("gigl.distributed.graph_store.dist_server.SharedDistSamplingBackend") + def test_destroy_last_channel_shuts_down_backend( + self, + mock_backend_cls: MagicMock, + _mock_channel_cls: MagicMock, + ) -> None: + runtime = mock_backend_cls.return_value + backend_id = self.server.init_sampling_backend( + InitSamplingBackendRequest( + backend_key="neighbor_loader_0", + worker_options=self.worker_options, + sampler_options=self.sampler_options, + sampling_config=self.sampling_config, + ) + ) + channel_id = self.server.register_sampling_input( + RegisterBackendRequest( + backend_id=backend_id, + worker_key="neighbor_loader_0_compute_rank_0", + sampler_input=MagicMock(), + sampling_config=self.sampling_config, + buffer_capacity=2, + buffer_size="1MB", + ) + ) + + self.server.destroy_sampling_input(channel_id) + + runtime.unregister_input.assert_called_once_with(channel_id) + runtime.shutdown.assert_called_once() + self.assertEqual(self.server._backend_state_by_id, {}) + + def test_destroy_unknown_channel_noop(self) -> None: + self.server.destroy_sampling_input(999) + self.assertEqual(self.server._backend_state_by_id, {}) + + @patch("gigl.distributed.graph_store.dist_server.ShmChannel") + @patch("gigl.distributed.graph_store.dist_server.SharedDistSamplingBackend") + def test_start_epoch_idempotent( + self, + mock_backend_cls: MagicMock, + _mock_channel_cls: MagicMock, + ) -> None: + runtime = mock_backend_cls.return_value + backend_id = self.server.init_sampling_backend( + InitSamplingBackendRequest( + backend_key="neighbor_loader_0", + worker_options=self.worker_options, + sampler_options=self.sampler_options, + sampling_config=self.sampling_config, + ) + ) + channel_id = self.server.register_sampling_input( + RegisterBackendRequest( + backend_id=backend_id, + worker_key="neighbor_loader_0_compute_rank_0", + sampler_input=MagicMock(), + sampling_config=self.sampling_config, + buffer_capacity=2, + buffer_size="1MB", + ) + ) + + self.server.start_new_epoch_sampling(channel_id, 0) + self.server.start_new_epoch_sampling(channel_id, 0) + + runtime.start_new_epoch_sampling.assert_called_once_with(channel_id, 0) + + def test_shutdown_cleans_all_backends(self) -> None: + runtime_1 = MagicMock() + runtime_2 = MagicMock() + self.server._backend_state_by_id = { + 0: dist_server.SamplingBackendState( + backend_id=0, + backend_key="neighbor_loader_0", + runtime=runtime_1, + ), + 1: dist_server.SamplingBackendState( + backend_id=1, + backend_key="neighbor_loader_1", + runtime=runtime_2, + ), + } + self.server._backend_key_to_id = { + "neighbor_loader_0": 0, + "neighbor_loader_1": 1, + } + + self.server.shutdown() + + runtime_1.shutdown.assert_called_once() + runtime_2.shutdown.assert_called_once() + self.assertEqual(self.server._backend_state_by_id, {}) + + def test_create_sampling_producer_removed(self) -> None: + self.assertFalse(hasattr(dist_server.DistServer, "create_sampling_producer")) + self.assertFalse(hasattr(dist_server.DistServer, "destroy_sampling_producer")) + + if __name__ == "__main__": absltest.main()