-
Notifications
You must be signed in to change notification settings - Fork 52
Add squeeze, gather, and scatter_none ops to MIGraphX dialect #2176
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
|
Note for reviewers, this PR is operating under the assumption that we will have the upstream changes in soon so that we can have this change: llvm/llvm-project#167894 (adding i64 type support for indices in scatter/gather). To be conservative, I can keep the conversion from i64 -> i32 and just add a TODO and file a ticket to make that simple change once the upstream PR makes its way in. Let me know what you guys want. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds three new operations to the MIGraphX dialect as part of the paged attention work: squeeze, gather, and scatter_none. These ops are based on ONNX specifications and include complete implementations with verifiers, TOSA lowering patterns, and comprehensive test coverage.
- Implements ONNX-compatible squeeze (removing size-1 dimensions), gather (collecting slices along an axis), and scatter_none (updating elements at specified indices)
- Adds verifiers for axis bounds checking, shape validation, and index constraints
- Provides TOSA lowering implementations that reshape tensors to match TOSA's 3D gather/scatter semantics
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td | Adds TableGen definitions for squeeze, gather, and scatter_none ops with documentation and assembly formats |
| mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp | Implements verify() methods for the three new ops including axis bounds checking and shape validation |
| mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosa.cpp | Adds converter patterns to lower each new op to TOSA operations via reshape/gather/scatter transformations |
| mlir/test/Dialect/MIGraphX/invalid.mlir | Adds negative test cases for axis out of bounds, rank mismatches, and invalid indices |
| mlir/test/Conversion/MIGraphXToTosa/migraphx-to-tosa.mlir | Adds positive conversion tests for squeeze, gather, and scatter_none with various axis configurations |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| }]; | ||
| } | ||
|
|
||
| def MIGraphX_SqueezeOp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You dont need to add squeeze and unsqueeze as we already lower this to reshape when going to mlir.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The IR that I was basing this off of (which came from MIGraphX) was using the unlowered squeeze/unsqueeze. I think it's a fairly trivial lowering, so it wouldn't hurt if rocMLIR could also do something like this.
|
Outside of paged attention, will this enable us to fuse gather and scatter into gemms and convolutions? |
This PR in of itself won't get us to fusing gather/scatter operations. We would still need an additional backend lowering PR for this. |
| let hasVerifier = 1; | ||
| } | ||
|
|
||
| def MIGraphX_ScatterNoneOp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is it called ScatterNone?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MIGraphX was passing us a scatter_none op which directly correlates to the ScatterElements ONNX op with reduction set to none: https://onnx.ai/onnx/operators/onnx__ScatterElements.html
| } | ||
| } | ||
|
|
||
| // Build the new shape by excluding the squeezed axes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we check that axesToSqueeze are all dimension=1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The verifier already does this validation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or is that not enough? Since if I recall correctly the verifier isn't turned on by default?
| auto outputType = cast<RankedTensorType>(outputTy); | ||
| Type elemType = dataType.getElementType(); | ||
|
|
||
| // Lowering strategy for migraphx.gather -> tosa.gather: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not skip tosa instead of this workaround?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking about needing a lowering path through TOSA for the CPU path in the future (there currently isn't a lowering path that goes past Tosa). I've already opened up a ticket to address this: https://github.com/ROCm/rocMLIR-internal/issues/2205. It looks like IREE has already encountered this before and they have a custom lowering that maybe we could port/use.
| // Output: [N, W, C] where each [n, w, :] = reshaped_data[n, indices[n, w], :] | ||
| SmallVector<int64_t> gatherOutputShape = {N, W, C}; | ||
| auto gatherOutputType = RankedTensorType::get(gatherOutputShape, elemType); | ||
| Value gatherResult = tosa::GatherOp::create(rewriter, loc, gatherOutputType, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand we are missing tosa.gather -> rock? are we going to support this for paged attention for now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tosa.gather -> rock was going to come in a future PR that handled the TosaToRock changes. However, in recent conversations regarding paged attention, it seems like we are moving away from the scatter/gather implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need this PR then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pfultz2 Are there any plans for MIGraphX to require rocMLIR supporting gather/scatter outside of paged attention (think I might have heard something in one of the meetings yesterday)? If not, then I think we can close this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fusing gather for gemms and convolutions will be useful as well, especially for resize.
| if (axis < 0) | ||
| axis += dataRank; | ||
|
|
||
| // TOSA scatter requires that indices be constant across the "C" dimension |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, why not skip tosa?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment above about CPU lowering.
37703a4 to
a4462ca
Compare
Motivation
This PR adds three new ops to the MIGraphX dialect (squeeze, gather, and scatter_none). These new ops are required as part of the paged attention work.
This implements https://github.com/ROCm/rocMLIR-internal/issues/2200
Technical Details
This PR can be broken down into the following changes:
MIGraphX.td:migraphx.squeeze(https://onnx.ai/onnx/operators/onnx__Squeeze.html): Remove dimensions of size 1migraphx.gather(https://onnx.ai/onnx/operators/onnx__Gather.html): Gather slices from data along an axismigraphx.scatter_none(https://onnx.ai/onnx/operators/onnx__ScatterElements.html): Updates into data at specified indicesMIGraphX.cpp: Add verifiers for each of the new opsMIGraphXToTosa: Lowers each of the new ops to their set of TOSA opsTest Plan
Test Result
Submission Checklist