-
Notifications
You must be signed in to change notification settings - Fork 52
(Async)DirectToLDS support in WMMA #2182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
| int64_t mPerWave = tuningParams.getMPerWave(); | ||
| int64_t nPerWave = tuningParams.getNPerWave(); | ||
| int64_t kPack = tuningParams.getKpack(); | ||
| // TODO: gfx10 supports directToLDS. Implement it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note for reviewers: This comment is wrong. gfx10 supports directoLDS, but it does not support WMMA, so the comment does not make sense here.
| // instruction | ||
| ldsIndex = arith::AddIOp::create(b, loc, ldsIndex, ldsIndexWave); | ||
|
|
||
| if (isAsyncDirectToLDSSupported(maybeArch.value())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note for reviewers: Is this the right place to implement this?
| @@ -1,21 +1,27 @@ | |||
| // RUN: rocmlir-opt %s --rock-to-rocdl | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😅
| } else { | ||
| b.replaceOpWithNewOp<memref::StoreOp>(op, op.getData(), op.getDest(), | ||
| op.getCoords()); | ||
| Location loc = op.getLoc(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note for reviewers: This is needed because of the changes in GlobalLoadToLDSRewritePattern. It is extending the InBoundsStore capabilities, but it's not directly related to this PR, so maybe we want to move it into a separate PR to have a cleaner git history.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I'd prefer this to be an independent PR.
| config = "-g 1 -m 512 -k 1 -n 512" | ||
|
|
||
| [[suite.test]] | ||
| config = "-g 1 -m 512 -k 32 -n 512" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note for reviwers: DirecToLDS counterpart also has:
[[suite.test]]
config = "-g 3 -m 1024 -k 768 -n 1024"
but on emulator this takes too long (>30m) so I removed it.
| b.setInsertionPointToEnd(elseBlock); | ||
| Type transferType = op.getTransferType(); | ||
| Value zeroValue = createZeroConstantOp(b, loc, transferType); | ||
| InBoundsStoreOp::create(b, loc, zeroValue, dest, destCoords); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note for reviewers: Here I use InBoundsStoreOp to store zeros. Let's assume a case like the following: transferType=f64 and LDS buffer is f16. We would need to write 4 x f16 of zeros. But what if actually the first 32 bits are in bounds and the last 32 are out-of-bounds, meaning that we should actually read 32bits from LDS and the set to zero the remaining 32 bits? This is not handled here. Can that happen? How should it be handled?
| << "AsyncWaitOpConversion: arch supports AsyncDirectToLDS\n"); | ||
| unsigned asyncCnt = std::min(63u, 0u); | ||
| ROCDL::WaitAsynccntOp::create(rewriter, loc, asyncCnt); | ||
| ROCDL::SBarrierOp::create(rewriter, loc); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need sbarrier here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to get rid of the barriers that RockPipeline will introduce for gfx1250 as well (I think WaitAsynccntOp is enough). Is there a ticket to do that?
| if (supportsAsyncDirectToLDS) { | ||
| LLVM_DEBUG(llvm::dbgs() | ||
| << "AsyncWaitOpConversion: arch supports AsyncDirectToLDS\n"); | ||
| unsigned asyncCnt = std::min(63u, 0u); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the max 63 for WaitAsynccntOp as well?
| coords.push_back(b.createOrFold<ConstantIndexOp>(loc, 0)); | ||
| } | ||
|
|
||
| Type originalLoadedType = op.getTransferType(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this could be another PR? it's not related to gfx1250, right?
| } else { | ||
| b.replaceOpWithNewOp<memref::StoreOp>(op, op.getData(), op.getDest(), | ||
| op.getCoords()); | ||
| Location loc = op.getLoc(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I'd prefer this to be an independent PR.
| return emitError(loc) | ||
| << "128 bits direct to LDS is not supported by the hardware"; | ||
| } | ||
| } else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this else (and the above if) should only run for direct to lds, not async.
| // the same output index), async DirectToLDS actually works like a | ||
| // traditional load, so we must take into account the thread-specific | ||
| // offset here. | ||
| if (loadTypeByteWidth == 16) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of this, it'd be better to have a transform for the output tensor as well. So, instead of having linear indexing of the output we could have anything.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that would allow us to store in LDS with [kpackperblock, dperblock, kpack] layout.
| if not config.arch.startswith("gfx1250"): | ||
| config.unsupported = True | ||
|
|
||
| # This is useful when running on the emulator, to propagate the environment variables |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we probably don't want to merge this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we have attention tests as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how is this different than direct to lds tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR implements support for Async DirectToLDS in WMMA operations for gfx1250, completing the work started in #2072 which introduced basic async load support but lacked WMMA compatibility.
Key changes include:
- Enabled DirectToLDS for WMMA operations with support for multiple load widths (8, 32, 64, and 128 bits) on gfx1250
- Modified async DirectToLDS to use normal load semantics instead of gather semantics, accounting for thread-specific offsets
- Added WaitAsynccntOp generation for async operations and out-of-bounds checking for GlobalLoadToLDS
Reviewed changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| mlir/test/e2e/GemmAsyncDirectToLDS.toml | E2E test configuration for async DirectToLDS with multiple GEMM configurations |
| mlir/test/e2e/GemmAsyncDirectToLDS.cfg | Test filter restricting tests to gfx1250 architecture |
| mlir/test/e2e/CMakeLists.txt | Added GemmAsyncDirectToLDS to test configuration list |
| mlir/test/Dialect/Rock/async_load_to_lds.mlir | LIT test verifying async_load_to_lds operation generation |
| mlir/test/Dialect/Rock/lowering_global_load_to_lds.mlir | Added test case for OOB checks in async DirectToLDS |
| mlir/test/Dialect/Rock/async_wait_lowering.mlir | Updated to test WaitAsynccntOp lowering for gfx1250 |
| mlir/lib/Dialect/Rock/utility/loweringUtils.cpp | Extended DirectToLDS logic to support all load widths for async operations |
| mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp | Implemented WMMA LDS buffer wrapping for DirectToLDS with proper transform maps |
| mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp | Added thread-specific offset calculations for async DirectToLDS |
| mlir/lib/Dialect/Rock/Transforms/SugarToLoops.cpp | Implemented OOB checks and enhanced InBoundsStoreOp to handle type width mismatches |
| mlir/lib/Dialect/Rock/Transforms/LowerRockOpsToROCDLOps.cpp | Added WaitAsynccntOp lowering path for async DirectToLDS architectures |
| mlir/lib/Dialect/Rock/IR/RockDialect.cpp | Extended GlobalLoadToLDSOp verification to support 8, 32, 64, and 128-bit loads |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if (supportsAsyncDirectToLDS) { | ||
| LLVM_DEBUG(llvm::dbgs() | ||
| << "AsyncWaitOpConversion: arch supports AsyncDirectToLDS\n"); | ||
| unsigned asyncCnt = std::min(63u, 0u); |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The asyncCnt calculation uses std::min(63u, 0u) which always returns 0, ignoring the operation's numInst attribute. This should be std::min(63u, op.getNumInst()) to properly clamp the instruction count, similar to the vmCnt calculation in the else branch.
| unsigned asyncCnt = std::min(63u, 0u); | |
| unsigned asyncCnt = std::min(63u, op.getNumInst()); |
| "Source element type is larger than destination, but not a " | ||
| "multiple of the destination element type"); |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message states "Source element type is larger than destination" but this error is triggered in the else clause that handles cases where either srcBits <= destBits OR srcBits is not a multiple of destBits. The message should be more accurate to cover all these cases, for example: "Source and destination element types have incompatible bit widths (source must equal destination or be a multiple of destination)".
| "Source element type is larger than destination, but not a " | |
| "multiple of the destination element type"); | |
| "Source and destination element types have incompatible bit " | |
| "widths (source bit width must equal destination bit width or be " | |
| "an integer multiple of it)"); |
| // handle both cases. | ||
| kVec = kPack; | ||
| kPerBlock *= kPack; | ||
| assert(!rotateDWithK && "rotateDWithK must not be enabled for directToLds"); |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inconsistent capitalization: "directToLds" should be "directToLDS" to match the casing used throughout the codebase (e.g., "DirectToLDS", "directToLDS", "asyncDirectToLDS").
| assert(!rotateDWithK && "rotateDWithK must not be enabled for directToLds"); | |
| assert(!rotateDWithK && "rotateDWithK must not be enabled for directToLDS"); |
Motivation
In #2072 we introduced support for async DirectToLDS for gfx1250. However, DirectToLDS is not supported for WMMA, so gfx1250 would still be unable to use async DirectToLDS.
This PR implements it, plus other required features, in order to give full support for async DirectToLDS in gfx1250.
Technical Details
This PR adds the following:
WaitAsynccntOpwhen loweringrock::AsyncWaitOp. Similarly to howSWaitcntOpis needed for traditional DirectToLDS,WaitAsynccntOpis required to wait for async load ops.emitOobChecks) inGlobalLoadToLDSlowering in SugarToLoops. The idea is to followGlobalLoadlowering, however is not easy to do so.More details about the
emitOobCheckssupportGlobalLoadlowering inGlobalLoadToLDSlowering because the ops have one radical difference: the former returns an SSA value with the result, whereas the latter does not. We might want to cleanup this in the future.elsecondition that we generate when out-of-bound checks are needed, we have to store zeros into LDS. However,GlobalLoadToLDShastransferType, meaning that we might need to write multiple elements (e.g., iftransferTypeisf64but LDS buffer isf16). My approach is to useInBoundsStoreOpand add support to such op to write multiple elements. Note that we might want to separate this into another PR.Test Plan
GemmAsyncDirectToLDS).rock::AsyncWaitOplowering.lowering_global_load_to_ldsto exercise the case of OOB checks.Test Result
All new E2E test pass on the emulator.
Submission Checklist