-
Notifications
You must be signed in to change notification settings - Fork 602
[PyTorch] Support cudagraph recomputation #2518
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Robin Zhang <[email protected]>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR enables cudagraph recomputation support by making two key changes: replacing
The changes align with PyTorch's cudagraph requirements where Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant make_graphed_callables
participant _CheckpointFunction
participant _none_grad_context_wrapper
participant RNGTracker
participant CUDAGraph
User->>make_graphed_callables: Call with forward function
Note over make_graphed_callables: Warmup Phase
make_graphed_callables->>RNGTracker: get_states()
RNGTracker-->>make_graphed_callables: tracker_states
make_graphed_callables->>make_graphed_callables: Check if graph_safe via is_graph_safe_rng_state()
make_graphed_callables->>make_graphed_callables: Get CUDA RNG state (graph_safe=True/False)
make_graphed_callables->>make_graphed_callables: Execute forward pass
alt Training Mode
make_graphed_callables->>_none_grad_context_wrapper: Enter with inputs
Note over _none_grad_context_wrapper: Save original grads<br/>Set input.grad = None
make_graphed_callables->>make_graphed_callables: torch.autograd.backward()
Note over make_graphed_callables: Grads accumulate in input.grad
make_graphed_callables->>make_graphed_callables: Collect grad_inputs from input.grad
_none_grad_context_wrapper->>_none_grad_context_wrapper: Restore original grads
end
Note over make_graphed_callables: Capture Phase
make_graphed_callables->>CUDAGraph: Capture forward graph
CUDAGraph-->>make_graphed_callables: fwd_graph
alt Training Mode
make_graphed_callables->>CUDAGraph: Capture backward graph
make_graphed_callables->>_none_grad_context_wrapper: Enter with inputs
make_graphed_callables->>make_graphed_callables: torch.autograd.backward()
make_graphed_callables->>make_graphed_callables: Collect grad_inputs from input.grad
_none_grad_context_wrapper->>_none_grad_context_wrapper: Restore original grads
CUDAGraph-->>make_graphed_callables: bwd_graph
end
Note over _CheckpointFunction: During Recomputation
_CheckpointFunction->>RNGTracker: get_states()
RNGTracker-->>_CheckpointFunction: tracker_states
_CheckpointFunction->>_CheckpointFunction: Determine graph_safe_rng_state
_CheckpointFunction->>_CheckpointFunction: Get/Set RNG states (graph_safe=True/False)
_CheckpointFunction->>_CheckpointFunction: Recompute forward
_CheckpointFunction->>_CheckpointFunction: Restore RNG states (graph_safe=True/False)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, no comments
Signed-off-by: Robin Zhang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, no comments
Signed-off-by: Robin Zhang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/pytorch/graph.py, line 66-77 (link)logic: potential race condition if
inputstuple contains non-tensor itemsThe code assumes all items in
inputsare tensors with.gradattributes, but the wrapper doesn't validate this. If non-tensor items slip through, accessing.gradwould raiseAttributeErrorAre all items in the inputs tuple guaranteed to be tensors at the call sites, or could there be edge cases with non-tensor inputs?
2 files reviewed, 1 comment
The inputs are from |
|
Hi @ksivaman could you help review? Thanks! |
|
/te-ci pytorch L0 L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/pytorch/graph.py, line 71-77 (link)style: The original gradient values are restored after the context exits. Verify this works correctly when gradients are already present (accumulation scenario).
2 files reviewed, 1 comment
Description
Support cudagraph recomputation with two changes:
autograd.gradwithautograd.backwardin cudagraph capturing.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: