Skip to content

Conversation

@justinrosner
Copy link
Contributor

@justinrosner justinrosner commented Dec 16, 2025

Motivation

This PR implements the lowering changes required for MIGraphX to use Flash Decoding + KVCache together. Additionally, since KVCache and prefix causal attention are quite similar, I opted to also include the support for that combination as well.

This implements: https://github.com/ROCm/rocMLIR-internal/issues/2183

Technical Details

Changes to TosaToRock:

  • Update addBroadcastForBlockArg to handle additional 4D layouts

Changes to DetectFlashDecoding:

  • Renamed and generalized detectSplitKVFromV -> detectSplitKVFromKV
  • New helper sliceSplitKVFromBatch that removes the splitKV dimension by slicing at index 0
  • Updated detection logic in DetectFlashDecodingPattern:
    • Now detects splitKV from all three tensors and requires that one of K or V match with Q

Test Plan

  • E2E example from MIGraphX is working as intended
  • PR CI

Test Result

  • E2E MIGraphX example
  • PR CI

Submission Checklist

@justinrosner
Copy link
Contributor Author

Note: There is still some testing with MIGraphX that is going on (need the second kernel that isn't exposed to rocMLIR). I will convert this to a regular PR when those results come back clean

@justinrosner justinrosner changed the title [WIP] MIGraphX support for Flash Decoding + KVCache MIGraphX support for Flash Decoding + KVCache Dec 17, 2025
@justinrosner justinrosner marked this pull request as ready for review December 17, 2025 16:11
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR implements Flash Decoding with KVCache support for MIGraphX, enabling the combination of these two optimizations for efficient attention operations. The implementation also includes support for prefix causal attention when used with Flash Decoding and KVCache.

Key changes:

  • Generalized splitKV detection to work with K and V tensors (previously only V)
  • Added support for 4D and 5D attention layouts with splitKV dimension handling
  • Implemented transformation logic for currentSeqLen and prefixOffset tensors to slice away the splitKV dimension

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
mlir/test/fusion/pr-e2e/attention/mixr-attention-flash-decoding-kvcache.mlir End-to-end test for Flash Decoding with KVCache combination
mlir/test/fusion/pr-e2e/attention/mixr-attention-flash-decoding-kvcache-prefix-causal.mlir End-to-end test for Flash Decoding with KVCache and prefix causal attention
mlir/test/Conversion/TosaToRock/tosa-to-rock-attention-flash-decoding-kvcache.mlir TOSA to Rock conversion test for Flash Decoding with KVCache
mlir/test/Conversion/TosaToRock/tosa-to-rock-attention-flash-decoding-kvcache-prefix-causal.mlir TOSA to Rock conversion test with prefix causal attention support
mlir/test/Conversion/DetectFlashDecoding/detect-flash-decoding-kvcache.mlir Flash Decoding detection pass test for KVCache scenario
mlir/test/Conversion/DetectFlashDecoding/detect-flash-decoding-kvcache-prefix-causal.mlir Flash Decoding detection pass test with prefix causal attention
mlir/lib/Dialect/Rock/Transforms/DetectFlashDecoding.cpp Core detection logic: renamed detectSplitKVFromV to detectSplitKVFromKV, added sliceSplitKVFromBatch helper, extended detection to K tensor, and added currentSeqLen/prefixOffset handling
mlir/lib/Conversion/TosaToRock/TosaToRock.cpp Updated addBroadcastForBlockArg to handle both 4D and 5D attention layouts with 2 or 3 collapsed dimensions

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

// 2D case: [batch, numHeads] for 4D attention layout
// 3D case: [batch, numHeads, splitKV] for 5D attention layout
if (reassocIndices.empty() ||
(reassocIndices[0].size() != 2 && reassocIndices[0].size() != 3))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have a flag for flash decoding? could we check it's 3 only if flash decoding is enabled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a MIGraphX environment variable, but there is no flag that gets passed to rocMLIR as of yet. Is this something that we want to coordinate with MIGraphX for? It could allow us to do some additional safety checks to make sure that we are actually pattern matching flash decoding.

@justinrosner justinrosner force-pushed the 2183-flash-decoding-kvcache branch from fe3e70e to 5878fd6 Compare January 12, 2026 20:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants