diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index 0147d6929..cecb52a51 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -575,11 +575,26 @@ def register_sampling_input(self, opts: RegisterBackendRequest) -> int: The unique channel ID for this input. """ request_start_time = time.monotonic() + sampler_input = opts.sampler_input + + if isinstance(sampler_input, RemoteSamplerInput): + sampler_input = sampler_input.to_local_sampler_input(dataset=self.dataset) + with self._lock: backend_state = self._backend_state_by_backend_id[opts.backend_id] channel_id = self._next_channel_id self._next_channel_id += 1 - channel = ShmChannel(opts.buffer_capacity, opts.buffer_size) + # If the sampler input is empty, we create a channel with 1 slot and 1MB size + # We do this to save on memory usage for empty inputs. + # NOTE: We must keep creating these channels as we need to "register input" for + # all nodes on the storage cluster, as they the `NeighborSampler` is responsible for + # serving incoming sampling requests as well as sending them out. + # TODO(kmonte): Look into either supporting truly empty channels or having a shared + # DistSampler. + if len(sampler_input) == 0: + channel = ShmChannel(1, "1MB") + else: + channel = ShmChannel(opts.buffer_capacity, opts.buffer_size) channel_state = ChannelState( backend_id=opts.backend_id, worker_key=opts.worker_key, @@ -592,10 +607,6 @@ def register_sampling_input(self, opts: RegisterBackendRequest) -> int: # value that could be mutated by concurrent register/destroy. active_channels_at_register = len(backend_state.active_channels) - sampler_input = opts.sampler_input - if isinstance(sampler_input, RemoteSamplerInput): - sampler_input = sampler_input.to_local_sampler_input(dataset=self.dataset) - try: with backend_state.lock: backend_state.runtime.register_input( diff --git a/tests/unit/distributed/dist_server_test.py b/tests/unit/distributed/dist_server_test.py index 48eb95fa4..2703391f7 100644 --- a/tests/unit/distributed/dist_server_test.py +++ b/tests/unit/distributed/dist_server_test.py @@ -3,7 +3,7 @@ import torch from absl.testing import absltest -from graphlearn_torch.sampler import SamplingConfig, SamplingType +from graphlearn_torch.sampler import NodeSamplerInput, SamplingConfig, SamplingType from gigl.distributed.graph_store import dist_server from gigl.distributed.graph_store.messages import ( @@ -676,16 +676,48 @@ def test_register_creates_channel( RegisterBackendRequest( backend_id=backend_id, worker_key="neighbor_loader_0_compute_rank_0", - sampler_input=MagicMock(), + sampler_input=NodeSamplerInput(torch.arange(10)), sampling_config=self.sampling_config, buffer_capacity=2, - buffer_size="1MB", + buffer_size="2MB", + ) + ) + + self.assertEqual(channel_id, 0) + runtime.register_input.assert_called_once() + mock_channel_cls.assert_called_once_with(2, "2MB") + + @patch("gigl.distributed.graph_store.dist_server.ShmChannel") + @patch("gigl.distributed.graph_store.dist_server.SharedDistSamplingBackend") + def test_register_empty_channel( + self, + mock_backend_cls: MagicMock, + mock_channel_cls: MagicMock, + ) -> None: + runtime = mock_backend_cls.return_value + backend_id = self.server.init_sampling_backend( + InitSamplingBackendRequest( + backend_key="neighbor_loader_0", + worker_options=self.worker_options, + sampler_options=self.sampler_options, + sampling_config=self.sampling_config, + ) + ) + + channel_id = self.server.register_sampling_input( + RegisterBackendRequest( + backend_id=backend_id, + worker_key="neighbor_loader_0_compute_rank_0", + sampler_input=NodeSamplerInput(torch.tensor([])), + sampling_config=self.sampling_config, + buffer_capacity=2, # Should be overridden to 1 for empty channels + buffer_size="2MB", # Should be overridden to 1MB for empty channels ) ) self.assertEqual(channel_id, 0) runtime.register_input.assert_called_once() - mock_channel_cls.assert_called_once_with(2, "1MB") + mock_channel_cls.assert_called_once_with(1, "1MB") @patch("gigl.distributed.graph_store.dist_server.ShmChannel") @patch("gigl.distributed.graph_store.dist_server.SharedDistSamplingBackend")