Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions gigl/distributed/graph_store/dist_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,11 +575,26 @@ def register_sampling_input(self, opts: RegisterBackendRequest) -> int:
The unique channel ID for this input.
"""
request_start_time = time.monotonic()
sampler_input = opts.sampler_input

if isinstance(sampler_input, RemoteSamplerInput):
sampler_input = sampler_input.to_local_sampler_input(dataset=self.dataset)

with self._lock:
backend_state = self._backend_state_by_backend_id[opts.backend_id]
channel_id = self._next_channel_id
self._next_channel_id += 1
channel = ShmChannel(opts.buffer_capacity, opts.buffer_size)
# If the sampler input is empty, we create a channel with 1 slot and 1MB size
# We do this to save on memory usage for empty inputs.
# NOTE: We must keep creating these channels as we need to "register input" for
# all nodes on the storage cluster, as they the `NeighborSampler` is responsible for
# serving incoming sampling requests as well as sending them out.
# TODO(kmonte): Look into either supporting truly empty channels or having a shared
# DistSampler.
if len(sampler_input) == 0:
channel = ShmChannel(1, "1MB")
else:
channel = ShmChannel(opts.buffer_capacity, opts.buffer_size)
channel_state = ChannelState(
backend_id=opts.backend_id,
worker_key=opts.worker_key,
Expand All @@ -592,10 +607,6 @@ def register_sampling_input(self, opts: RegisterBackendRequest) -> int:
# value that could be mutated by concurrent register/destroy.
active_channels_at_register = len(backend_state.active_channels)

sampler_input = opts.sampler_input
if isinstance(sampler_input, RemoteSamplerInput):
sampler_input = sampler_input.to_local_sampler_input(dataset=self.dataset)

try:
with backend_state.lock:
backend_state.runtime.register_input(
Expand Down
40 changes: 36 additions & 4 deletions tests/unit/distributed/dist_server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
from absl.testing import absltest
from graphlearn_torch.sampler import SamplingConfig, SamplingType
from graphlearn_torch.sampler import NodeSamplerInput, SamplingConfig, SamplingType

from gigl.distributed.graph_store import dist_server
from gigl.distributed.graph_store.messages import (
Expand Down Expand Up @@ -676,16 +676,48 @@ def test_register_creates_channel(
RegisterBackendRequest(
backend_id=backend_id,
worker_key="neighbor_loader_0_compute_rank_0",
sampler_input=MagicMock(),
sampler_input=NodeSamplerInput(torch.arange(10)),
sampling_config=self.sampling_config,
buffer_capacity=2,
buffer_size="1MB",
buffer_size="2MB",
)
)

self.assertEqual(channel_id, 0)
runtime.register_input.assert_called_once()
mock_channel_cls.assert_called_once_with(2, "2MB")

@patch("gigl.distributed.graph_store.dist_server.ShmChannel")
@patch("gigl.distributed.graph_store.dist_server.SharedDistSamplingBackend")
def test_register_empty_channel(
self,
mock_backend_cls: MagicMock,
mock_channel_cls: MagicMock,
) -> None:
runtime = mock_backend_cls.return_value
backend_id = self.server.init_sampling_backend(
InitSamplingBackendRequest(
backend_key="neighbor_loader_0",
worker_options=self.worker_options,
sampler_options=self.sampler_options,
sampling_config=self.sampling_config,
)
)

channel_id = self.server.register_sampling_input(
RegisterBackendRequest(
backend_id=backend_id,
worker_key="neighbor_loader_0_compute_rank_0",
sampler_input=NodeSamplerInput(torch.tensor([])),
sampling_config=self.sampling_config,
buffer_capacity=2, # Should be overridden to 1 for empty channels
buffer_size="2MB", # Should be overridden to 1MB for empty channels
)
)

self.assertEqual(channel_id, 0)
runtime.register_input.assert_called_once()
mock_channel_cls.assert_called_once_with(2, "1MB")
mock_channel_cls.assert_called_once_with(1, "1MB")

@patch("gigl.distributed.graph_store.dist_server.ShmChannel")
@patch("gigl.distributed.graph_store.dist_server.SharedDistSamplingBackend")
Expand Down