Skip to content

[AIROCMLIR-498] Attention scheduling improvements#2267

Open
justinrosner wants to merge 18 commits intodevelopfrom
498-attention-scheduling
Open

[AIROCMLIR-498] Attention scheduling improvements#2267
justinrosner wants to merge 18 commits intodevelopfrom
498-attention-scheduling

Conversation

@justinrosner
Copy link
Copy Markdown
Contributor

@justinrosner justinrosner commented Mar 3, 2026

Motivation

This PR implements some improvements to the final scheduling of attention ops: https://amd-hub.atlassian.net/browse/AIROCMLIR-498

Technical Details

V Load Prefetching

  • The first V tile's global reads are issued before softmax begins using a new split-phase load mechanism. An amdgpu.sched_barrier prevents the backend from sinking these loads past softmax. The regs -> LDS write is placed either before or after the sum reduction depending on whether the combined LDS footprint fits in hardware limits. The GEMM1 loop is restructured to peel the first (prefetched) iteration.

Test Plan

  • Nightly CI
  • Benchmarking on gfx950
  • Benchmarking on gfx942
  • Benchmarking on Navi3x

Test Result

  • Nightly CI
  • See linked JIRA ticket for performance result discussions

Submission Checklist

@justinrosner justinrosner force-pushed the 498-attention-scheduling branch from 94e5d3b to 79dd5db Compare March 9, 2026 21:16
@justinrosner justinrosner force-pushed the 498-attention-scheduling branch from 83934d4 to b4ac474 Compare March 12, 2026 17:34
@justinrosner
Copy link
Copy Markdown
Contributor Author

justinrosner commented Mar 17, 2026

Note for reviewers: greedy tuning is currently running for this

Update: results posted here https://amd-hub.atlassian.net/jira/software/c/projects/AIROCMLIR/boards/2346?assignee=712020%3Aa4f5bae8-51c2-4c51-b4a5-20261e9bcc63&selectedIssue=AIROCMLIR-498

@justinrosner justinrosner marked this pull request as ready for review March 25, 2026 15:45
@justinrosner justinrosner requested a review from causten as a code owner March 25, 2026 15:45
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR improves Rock attention scheduling by adding a split-phase “V” tile prefetch (global-read before softmax, deferred LDS write, then LDS read for the peeled GEMM1 iteration) and by tightening post-lowering barrier cleanup so the generated pipelines remain well-scheduled after subsequent lowering/unrolling passes.

Changes:

  • Add new GemmLoadTileType modes (GlobalReadOnly, LDSWriteFromRegs, LDSReadOnly) to support V prefetch split-phases, plus updated lowering and tests.
  • Restructure attention GEMM1 lowering to peel the prefetched iteration and use amdgpu.sched_barrier to prevent backend sinking of the V global loads.
  • Re-run rock-pipeline after rock-sugar-to-loops (and always run back-to-back barrier removal) to clean up adjacent LDS barriers introduced by unrolling.

Reviewed changes

Copilot reviewed 16 out of 16 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
mlir/test/rocmlir-driver/pipelines.mlir Updates expected rocmlir-driver GPU pipeline to include an extra rock-pipeline run post rock-sugar-to-loops.
mlir/test/Dialect/Rock/toblockwise_attention_accel_lowering.mlir Updates checks for the new V prefetch split-phase and peeled GEMM1 structure.
mlir/test/Dialect/Rock/test_rock_pipeline.mlir Adjusts pipeline tests to match updated multibuffer extraction behavior.
mlir/test/Dialect/Rock/lowering_blockwise_broadcast_reduce.mlir Updates checks for revised reduction lowering (including vector.reduction-with-acc patterns and -0.0 init).
mlir/test/Dialect/Rock/gridwise_attention_accel_lowering_gqa.mlir Adds checks for V prefetch behavior in GQA attention lowering.
mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir Updates checks for V prefetch stage naming and barrier placement.
mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp Fixes/stabilizes stores in threadwise_read_into rewrite by using full dest coordinates.
mlir/lib/Dialect/Rock/Transforms/RockPipeline.cpp Improves barrier motion/removal logic and ensures back-to-back barrier removal always runs.
mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp Implements V prefetch emission and GEMM1 loop restructuring; adds amdgpu.sched_barrier.
mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp Lowers new split-phase BlockwiseLoadTileOp modes and introduces distinct V stage names.
mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp Updates blockwise reduction lowering to use vector.reduction with accumulator and branchless rThread reductions.
mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp Adds a second rock-pipeline run after rock-sugar-to-loops to remove newly-adjacent barriers.
mlir/lib/Dialect/Rock/IR/RockDialect.cpp Updates BlockwiseLoadTileOp memory effects for LDSReadOnly (but needs more updates for other new modes).
mlir/include/mlir/Dialect/Rock/Passes.td Adds AMDGPU dialect dependency for gridwise->blockwise pass; documents rock-pipeline option.
mlir/include/mlir/Dialect/Rock/IR/RockOps.td Documents the new GemmLoadTileType modes.
mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td Extends the GemmLoadTileType enum with the new split-phase variants.
Comments suppressed due to low confidence (1)

mlir/lib/Dialect/Rock/IR/RockDialect.cpp:2499

  • The unconditional effects.emplace_back(write, &getDestRegistersMutable()[0]) for all non-Default load types is not correct for GemmLoadTileType::LDSWriteFromRegs, which reads from destRegisters but does not modify it. This should be modeled as a read effect on destRegisters (and no write effect) for LDSWriteFromRegs to keep MemoryEffects accurate.
    if (doubleBuffer || ldsReadOnly)
      effects.emplace_back(read, &getDestLDSMutable()[0]);
  }
  if (!singleBuffer) {
    assert(getDestRegisters() != nullptr);

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +2487 to 2491
// LDSReadOnly does not read from global source.
if (!ldsReadOnly)
effects.emplace_back(read, &getSourceMutable());
if (loadType != GemmLoadTileType::BypassLDS) {
assert(getDestLDS() != nullptr);
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BlockwiseLoadTileOp::getEffects doesn’t account for the new split-phase load types. For GemmLoadTileType::GlobalReadOnly the op does not write to destLDS, and for GemmLoadTileType::LDSWriteFromRegs the op should not read from source at all (it should read from destRegisters and write to destLDS). As written, MemoryEffects will incorrectly report global/LDS accesses, which can mislead scheduling and optimization passes that rely on effects. Please add explicit cases for GlobalReadOnly/LDSWriteFromRegs (and ensure LDSReadOnly remains LDS-read + regs-write only).

Copilot uses AI. Check for mistakes.
Comment on lines +826 to +830
Block::BlockArgListType allDestCoords =
loadLoop.getLowerCoords(/*domain=*/1);
size_t dropCount = allDestCoords.size() - dstRank;
SmallVector<Value> destCoords(allDestCoords.begin() + dropCount,
allDestCoords.end());
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dropCount is computed as allDestCoords.size() - dstRank with an unsigned size_t. If dstRank > allDestCoords.size() this underflows and will produce invalid iterators (likely crashing the compiler). Please add an explicit check/early failure (or assert with a clear message) that dstRank <= allDestCoords.size() before computing dropCount.

Copilot uses AI. Check for mistakes.
Comment on lines +179 to +183
// Check if the operation accesses LDS.
// We can move past LDS store-only operations because independent
// writes don't need ordering between them, the next barrier will
// ensure all writes complete before any subsequent reads.
// We must stop at LDS reads.
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PushBarrierDownRewritePattern now allows pushing an rock.lds_barrier past LDS store ops. This can change the synchronization semantics: a barrier that previously ensured all threads finished prior LDS reads before an overwrite can end up after the overwrite, allowing races between reads and writes across threads. Unless you can prove the stores are to disjoint/unused LDS regions (e.g., different multibuffer slots), the barrier should not be moved past any op that touches workgroup memory. Consider restricting this to ops that do not access LDS at all, or add a more precise dependence check before allowing the swap.

Copilot uses AI. Check for mistakes.
@justinrosner justinrosner changed the title [AIROCMLIR-498] [DRAFT] Attention scheduling improvements [AIROCMLIR-498] Attention scheduling improvements Mar 25, 2026
Comment on lines +931 to +935
// Use -0.0 (negative zero) instead of +0.0. In IEEE 754, -0.0 is the
// true additive identity: fadd(-0.0, x) = x for ALL x (including -0.0
// and NaN). LLVM can fold `fadd -0.0, x → x`, eliminating the
// redundant `v_add_f32 v, 0, v` that +0.0 generates via
// llvm.vector.reduce.fadd.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this PR takes too long to merge, move these changes into a seperate PR and also create tests to make sure it doesn't generate v_add_f32 v, 0, v

// where offset is a power of 2.
// Initial it starts with power = ceil(|rtid|, power of 2) / 2
// Then keep on reducing the power.
// Branchless reduction: each thread reads all rTidDim partial
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also seems a like an independent change compared to scheduling VTile

Comment on lines +243 to +245
// Re-run the pipeline pass to remove back-to-back LDS barriers
// that may appear after SugarToLoops unrolls TransformingForOps.
funcPm.addPass(rock::createRockPipelinePass());
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know we add LDSBarriers in gridewiseToBlockwise conservatively thinking rock-pipeline will take care of it.
Is it possible to add barriers such that we don't need to run rock-pipeline pass again ?
Is it possible to enhance logic for backToBackBarriers ? Unrolling will create back to back if there is
case like this i think

LDSBarrier (1)
Scf.for {
LDSBarrier (2)
....
LDSBarrier (3)
}
LDSBarrier (4)

(1) this barrier, ay not be necessary if loop body starts with barrier and after exiting the loop there's a barrier

(2) For loop carried deps possibly

(3) Can be eliminated if there is a barrier at the exit of the loop

(4) Exit barrier

@umangyadav
Copy link
Copy Markdown
Member

@stefankoncarevic can you review this PR especially the reduction changes ?

Comment on lines +1426 to +1427
// TODO: We may have to use a heuristic to determine whether or not to
// use this depending on the size of rTidCount.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree this needs a heuristic. The branchless approach is O(N) LDS reads vs O(log N) in the tree, so for small rTidCount (2-4) it's clearly better, but for larger values (8, 16) the extra LDS reads may outweigh the branch elimination benefit.

Suggestion: benchmark both approaches for representative configs with rTidCount = 2, 4, 8, 16 on target architectures to find the empirical crossover point, then add a threshold and keep the old tree path as a fallback.

Copy link
Copy Markdown
Contributor

@stefankoncarevic stefankoncarevic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good approach, eliminating scf::IfOp branches keeps softmax in a single basic block, improving backend scheduling. The reduction from log2(rTidCount) LDS barriers to a single barrier is a clear win. See inline comment about the TODO for a threshold heuristic.

vector::ReductionOp with accumulator: Clean optimization combines vector reduction + scalar accumulation into a single op.
One thing to note: this optimization benefits the Tree path's threadwise reduction (which uses vectorized loads, rIterVectorLen > 1), but does NOT benefit the DPP path's threadwise reduction because DPP forces localIterVectorLen = 1 (scalar loads).

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 16 out of 16 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

auto accRegType = MemRefType::get({1}, elemType, AffineMap{},
privateMemoryAddressSpace);
Value accReg = GpuAllocOp::create(rewriter, loc, accRegType);
FillOp::create(rewriter, loc, accReg, initVal);
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

initVal is used to initialize accReg in the branchless reduction path, but it is declared inside the preceding if (threadViewShape[rIterDim] > 1) block. As written, this won’t compile (and even conceptually, the branchless reduction should be able to run when rIterDim <= 1). Move the initVal definition outside the conditional (or recompute it in the branchless block), or remove the FillOp entirely since the i==0 iteration overwrites the accumulator.

Suggested change
FillOp::create(rewriter, loc, accReg, initVal);

Copilot uses AI. Check for mistakes.
Comment on lines 2480 to 2497
GemmLoadTileType loadType = getLoadType();
bool doubleBuffer = loadType == GemmLoadTileType::DoubleBuffer ||
loadType == GemmLoadTileType::DirectToLDSDoubleBuffer;
bool singleBuffer = loadType == GemmLoadTileType::Default ||
loadType == GemmLoadTileType::DirectToLDSDefault;
bool ldsReadOnly = loadType == GemmLoadTileType::LDSReadOnly;

effects.emplace_back(read, &getSourceMutable());
// LDSReadOnly does not read from global source.
if (!ldsReadOnly)
effects.emplace_back(read, &getSourceMutable());
if (loadType != GemmLoadTileType::BypassLDS) {
assert(getDestLDS() != nullptr);
effects.emplace_back(write, &getDestLDSMutable()[0]);
// DoubleBuffer means we write to LDS and then, load from it
if (doubleBuffer)
// LDSReadOnly only reads from LDS, it does not write to it.
if (!ldsReadOnly)
effects.emplace_back(write, &getDestLDSMutable()[0]);
if (doubleBuffer || ldsReadOnly)
effects.emplace_back(read, &getDestLDSMutable()[0]);
}
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BlockwiseLoadTileOp::getEffects() was only updated for LDSReadOnly, but the new split-phase load types (GlobalReadOnly, LDSWriteFromRegs) also change the memory behavior:

  • GlobalReadOnly should not be marked as writing/reading destLDS.
  • LDSWriteFromRegs should not be marked as reading source, and destRegisters is read (not written).
    With the current effects, alias/dependence analyses (and scheduling/barrier logic that relies on hasEffect) will be incorrect for these new modes. Please update effects to reflect the actual reads/writes for each GemmLoadTileType case.

Copilot uses AI. Check for mistakes.
Comment on lines +931 to +935
// Use -0.0 (negative zero) instead of +0.0. In IEEE 754, -0.0 is the
// true additive identity: fadd(-0.0, x) = x for ALL x (including -0.0
// and NaN). LLVM can fold `fadd -0.0, x → x`, eliminating the
// redundant `v_add_f32 v, 0, v` that +0.0 generates via
// llvm.vector.reduce.fadd.
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new comment claims -0.0 is an additive identity “including NaN”, but IEEE-754 defines (-0.0) + NaN = NaN (same for +0.0). Consider adjusting the wording to avoid stating the identity property holds for NaNs; the optimization rationale about LLVM folding fadd -0.0, x -> x can still stand without that claim.

Suggested change
// Use -0.0 (negative zero) instead of +0.0. In IEEE 754, -0.0 is the
// true additive identity: fadd(-0.0, x) = x for ALL x (including -0.0
// and NaN). LLVM can fold `fadd -0.0, x → x`, eliminating the
// redundant `v_add_f32 v, 0, v` that +0.0 generates via
// llvm.vector.reduce.fadd.
// Use -0.0 (negative zero) instead of +0.0. LLVM can fold
// `fadd -0.0, x → x`, eliminating the redundant
// `v_add_f32 v, 0, v` that +0.0 generates via
// llvm.vector.reduce.fadd. (Note: IEEE 754 still propagates NaNs,
// i.e., x + NaN = NaN for any x.)

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants