Skip to content

Conversation

@justinrosner
Copy link
Contributor

@justinrosner justinrosner commented Dec 17, 2025

Motivation

This PR adds three new ops to the MIGraphX dialect (squeeze, gather, and scatter_none). These new ops are required as part of the paged attention work.

This implements https://github.com/ROCm/rocMLIR-internal/issues/2200

Technical Details

This PR can be broken down into the following changes:

Test Plan

  • Add new LIT tests for each of the new ops that was added

Test Result

  • PR CI

Submission Checklist

@justinrosner
Copy link
Contributor Author

justinrosner commented Dec 17, 2025

Note for reviewers, this PR is operating under the assumption that we will have the upstream changes in soon so that we can have this change: llvm/llvm-project#167894 (adding i64 type support for indices in scatter/gather). To be conservative, I can keep the conversion from i64 -> i32 and just add a TODO and file a ticket to make that simple change once the upstream PR makes its way in. Let me know what you guys want.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds three new operations to the MIGraphX dialect as part of the paged attention work: squeeze, gather, and scatter_none. These ops are based on ONNX specifications and include complete implementations with verifiers, TOSA lowering patterns, and comprehensive test coverage.

  • Implements ONNX-compatible squeeze (removing size-1 dimensions), gather (collecting slices along an axis), and scatter_none (updating elements at specified indices)
  • Adds verifiers for axis bounds checking, shape validation, and index constraints
  • Provides TOSA lowering implementations that reshape tensors to match TOSA's 3D gather/scatter semantics

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td Adds TableGen definitions for squeeze, gather, and scatter_none ops with documentation and assembly formats
mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp Implements verify() methods for the three new ops including axis bounds checking and shape validation
mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosa.cpp Adds converter patterns to lower each new op to TOSA operations via reshape/gather/scatter transformations
mlir/test/Dialect/MIGraphX/invalid.mlir Adds negative test cases for axis out of bounds, rank mismatches, and invalid indices
mlir/test/Conversion/MIGraphXToTosa/migraphx-to-tosa.mlir Adds positive conversion tests for squeeze, gather, and scatter_none with various axis configurations

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@justinrosner justinrosner changed the title [WIP] Add squeeze, gather, and scatter_none ops to MIGraphX dialect Add squeeze, gather, and scatter_none ops to MIGraphX dialect Jan 2, 2026
}];
}

def MIGraphX_SqueezeOp
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You dont need to add squeeze and unsqueeze as we already lower this to reshape when going to mlir.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The IR that I was basing this off of (which came from MIGraphX) was using the unlowered squeeze/unsqueeze. I think it's a fairly trivial lowering, so it wouldn't hurt if rocMLIR could also do something like this.

@pfultz2
Copy link

pfultz2 commented Jan 6, 2026

Outside of paged attention, will this enable us to fuse gather and scatter into gemms and convolutions?

@justinrosner
Copy link
Contributor Author

Outside of paged attention, will this enable us to fuse gather and scatter into gemms and convolutions?

This PR in of itself won't get us to fusing gather/scatter operations. We would still need an additional backend lowering PR for this.

let hasVerifier = 1;
}

def MIGraphX_ScatterNoneOp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it called ScatterNone?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MIGraphX was passing us a scatter_none op which directly correlates to the ScatterElements ONNX op with reduction set to none: https://onnx.ai/onnx/operators/onnx__ScatterElements.html

}
}

// Build the new shape by excluding the squeezed axes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check that axesToSqueeze are all dimension=1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The verifier already does this validation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or is that not enough? Since if I recall correctly the verifier isn't turned on by default?

auto outputType = cast<RankedTensorType>(outputTy);
Type elemType = dataType.getElementType();

// Lowering strategy for migraphx.gather -> tosa.gather:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not skip tosa instead of this workaround?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about needing a lowering path through TOSA for the CPU path in the future (there currently isn't a lowering path that goes past Tosa). I've already opened up a ticket to address this: https://github.com/ROCm/rocMLIR-internal/issues/2205. It looks like IREE has already encountered this before and they have a custom lowering that maybe we could port/use.

// Output: [N, W, C] where each [n, w, :] = reshaped_data[n, indices[n, w], :]
SmallVector<int64_t> gatherOutputShape = {N, W, C};
auto gatherOutputType = RankedTensorType::get(gatherOutputShape, elemType);
Value gatherResult = tosa::GatherOp::create(rewriter, loc, gatherOutputType,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand we are missing tosa.gather -> rock? are we going to support this for paged attention for now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tosa.gather -> rock was going to come in a future PR that handled the TosaToRock changes. However, in recent conversations regarding paged attention, it seems like we are moving away from the scatter/gather implementation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this PR then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pfultz2 Are there any plans for MIGraphX to require rocMLIR supporting gather/scatter outside of paged attention (think I might have heard something in one of the meetings yesterday)? If not, then I think we can close this.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fusing gather for gemms and convolutions will be useful as well, especially for resize.

if (axis < 0)
axis += dataRank;

// TOSA scatter requires that indices be constant across the "C" dimension
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, why not skip tosa?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above about CPU lowering.

@justinrosner justinrosner force-pushed the 2200-migraphx-paged-attention-ops branch from 37703a4 to a4462ca Compare January 8, 2026 18:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants