diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 1894408ab..e01b00d2b 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -57,7 +57,7 @@ reverse_edge_type, select_label_edge_types, ) -from gigl.utils.data_splitters import get_labels_for_anchor_nodes +from gigl.utils.data_splitters import PADDING_NODE, get_labels_for_anchor_nodes from gigl.utils.sampling import ABLPInputNodes logger = Logger() @@ -755,6 +755,58 @@ def _setup_for_graph_store( ), ) + def _compute_label_matches( + self, + local_to_global: torch.Tensor, + label_tensor: torch.Tensor, + num_anchors: int, + ) -> dict[int, torch.Tensor]: + """ + Compute label matches using fully vectorized operations. + + Args: + local_to_global: [N] tensor mapping local node idx to global node ID + label_tensor: [A, M] tensor of label global node IDs (padded with PADDING_NODE) + num_anchors: Number of anchor nodes (A) + + Returns: + dict[int, torch.Tensor]: Mapping from anchor_idx to tensor of matching local node indices + """ + # Vectorized broadcast comparison: [A, N, M] + # local_to_global: [N] -> [1, N, 1] + # label_tensor: [A, M] -> [A, 1, M] + matches = local_to_global.view(1, -1, 1) == label_tensor.unsqueeze(1) + + # Mask out padding matches (PADDING_NODE should not match any real node) + padding_mask = (label_tensor == PADDING_NODE).unsqueeze(1) # [A, 1, M] + matches = matches & ~padding_mask + + # Reduce: any match across labels dimension -> [A, N] + any_match = matches.any(dim=2) + + # Single nonzero call on full tensor to get all (anchor_idx, node_idx) pairs + # Returns tuple of (anchor_indices, node_indices) tensors + match_coords = torch.nonzero(any_match, as_tuple=True) + anchor_indices = match_coords[0] + node_indices = match_coords[1] + + # Count matches per anchor using bincount for efficient splitting + if anchor_indices.numel() > 0: + counts = torch.bincount(anchor_indices, minlength=num_anchors) + else: + counts = torch.zeros(num_anchors, dtype=torch.long, device=any_match.device) + + # Transfer node_indices to target device ONCE before splitting + # This avoids num_anchors small device transfers which have significant overhead + node_indices_on_device = node_indices.to(self.to_device) + + # Split on device - torch.split returns a tuple of views (no copy) + split_sizes = counts.tolist() + split_indices = torch.split(node_indices_on_device, split_sizes) + + # Build output dict using dict comprehension with enumerate (faster than loop) + return dict(enumerate(split_indices)) + def _set_labels( self, data: Union[Data, HeteroData], @@ -765,6 +817,10 @@ def _set_labels( Sets the labels and relevant fields in the torch_geometric Data object, converting the global node ids for labels to their local index. Removes inserted supervision edge type from the data variables, since this is an implementation detail and should not be exposed in the final HeteroData/Data object. + + This method uses fully vectorized operations to efficiently process all anchor nodes in a batch simultaneously, + including a single nonzero call for all anchors followed by efficient splitting. + Args: data (Union[Data, HeteroData]): Graph to provide labels for positive_labels_by_label_edge_type (dict[EdgeType, torch.Tensor]): Dict[positive label edge type, label ID tensor], @@ -784,48 +840,42 @@ def _set_labels( node_type_to_local_node_to_global_node[ DEFAULT_HOMOGENEOUS_NODE_TYPE ] = data.node + output_positive_labels: dict[EdgeType, dict[int, torch.Tensor]] = defaultdict( dict ) output_negative_labels: dict[EdgeType, dict[int, torch.Tensor]] = defaultdict( dict ) + # We always have supervision edge types of the form (anchor_node_type, to, supervision_node_type) # So we can index into the edge type accordingly. edge_index = 2 + + # Process positive labels with fully vectorized operations for edge_type, label_tensor in positive_labels_by_label_edge_type.items(): - for local_anchor_node_id in range(label_tensor.size(0)): - positive_mask = ( - node_type_to_local_node_to_global_node[ - edge_type[edge_index] - ].unsqueeze(1) - == label_tensor[local_anchor_node_id] - ) # shape [N, P], where N is the number of nodes and P is the number of positive labels for the current anchor node - - # Gets the indexes of the items in local_node_to_global_node which match any of the positive labels for the current anchor node - output_positive_labels[ - label_edge_type_to_message_passing_edge_type(edge_type) - ][local_anchor_node_id] = torch.nonzero(positive_mask)[:, 0].to( - self.to_device - ) - # Shape [X], where X is the number of indexes in the original local_node_to_global_node which match a node in the positive labels for the current anchor node + local_to_global = node_type_to_local_node_to_global_node[ + edge_type[edge_index] + ] + num_anchors = label_tensor.size(0) + + mp_edge_type = label_edge_type_to_message_passing_edge_type(edge_type) + output_positive_labels[mp_edge_type] = self._compute_label_matches( + local_to_global, label_tensor, num_anchors + ) + # Process negative labels with fully vectorized operations for edge_type, label_tensor in negative_labels_by_label_edge_type.items(): - for local_anchor_node_id in range(label_tensor.size(0)): - negative_mask = ( - node_type_to_local_node_to_global_node[ - edge_type[edge_index] - ].unsqueeze(1) - == label_tensor[local_anchor_node_id] - ) # shape [N, M], where N is the number of nodes and M is the number of negative labels for the current anchor node - - # Gets the indexes of the items in local_node_to_global_node which match any of the negative labels for the current anchor node - output_negative_labels[ - label_edge_type_to_message_passing_edge_type(edge_type) - ][local_anchor_node_id] = torch.nonzero(negative_mask)[:, 0].to( - self.to_device - ) - # Shape [X], where X is the number of indexes in the original local_node_to_global_node which match a node in the negative labels for the current anchor node + local_to_global = node_type_to_local_node_to_global_node[ + edge_type[edge_index] + ] + num_anchors = label_tensor.size(0) + + mp_edge_type = label_edge_type_to_message_passing_edge_type(edge_type) + output_negative_labels[mp_edge_type] = self._compute_label_matches( + local_to_global, label_tensor, num_anchors + ) + if not output_positive_labels: raise ValueError("No positive labels were found in the data!") elif len(output_positive_labels) == 1: @@ -837,6 +887,7 @@ def _set_labels( data.y_negative = next(iter(output_negative_labels.values())) elif len(output_negative_labels) > 0: data.y_negative = output_negative_labels + return data def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: