-
Notifications
You must be signed in to change notification settings - Fork 52
MIGraphX support for Flash Decoding + KVCache #2174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
|
Note: There is still some testing with MIGraphX that is going on (need the second kernel that isn't exposed to rocMLIR). I will convert this to a regular PR when those results come back clean |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR implements Flash Decoding with KVCache support for MIGraphX, enabling the combination of these two optimizations for efficient attention operations. The implementation also includes support for prefix causal attention when used with Flash Decoding and KVCache.
Key changes:
- Generalized splitKV detection to work with K and V tensors (previously only V)
- Added support for 4D and 5D attention layouts with splitKV dimension handling
- Implemented transformation logic for currentSeqLen and prefixOffset tensors to slice away the splitKV dimension
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| mlir/test/fusion/pr-e2e/attention/mixr-attention-flash-decoding-kvcache.mlir | End-to-end test for Flash Decoding with KVCache combination |
| mlir/test/fusion/pr-e2e/attention/mixr-attention-flash-decoding-kvcache-prefix-causal.mlir | End-to-end test for Flash Decoding with KVCache and prefix causal attention |
| mlir/test/Conversion/TosaToRock/tosa-to-rock-attention-flash-decoding-kvcache.mlir | TOSA to Rock conversion test for Flash Decoding with KVCache |
| mlir/test/Conversion/TosaToRock/tosa-to-rock-attention-flash-decoding-kvcache-prefix-causal.mlir | TOSA to Rock conversion test with prefix causal attention support |
| mlir/test/Conversion/DetectFlashDecoding/detect-flash-decoding-kvcache.mlir | Flash Decoding detection pass test for KVCache scenario |
| mlir/test/Conversion/DetectFlashDecoding/detect-flash-decoding-kvcache-prefix-causal.mlir | Flash Decoding detection pass test with prefix causal attention |
| mlir/lib/Dialect/Rock/Transforms/DetectFlashDecoding.cpp | Core detection logic: renamed detectSplitKVFromV to detectSplitKVFromKV, added sliceSplitKVFromBatch helper, extended detection to K tensor, and added currentSeqLen/prefixOffset handling |
| mlir/lib/Conversion/TosaToRock/TosaToRock.cpp | Updated addBroadcastForBlockArg to handle both 4D and 5D attention layouts with 2 or 3 collapsed dimensions |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
mlir/test/fusion/pr-e2e/attention/mixr-attention-flash-decoding-kvcache.mlir
Outdated
Show resolved
Hide resolved
mlir/test/fusion/pr-e2e/attention/mixr-attention-flash-decoding-kvcache-prefix-causal.mlir
Outdated
Show resolved
Hide resolved
| // 2D case: [batch, numHeads] for 4D attention layout | ||
| // 3D case: [batch, numHeads, splitKV] for 5D attention layout | ||
| if (reassocIndices.empty() || | ||
| (reassocIndices[0].size() != 2 && reassocIndices[0].size() != 3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have a flag for flash decoding? could we check it's 3 only if flash decoding is enabled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a MIGraphX environment variable, but there is no flag that gets passed to rocMLIR as of yet. Is this something that we want to coordinate with MIGraphX for? It could allow us to do some additional safety checks to make sure that we are actually pattern matching flash decoding.
mlir/test/Conversion/DetectFlashDecoding/detect-flash-decoding-kvcache-prefix-causal.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Conversion/DetectFlashDecoding/detect-flash-decoding-kvcache.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Conversion/DetectFlashDecoding/detect-flash-decoding-feature-combination.mlir
Show resolved
Hide resolved
mlir/test/fusion/pr-e2e/attention/mixr-attention-flash-decoding-kvcache-prefix-causal.mlir
Outdated
Show resolved
Hide resolved
.../test/Conversion/TosaToRock/tosa-to-rock-attention-flash-decoding-kvcache-prefix-causal.mlir
Show resolved
Hide resolved
mlir/test/Conversion/TosaToRock/tosa-to-rock-attention-flash-decoding-kvcache.mlir
Show resolved
Hide resolved
mlir/test/fusion/pr-e2e/attention/mixr-attention-flash-decoding-kvcache.mlir
Outdated
Show resolved
Hide resolved
mlir/test/fusion/pr-e2e/attention/mixr-attention-flash-decoding-kvcache.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Conversion/TosaToRock/tosa-to-rock-attention-flash-decoding-kvcache.mlir
Outdated
Show resolved
Hide resolved
.../test/Conversion/TosaToRock/tosa-to-rock-attention-flash-decoding-kvcache-prefix-causal.mlir
Outdated
Show resolved
Hide resolved
fe3e70e to
5878fd6
Compare
Motivation
This PR implements the lowering changes required for MIGraphX to use Flash Decoding + KVCache together. Additionally, since KVCache and prefix causal attention are quite similar, I opted to also include the support for that combination as well.
This implements: https://github.com/ROCm/rocMLIR-internal/issues/2183
Technical Details
Changes to TosaToRock:
addBroadcastForBlockArgto handle additional 4D layoutsChanges to DetectFlashDecoding:
detectSplitKVFromV->detectSplitKVFromKVsliceSplitKVFromBatchthat removes the splitKV dimension by slicing at index 0DetectFlashDecodingPattern:Test Plan
Test Result
Submission Checklist