[AIROCMLIR-498] Attention scheduling improvements#2267
[AIROCMLIR-498] Attention scheduling improvements#2267justinrosner wants to merge 18 commits intodevelopfrom
Conversation
94e5d3b to
79dd5db
Compare
83934d4 to
b4ac474
Compare
|
Note for reviewers: greedy tuning is currently running for this Update: results posted here https://amd-hub.atlassian.net/jira/software/c/projects/AIROCMLIR/boards/2346?assignee=712020%3Aa4f5bae8-51c2-4c51-b4a5-20261e9bcc63&selectedIssue=AIROCMLIR-498 |
There was a problem hiding this comment.
Pull request overview
This PR improves Rock attention scheduling by adding a split-phase “V” tile prefetch (global-read before softmax, deferred LDS write, then LDS read for the peeled GEMM1 iteration) and by tightening post-lowering barrier cleanup so the generated pipelines remain well-scheduled after subsequent lowering/unrolling passes.
Changes:
- Add new
GemmLoadTileTypemodes (GlobalReadOnly,LDSWriteFromRegs,LDSReadOnly) to support V prefetch split-phases, plus updated lowering and tests. - Restructure attention GEMM1 lowering to peel the prefetched iteration and use
amdgpu.sched_barrierto prevent backend sinking of the V global loads. - Re-run
rock-pipelineafterrock-sugar-to-loops(and always run back-to-back barrier removal) to clean up adjacent LDS barriers introduced by unrolling.
Reviewed changes
Copilot reviewed 16 out of 16 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| mlir/test/rocmlir-driver/pipelines.mlir | Updates expected rocmlir-driver GPU pipeline to include an extra rock-pipeline run post rock-sugar-to-loops. |
| mlir/test/Dialect/Rock/toblockwise_attention_accel_lowering.mlir | Updates checks for the new V prefetch split-phase and peeled GEMM1 structure. |
| mlir/test/Dialect/Rock/test_rock_pipeline.mlir | Adjusts pipeline tests to match updated multibuffer extraction behavior. |
| mlir/test/Dialect/Rock/lowering_blockwise_broadcast_reduce.mlir | Updates checks for revised reduction lowering (including vector.reduction-with-acc patterns and -0.0 init). |
| mlir/test/Dialect/Rock/gridwise_attention_accel_lowering_gqa.mlir | Adds checks for V prefetch behavior in GQA attention lowering. |
| mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir | Updates checks for V prefetch stage naming and barrier placement. |
| mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp | Fixes/stabilizes stores in threadwise_read_into rewrite by using full dest coordinates. |
| mlir/lib/Dialect/Rock/Transforms/RockPipeline.cpp | Improves barrier motion/removal logic and ensures back-to-back barrier removal always runs. |
| mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp | Implements V prefetch emission and GEMM1 loop restructuring; adds amdgpu.sched_barrier. |
| mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp | Lowers new split-phase BlockwiseLoadTileOp modes and introduces distinct V stage names. |
| mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp | Updates blockwise reduction lowering to use vector.reduction with accumulator and branchless rThread reductions. |
| mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp | Adds a second rock-pipeline run after rock-sugar-to-loops to remove newly-adjacent barriers. |
| mlir/lib/Dialect/Rock/IR/RockDialect.cpp | Updates BlockwiseLoadTileOp memory effects for LDSReadOnly (but needs more updates for other new modes). |
| mlir/include/mlir/Dialect/Rock/Passes.td | Adds AMDGPU dialect dependency for gridwise->blockwise pass; documents rock-pipeline option. |
| mlir/include/mlir/Dialect/Rock/IR/RockOps.td | Documents the new GemmLoadTileType modes. |
| mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td | Extends the GemmLoadTileType enum with the new split-phase variants. |
Comments suppressed due to low confidence (1)
mlir/lib/Dialect/Rock/IR/RockDialect.cpp:2499
- The unconditional
effects.emplace_back(write, &getDestRegistersMutable()[0])for all non-Default load types is not correct for GemmLoadTileType::LDSWriteFromRegs, which reads from destRegisters but does not modify it. This should be modeled as a read effect on destRegisters (and no write effect) for LDSWriteFromRegs to keep MemoryEffects accurate.
if (doubleBuffer || ldsReadOnly)
effects.emplace_back(read, &getDestLDSMutable()[0]);
}
if (!singleBuffer) {
assert(getDestRegisters() != nullptr);
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // LDSReadOnly does not read from global source. | ||
| if (!ldsReadOnly) | ||
| effects.emplace_back(read, &getSourceMutable()); | ||
| if (loadType != GemmLoadTileType::BypassLDS) { | ||
| assert(getDestLDS() != nullptr); |
There was a problem hiding this comment.
BlockwiseLoadTileOp::getEffects doesn’t account for the new split-phase load types. For GemmLoadTileType::GlobalReadOnly the op does not write to destLDS, and for GemmLoadTileType::LDSWriteFromRegs the op should not read from source at all (it should read from destRegisters and write to destLDS). As written, MemoryEffects will incorrectly report global/LDS accesses, which can mislead scheduling and optimization passes that rely on effects. Please add explicit cases for GlobalReadOnly/LDSWriteFromRegs (and ensure LDSReadOnly remains LDS-read + regs-write only).
| Block::BlockArgListType allDestCoords = | ||
| loadLoop.getLowerCoords(/*domain=*/1); | ||
| size_t dropCount = allDestCoords.size() - dstRank; | ||
| SmallVector<Value> destCoords(allDestCoords.begin() + dropCount, | ||
| allDestCoords.end()); |
There was a problem hiding this comment.
dropCount is computed as allDestCoords.size() - dstRank with an unsigned size_t. If dstRank > allDestCoords.size() this underflows and will produce invalid iterators (likely crashing the compiler). Please add an explicit check/early failure (or assert with a clear message) that dstRank <= allDestCoords.size() before computing dropCount.
| // Check if the operation accesses LDS. | ||
| // We can move past LDS store-only operations because independent | ||
| // writes don't need ordering between them, the next barrier will | ||
| // ensure all writes complete before any subsequent reads. | ||
| // We must stop at LDS reads. |
There was a problem hiding this comment.
PushBarrierDownRewritePattern now allows pushing an rock.lds_barrier past LDS store ops. This can change the synchronization semantics: a barrier that previously ensured all threads finished prior LDS reads before an overwrite can end up after the overwrite, allowing races between reads and writes across threads. Unless you can prove the stores are to disjoint/unused LDS regions (e.g., different multibuffer slots), the barrier should not be moved past any op that touches workgroup memory. Consider restricting this to ops that do not access LDS at all, or add a more precise dependence check before allowing the swap.
| // Use -0.0 (negative zero) instead of +0.0. In IEEE 754, -0.0 is the | ||
| // true additive identity: fadd(-0.0, x) = x for ALL x (including -0.0 | ||
| // and NaN). LLVM can fold `fadd -0.0, x → x`, eliminating the | ||
| // redundant `v_add_f32 v, 0, v` that +0.0 generates via | ||
| // llvm.vector.reduce.fadd. |
There was a problem hiding this comment.
If this PR takes too long to merge, move these changes into a seperate PR and also create tests to make sure it doesn't generate v_add_f32 v, 0, v
| // where offset is a power of 2. | ||
| // Initial it starts with power = ceil(|rtid|, power of 2) / 2 | ||
| // Then keep on reducing the power. | ||
| // Branchless reduction: each thread reads all rTidDim partial |
There was a problem hiding this comment.
This also seems a like an independent change compared to scheduling VTile
| // Re-run the pipeline pass to remove back-to-back LDS barriers | ||
| // that may appear after SugarToLoops unrolls TransformingForOps. | ||
| funcPm.addPass(rock::createRockPipelinePass()); |
There was a problem hiding this comment.
I know we add LDSBarriers in gridewiseToBlockwise conservatively thinking rock-pipeline will take care of it.
Is it possible to add barriers such that we don't need to run rock-pipeline pass again ?
Is it possible to enhance logic for backToBackBarriers ? Unrolling will create back to back if there is
case like this i think
LDSBarrier (1)
Scf.for {
LDSBarrier (2)
....
LDSBarrier (3)
}
LDSBarrier (4)
(1) this barrier, ay not be necessary if loop body starts with barrier and after exiting the loop there's a barrier
(2) For loop carried deps possibly
(3) Can be eliminated if there is a barrier at the exit of the loop
(4) Exit barrier
|
@stefankoncarevic can you review this PR especially the reduction changes ? |
| // TODO: We may have to use a heuristic to determine whether or not to | ||
| // use this depending on the size of rTidCount. |
There was a problem hiding this comment.
Agree this needs a heuristic. The branchless approach is O(N) LDS reads vs O(log N) in the tree, so for small rTidCount (2-4) it's clearly better, but for larger values (8, 16) the extra LDS reads may outweigh the branch elimination benefit.
Suggestion: benchmark both approaches for representative configs with rTidCount = 2, 4, 8, 16 on target architectures to find the empirical crossover point, then add a threshold and keep the old tree path as a fallback.
stefankoncarevic
left a comment
There was a problem hiding this comment.
Good approach, eliminating scf::IfOp branches keeps softmax in a single basic block, improving backend scheduling. The reduction from log2(rTidCount) LDS barriers to a single barrier is a clear win. See inline comment about the TODO for a threshold heuristic.
vector::ReductionOp with accumulator: Clean optimization combines vector reduction + scalar accumulation into a single op.
One thing to note: this optimization benefits the Tree path's threadwise reduction (which uses vectorized loads, rIterVectorLen > 1), but does NOT benefit the DPP path's threadwise reduction because DPP forces localIterVectorLen = 1 (scalar loads).
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 16 out of 16 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| auto accRegType = MemRefType::get({1}, elemType, AffineMap{}, | ||
| privateMemoryAddressSpace); | ||
| Value accReg = GpuAllocOp::create(rewriter, loc, accRegType); | ||
| FillOp::create(rewriter, loc, accReg, initVal); |
There was a problem hiding this comment.
initVal is used to initialize accReg in the branchless reduction path, but it is declared inside the preceding if (threadViewShape[rIterDim] > 1) block. As written, this won’t compile (and even conceptually, the branchless reduction should be able to run when rIterDim <= 1). Move the initVal definition outside the conditional (or recompute it in the branchless block), or remove the FillOp entirely since the i==0 iteration overwrites the accumulator.
| FillOp::create(rewriter, loc, accReg, initVal); |
| GemmLoadTileType loadType = getLoadType(); | ||
| bool doubleBuffer = loadType == GemmLoadTileType::DoubleBuffer || | ||
| loadType == GemmLoadTileType::DirectToLDSDoubleBuffer; | ||
| bool singleBuffer = loadType == GemmLoadTileType::Default || | ||
| loadType == GemmLoadTileType::DirectToLDSDefault; | ||
| bool ldsReadOnly = loadType == GemmLoadTileType::LDSReadOnly; | ||
|
|
||
| effects.emplace_back(read, &getSourceMutable()); | ||
| // LDSReadOnly does not read from global source. | ||
| if (!ldsReadOnly) | ||
| effects.emplace_back(read, &getSourceMutable()); | ||
| if (loadType != GemmLoadTileType::BypassLDS) { | ||
| assert(getDestLDS() != nullptr); | ||
| effects.emplace_back(write, &getDestLDSMutable()[0]); | ||
| // DoubleBuffer means we write to LDS and then, load from it | ||
| if (doubleBuffer) | ||
| // LDSReadOnly only reads from LDS, it does not write to it. | ||
| if (!ldsReadOnly) | ||
| effects.emplace_back(write, &getDestLDSMutable()[0]); | ||
| if (doubleBuffer || ldsReadOnly) | ||
| effects.emplace_back(read, &getDestLDSMutable()[0]); | ||
| } |
There was a problem hiding this comment.
BlockwiseLoadTileOp::getEffects() was only updated for LDSReadOnly, but the new split-phase load types (GlobalReadOnly, LDSWriteFromRegs) also change the memory behavior:
GlobalReadOnlyshould not be marked as writing/readingdestLDS.LDSWriteFromRegsshould not be marked as readingsource, anddestRegistersis read (not written).
With the current effects, alias/dependence analyses (and scheduling/barrier logic that relies onhasEffect) will be incorrect for these new modes. Please update effects to reflect the actual reads/writes for eachGemmLoadTileTypecase.
| // Use -0.0 (negative zero) instead of +0.0. In IEEE 754, -0.0 is the | ||
| // true additive identity: fadd(-0.0, x) = x for ALL x (including -0.0 | ||
| // and NaN). LLVM can fold `fadd -0.0, x → x`, eliminating the | ||
| // redundant `v_add_f32 v, 0, v` that +0.0 generates via | ||
| // llvm.vector.reduce.fadd. |
There was a problem hiding this comment.
The new comment claims -0.0 is an additive identity “including NaN”, but IEEE-754 defines (-0.0) + NaN = NaN (same for +0.0). Consider adjusting the wording to avoid stating the identity property holds for NaNs; the optimization rationale about LLVM folding fadd -0.0, x -> x can still stand without that claim.
| // Use -0.0 (negative zero) instead of +0.0. In IEEE 754, -0.0 is the | |
| // true additive identity: fadd(-0.0, x) = x for ALL x (including -0.0 | |
| // and NaN). LLVM can fold `fadd -0.0, x → x`, eliminating the | |
| // redundant `v_add_f32 v, 0, v` that +0.0 generates via | |
| // llvm.vector.reduce.fadd. | |
| // Use -0.0 (negative zero) instead of +0.0. LLVM can fold | |
| // `fadd -0.0, x → x`, eliminating the redundant | |
| // `v_add_f32 v, 0, v` that +0.0 generates via | |
| // llvm.vector.reduce.fadd. (Note: IEEE 754 still propagates NaNs, | |
| // i.e., x + NaN = NaN for any x.) |
Motivation
This PR implements some improvements to the final scheduling of attention ops: https://amd-hub.atlassian.net/browse/AIROCMLIR-498
Technical Details
V Load Prefetching
amdgpu.sched_barrierprevents the backend from sinking these loads past softmax. The regs -> LDS write is placed either before or after the sum reduction depending on whether the combined LDS footprint fits in hardware limits. The GEMM1 loop is restructured to peel the first (prefetched) iteration.Test Plan
Test Result
Submission Checklist