Skip to content

Conversation

@pabloantoniom
Copy link
Contributor

@pabloantoniom pabloantoniom commented Dec 22, 2025

Motivation

In #2072 we introduced support for async DirectToLDS for gfx1250. However, DirectToLDS is not supported for WMMA, so gfx1250 would still be unable to use async DirectToLDS.

This PR implements it, plus other required features, in order to give full support for async DirectToLDS in gfx1250.

Technical Details

This PR adds the following:

  • Support for DirectToLDS for WMMA (gfx1250 is the only GPU available that can exploit it).
  • Changed the way Async DirectToLDS loads from memory: Previously we assumed it behaved like a gather (i.e., like traditional DirectToLDS), but it actually works like a normal load, so we must take into account the thread ID when computing the destination indices for the op.
  • Generate WaitAsynccntOp when lowering rock::AsyncWaitOp. Similarly to how SWaitcntOp is needed for traditional DirectToLDS, WaitAsynccntOp is required to wait for async load ops.
  • Added support for out-of-bounds checks (emitOobChecks) in GlobalLoadToLDS lowering in SugarToLoops. The idea is to follow GlobalLoad lowering, however is not easy to do so.

More details about the emitOobChecks support

  1. We cannot reuse code from GlobalLoad lowering in GlobalLoadToLDS lowering because the ops have one radical difference: the former returns an SSA value with the result, whereas the latter does not. We might want to cleanup this in the future.
  2. In the else condition that we generate when out-of-bound checks are needed, we have to store zeros into LDS. However, GlobalLoadToLDS has transferType, meaning that we might need to write multiple elements (e.g., if transferType is f64 but LDS buffer is f16). My approach is to use InBoundsStoreOp and add support to such op to write multiple elements. Note that we might want to separate this into another PR.

Test Plan

  • Added E2E test (GemmAsyncDirectToLDS).
  • Adapted LIT tests due to change in rock::AsyncWaitOp lowering.
  • Added LIT test in lowering_global_load_to_lds to exercise the case of OOB checks.

Test Result

All new E2E test pass on the emulator.

Submission Checklist

int64_t mPerWave = tuningParams.getMPerWave();
int64_t nPerWave = tuningParams.getNPerWave();
int64_t kPack = tuningParams.getKpack();
// TODO: gfx10 supports directToLDS. Implement it.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for reviewers: This comment is wrong. gfx10 supports directoLDS, but it does not support WMMA, so the comment does not make sense here.

// instruction
ldsIndex = arith::AddIOp::create(b, loc, ldsIndex, ldsIndexWave);

if (isAsyncDirectToLDSSupported(maybeArch.value())) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for reviewers: Is this the right place to implement this?

@@ -1,21 +1,27 @@
// RUN: rocmlir-opt %s --rock-to-rocdl
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😅

@pabloantoniom pabloantoniom marked this pull request as ready for review December 22, 2025 15:35
} else {
b.replaceOpWithNewOp<memref::StoreOp>(op, op.getData(), op.getDest(),
op.getCoords());
Location loc = op.getLoc();
Copy link
Contributor Author

@pabloantoniom pabloantoniom Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for reviewers: This is needed because of the changes in GlobalLoadToLDSRewritePattern. It is extending the InBoundsStore capabilities, but it's not directly related to this PR, so maybe we want to move it into a separate PR to have a cleaner git history.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I'd prefer this to be an independent PR.

config = "-g 1 -m 512 -k 1 -n 512"

[[suite.test]]
config = "-g 1 -m 512 -k 32 -n 512"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for reviwers: DirecToLDS counterpart also has:

[[suite.test]]
config = "-g 3 -m 1024 -k 768 -n 1024"

but on emulator this takes too long (>30m) so I removed it.

b.setInsertionPointToEnd(elseBlock);
Type transferType = op.getTransferType();
Value zeroValue = createZeroConstantOp(b, loc, transferType);
InBoundsStoreOp::create(b, loc, zeroValue, dest, destCoords);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for reviewers: Here I use InBoundsStoreOp to store zeros. Let's assume a case like the following: transferType=f64 and LDS buffer is f16. We would need to write 4 x f16 of zeros. But what if actually the first 32 bits are in bounds and the last 32 are out-of-bounds, meaning that we should actually read 32bits from LDS and the set to zero the remaining 32 bits? This is not handled here. Can that happen? How should it be handled?

@pabloantoniom pabloantoniom changed the title [WIP] (Async)DirectToLDS support in WMMA (Async)DirectToLDS support in WMMA Dec 23, 2025
<< "AsyncWaitOpConversion: arch supports AsyncDirectToLDS\n");
unsigned asyncCnt = std::min(63u, 0u);
ROCDL::WaitAsynccntOp::create(rewriter, loc, asyncCnt);
ROCDL::SBarrierOp::create(rewriter, loc);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need sbarrier here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to get rid of the barriers that RockPipeline will introduce for gfx1250 as well (I think WaitAsynccntOp is enough). Is there a ticket to do that?

if (supportsAsyncDirectToLDS) {
LLVM_DEBUG(llvm::dbgs()
<< "AsyncWaitOpConversion: arch supports AsyncDirectToLDS\n");
unsigned asyncCnt = std::min(63u, 0u);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the max 63 for WaitAsynccntOp as well?

coords.push_back(b.createOrFold<ConstantIndexOp>(loc, 0));
}

Type originalLoadedType = op.getTransferType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be another PR? it's not related to gfx1250, right?

} else {
b.replaceOpWithNewOp<memref::StoreOp>(op, op.getData(), op.getDest(),
op.getCoords());
Location loc = op.getLoc();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I'd prefer this to be an independent PR.

return emitError(loc)
<< "128 bits direct to LDS is not supported by the hardware";
}
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this else (and the above if) should only run for direct to lds, not async.

// the same output index), async DirectToLDS actually works like a
// traditional load, so we must take into account the thread-specific
// offset here.
if (loadTypeByteWidth == 16) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of this, it'd be better to have a transform for the output tensor as well. So, instead of having linear indexing of the output we could have anything.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that would allow us to store in LDS with [kpackperblock, dperblock, kpack] layout.

if not config.arch.startswith("gfx1250"):
config.unsupported = True

# This is useful when running on the emulator, to propagate the environment variables
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably don't want to merge this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have attention tests as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is this different than direct to lds tests?

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR implements support for Async DirectToLDS in WMMA operations for gfx1250, completing the work started in #2072 which introduced basic async load support but lacked WMMA compatibility.

Key changes include:

  • Enabled DirectToLDS for WMMA operations with support for multiple load widths (8, 32, 64, and 128 bits) on gfx1250
  • Modified async DirectToLDS to use normal load semantics instead of gather semantics, accounting for thread-specific offsets
  • Added WaitAsynccntOp generation for async operations and out-of-bounds checking for GlobalLoadToLDS

Reviewed changes

Copilot reviewed 12 out of 12 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
mlir/test/e2e/GemmAsyncDirectToLDS.toml E2E test configuration for async DirectToLDS with multiple GEMM configurations
mlir/test/e2e/GemmAsyncDirectToLDS.cfg Test filter restricting tests to gfx1250 architecture
mlir/test/e2e/CMakeLists.txt Added GemmAsyncDirectToLDS to test configuration list
mlir/test/Dialect/Rock/async_load_to_lds.mlir LIT test verifying async_load_to_lds operation generation
mlir/test/Dialect/Rock/lowering_global_load_to_lds.mlir Added test case for OOB checks in async DirectToLDS
mlir/test/Dialect/Rock/async_wait_lowering.mlir Updated to test WaitAsynccntOp lowering for gfx1250
mlir/lib/Dialect/Rock/utility/loweringUtils.cpp Extended DirectToLDS logic to support all load widths for async operations
mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp Implemented WMMA LDS buffer wrapping for DirectToLDS with proper transform maps
mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp Added thread-specific offset calculations for async DirectToLDS
mlir/lib/Dialect/Rock/Transforms/SugarToLoops.cpp Implemented OOB checks and enhanced InBoundsStoreOp to handle type width mismatches
mlir/lib/Dialect/Rock/Transforms/LowerRockOpsToROCDLOps.cpp Added WaitAsynccntOp lowering path for async DirectToLDS architectures
mlir/lib/Dialect/Rock/IR/RockDialect.cpp Extended GlobalLoadToLDSOp verification to support 8, 32, 64, and 128-bit loads

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

if (supportsAsyncDirectToLDS) {
LLVM_DEBUG(llvm::dbgs()
<< "AsyncWaitOpConversion: arch supports AsyncDirectToLDS\n");
unsigned asyncCnt = std::min(63u, 0u);
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The asyncCnt calculation uses std::min(63u, 0u) which always returns 0, ignoring the operation's numInst attribute. This should be std::min(63u, op.getNumInst()) to properly clamp the instruction count, similar to the vmCnt calculation in the else branch.

Suggested change
unsigned asyncCnt = std::min(63u, 0u);
unsigned asyncCnt = std::min(63u, op.getNumInst());

Copilot uses AI. Check for mistakes.
Comment on lines +1751 to +1752
"Source element type is larger than destination, but not a "
"multiple of the destination element type");
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message states "Source element type is larger than destination" but this error is triggered in the else clause that handles cases where either srcBits <= destBits OR srcBits is not a multiple of destBits. The message should be more accurate to cover all these cases, for example: "Source and destination element types have incompatible bit widths (source must equal destination or be a multiple of destination)".

Suggested change
"Source element type is larger than destination, but not a "
"multiple of the destination element type");
"Source and destination element types have incompatible bit "
"widths (source bit width must equal destination bit width or be "
"an integer multiple of it)");

Copilot uses AI. Check for mistakes.
// handle both cases.
kVec = kPack;
kPerBlock *= kPack;
assert(!rotateDWithK && "rotateDWithK must not be enabled for directToLds");
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent capitalization: "directToLds" should be "directToLDS" to match the casing used throughout the codebase (e.g., "DirectToLDS", "directToLDS", "asyncDirectToLDS").

Suggested change
assert(!rotateDWithK && "rotateDWithK must not be enabled for directToLds");
assert(!rotateDWithK && "rotateDWithK must not be enabled for directToLDS");

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants