Skip to content

Conversation

@buptzyb
Copy link
Contributor

@buptzyb buptzyb commented Dec 16, 2025

Description

Support cudagraph recomputation with two changes:

  1. Replace autograd.grad with autograd.backward in cudagraph capturing.
  2. Get default RNG states in a graphsafe manner, if the tracker states are also graphsafe.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 16, 2025

Greptile Summary

This PR enables cudagraph recomputation support by making two key changes: replacing torch.autograd.grad with torch.autograd.backward in graph capturing, and conditionally using graph-safe RNG states based on tracker state capabilities.

  • Replaced torch.autograd.grad with torch.autograd.backward in all graph capture locations (graph.py:454-460, graph.py:638-643, graph.py:729-734)
  • Added _none_grad_context_wrapper to manage gradient accumulation by temporarily clearing and restoring input.grad values
  • Introduced is_graph_safe_rng_state() helper to detect if RNG states support graph-safe operations
  • Modified checkpoint functions to automatically determine and use graph-safe RNG state when tracker states are graph-safe
  • Added logic to handle None gradients for unused inputs during warmup phase

The changes align with PyTorch's cudagraph requirements where backward() is preferred over grad() for graph capture scenarios. The RNG state changes ensure compatibility with graph-safe random number generation when available.

Confidence Score: 4/5

  • This PR is safe to merge with minimal risk, pending verification of gradient restoration behavior
  • The implementation follows PyTorch best practices for cudagraph support. The main changes are well-scoped: replacing autograd.grad with autograd.backward and adding graph-safe RNG state detection. The logic for handling gradients appears correct, though the gradient restoration in _none_grad_context_wrapper could potentially benefit from cloning to avoid aliasing issues in edge cases. The RNG state changes are conservative and only use graph-safe mode when detected as available.
  • Pay attention to transformer_engine/pytorch/graph.py for the gradient accumulation wrapper behavior

Important Files Changed

Filename Overview
transformer_engine/pytorch/graph.py Replaced autograd.grad with autograd.backward and added context wrapper to manage gradient accumulation during cudagraph capture
transformer_engine/pytorch/distributed.py Added graph-safe RNG state detection and conditional usage throughout checkpoint functions

Sequence Diagram

sequenceDiagram
    participant User
    participant make_graphed_callables
    participant _CheckpointFunction
    participant _none_grad_context_wrapper
    participant RNGTracker
    participant CUDAGraph
    
    User->>make_graphed_callables: Call with forward function
    
    Note over make_graphed_callables: Warmup Phase
    make_graphed_callables->>RNGTracker: get_states()
    RNGTracker-->>make_graphed_callables: tracker_states
    make_graphed_callables->>make_graphed_callables: Check if graph_safe via is_graph_safe_rng_state()
    make_graphed_callables->>make_graphed_callables: Get CUDA RNG state (graph_safe=True/False)
    make_graphed_callables->>make_graphed_callables: Execute forward pass
    
    alt Training Mode
        make_graphed_callables->>_none_grad_context_wrapper: Enter with inputs
        Note over _none_grad_context_wrapper: Save original grads<br/>Set input.grad = None
        make_graphed_callables->>make_graphed_callables: torch.autograd.backward()
        Note over make_graphed_callables: Grads accumulate in input.grad
        make_graphed_callables->>make_graphed_callables: Collect grad_inputs from input.grad
        _none_grad_context_wrapper->>_none_grad_context_wrapper: Restore original grads
    end
    
    Note over make_graphed_callables: Capture Phase
    make_graphed_callables->>CUDAGraph: Capture forward graph
    CUDAGraph-->>make_graphed_callables: fwd_graph
    
    alt Training Mode
        make_graphed_callables->>CUDAGraph: Capture backward graph
        make_graphed_callables->>_none_grad_context_wrapper: Enter with inputs
        make_graphed_callables->>make_graphed_callables: torch.autograd.backward()
        make_graphed_callables->>make_graphed_callables: Collect grad_inputs from input.grad
        _none_grad_context_wrapper->>_none_grad_context_wrapper: Restore original grads
        CUDAGraph-->>make_graphed_callables: bwd_graph
    end
    
    Note over _CheckpointFunction: During Recomputation
    _CheckpointFunction->>RNGTracker: get_states()
    RNGTracker-->>_CheckpointFunction: tracker_states
    _CheckpointFunction->>_CheckpointFunction: Determine graph_safe_rng_state
    _CheckpointFunction->>_CheckpointFunction: Get/Set RNG states (graph_safe=True/False)
    _CheckpointFunction->>_CheckpointFunction: Recompute forward
    _CheckpointFunction->>_CheckpointFunction: Restore RNG states (graph_safe=True/False)
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Robin Zhang <[email protected]>
@buptzyb buptzyb changed the title [PyTorch] Use autograd.backward to capture cudagraph backward [PyTorch] Support cudagraph recomputation Dec 16, 2025
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Robin Zhang <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/pytorch/graph.py, line 66-77 (link)

    logic: potential race condition if inputs tuple contains non-tensor items

    The code assumes all items in inputs are tensors with .grad attributes, but the wrapper doesn't validate this. If non-tensor items slip through, accessing .grad would raise AttributeError

    Are all items in the inputs tuple guaranteed to be tensors at the call sites, or could there be edge cases with non-tensor inputs?

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@buptzyb
Copy link
Contributor Author

buptzyb commented Dec 31, 2025

Additional Comments (1)

  1. transformer_engine/pytorch/graph.py, line 66-77 (link)
    logic: potential race condition if inputs tuple contains non-tensor items
    The code assumes all items in inputs are tensors with .grad attributes, but the wrapper doesn't validate this. If non-tensor items slip through, accessing .grad would raise AttributeError
    Are all items in the inputs tuple guaranteed to be tensors at the call sites, or could there be edge cases with non-tensor inputs?

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

The inputs are from per_callable_static_input_surfaces, which is guaranteed to be all tensors.

@buptzyb
Copy link
Contributor Author

buptzyb commented Dec 31, 2025

Hi @ksivaman could you help review? Thanks!

@ksivaman
Copy link
Member

/te-ci pytorch L0 L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/pytorch/graph.py, line 71-77 (link)

    style: The original gradient values are restored after the context exits. Verify this works correctly when gradients are already present (accumulation scenario).

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@ksivaman ksivaman merged commit 324be33 into NVIDIA:main Dec 31, 2025
28 of 32 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants