Skip to content

_reduce silently returns unreduced tensor when input is non-contiguous #5337

@Pearblossom-M

Description

@Pearblossom-M

Describe the bug

@NVIDIA/mcore-oncall

In megatron/core/tensor_parallel/mappings.py, function _reduce silently returns the original unreduced tensor when the input is non-contiguous.

Current code:

def _reduce(input_, group):
    """All-reduce the input tensor across model parallel group."""
    assert group is not None, "group should not be None"

    # Bypass the function if we are using only 1 GPU.
    if group.size() == 1:
        return input_

    # All-reduce.
    torch.distributed.all_reduce(input_.contiguous(), group=group)

    return input_

tensor.contiguous() returns self if the tensor is already contiguous (same storage, no copy), but returns a new tensor with independent storage if it is not. When input_ is non-contiguous:

  1. input_.contiguous() creates temp — a new contiguous tensor
  2. all_reduce(temp) all-reduces temp in-place ✓
  3. return input_ returns the original, never-reduced tensor

The presence of .contiguous() implies the function is designed to handle non-contiguous inputs gracefully, but the actual behavior is to silently return unreduced data in that case.

Steps/Code to reproduce bug

While this bug is not triggered by any current call path (see Additional context), it can be demonstrated with the following test:

import torch
import torch.distributed as dist

def _reduce(input_, group):
    """All-reduce the input tensor across model parallel group."""
    assert group is not None, "group should not be None"

    # Bypass the function if we are using only 1 GPU.
    if group.size() == 1:
        return input_

    # All-reduce.
    torch.distributed.all_reduce(input_.contiguous(), group=group)

    return input_

def main():
    dist.init_process_group(backend="gloo")
    rank = dist.get_rank()
    group = dist.group.WORLD

    if rank == 0:
        input_tensor = torch.tensor([[1, 2, 3, 4], 
                              [5, 6, 7, 8]], dtype=torch.float32)
    else:
        input_tensor = torch.tensor([[10, 20, 30, 40], 
                              [50, 60, 70, 80]], dtype=torch.float32)

    input_noncontig = input_tensor.t() # transpose to create a non-contiguous tensor
    assert not input_noncontig.is_contiguous()

    reference = torch.tensor([[11, 55],
                              [22, 66],
                              [33, 77],
                              [44, 88]], dtype=torch.float32)

    result = _reduce(input_noncontig, group)
    dist.barrier()

    if rank == 0:
        correct_match = torch.allclose(result, reference)
        print(f"correct_match = {correct_match}")
        print("result:")
        print(result)

    dist.destroy_process_group()

if __name__ == "__main__":
    main()

To run the code:

torchrun --nproc_per_node=2 reproduce_example.py

When running the code above, it prints correct_match = False. The result tensor remains unreduced (equal to the original local input tensor), instead of matching the reference tensor.

Expected behavior

_reduce should return the reduced tensor for all inputs, regardless of contiguity.

  • Suggested fix:

    def _reduce(input_, group):
        """All-reduce the input tensor across model parallel group."""
        assert group is not None, "group should not be None"
    
        # Bypass the function if we are using only 1 GPU.
        if group.size() == 1:
            return input_
    
        # All-reduce.
        input_ = input_.contiguous()  # reassign so all_reduce target == returned tensor
        torch.distributed.all_reduce(input_, group=group)
    
        return input_

    This is a one-line change. It makes the function correct for all inputs and removes the false sense of safety created by the current .contiguous() placement.

  • Alternatively, if the contract is that callers must provide contiguous input, the .contiguous() call should be removed entirely and replaced with an assertion:

    def _reduce(input_, group):
        """All-reduce the input tensor across model parallel group."""
        assert group is not None, "group should not be None"
    
        # Bypass the function if we are using only 1 GPU.
        if group.size() == 1:
            return input_
    
        # All-reduce.
        assert input_.is_contiguous(), "_reduce requires contiguous input"
        torch.distributed.all_reduce(input_, group=group)
    
        return input_

Additional context

Why it hasn't caused problems in practice:

All current callers pass contiguous tensors (_ReduceFromModelParallelRegion.forward receives output from a linear layer; _CopyToModelParallelRegion.backward receives a gradient tensor from autograd — both are guaranteed contiguous). So .contiguous() is always a no-op and the bug is never triggered.

I have already prepared a fix and will submit a PR shortly.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions