-
Notifications
You must be signed in to change notification settings - Fork 52
[WIP] Remove redundant casts in LLVMIR #2202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
| %7 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> | ||
| llvm.store %7, %6 : vector<4xf16>, !llvm.ptr<5> | ||
| %8 = llvm.getelementptr %2[12] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 | ||
| %9 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are there so many repeated llvm.fptrunc %1? I don't understand this.
Wouldn't it be easier to do this earlier? for example handling arith.extf, etc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fptrunc's seem to be the result of loop unrolling. They are all writing to the same buffer. I was doing this earlier and there are quite a few more difficulties with moving this pass somewhere right after GridwiseGemmToBlockwise. The dominance analysis (used for safety) becomes tricky because it doesn't work well when the trunc/ext ops are in different regions, and also having to rewrite the linalg generic makes things more difficult as well.
Both approaches have their pros and cons. We can discuss this more in the team meeting, or elsewhere offline.
| // - If the narrow buffer has no remaining uses, erase the fptrunc stores | ||
| // - These can only be erased if they are not used by any other | ||
| // operations | ||
| // - Erase the narrow alloca if it has no remaining uses |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have tests for these two cases?
|
|
||
| // Look for existing parallel wide store | ||
| for (Operation *wideUser : wideValue.getUsers()) { | ||
| auto wideStore = dyn_cast<StoreOp>(wideUser); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we check it's the source instead of destination?
| info.wideStore = nullptr; | ||
|
|
||
| // Look for existing parallel wide store | ||
| for (Operation *wideUser : wideValue.getUsers()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this could be done outside of this loop and store the results in a SmallVector?
| << wideStore << "\n"); | ||
| info.wideBuffer = wideBuffer; | ||
| info.wideStore = wideStore; | ||
| break; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we store only the first one?
Motivation
When processing mixed-precision computations (e.g., attention kernels with f32 intermediate values stored as f16), the generated IR often contains redundant precision conversion patterns:
This pattern causes unnecessary precision loss compared to just keeping the original wide value. This pass eliminates these redundant casts by redirectoing loads to read from a parallel wide buffer when possible.
This implements: https://github.com/ROCm/rocMLIR-internal/issues/1932
Technical Details
This PR introduces the
RemoveRedundantCastspass that operates at the LLVMIR dialect level to optimize fptrunc -> store -> load -> fpext patterns.General Algorithm:
Test Plan
Test Result
Submission Checklist