Skip to content

Conversation

@justinrosner
Copy link
Contributor

@justinrosner justinrosner commented Jan 8, 2026

Motivation

When processing mixed-precision computations (e.g., attention kernels with f32 intermediate values stored as f16), the generated IR often contains redundant precision conversion patterns:

%wide = ...                           ; f32 computation result
%narrow = llvm.fptrunc %wide : f32 to f16
llvm.store %narrow, %narrow_buf       ; store truncated value
...
%loaded = llvm.load %narrow_buf       ; load truncated value  
%extended = llvm.fpext %loaded : f16 to f32  ; extend back to f32

This pattern causes unnecessary precision loss compared to just keeping the original wide value. This pass eliminates these redundant casts by redirectoing loads to read from a parallel wide buffer when possible.

This implements: https://github.com/ROCm/rocMLIR-internal/issues/1932

Technical Details

This PR introduces the RemoveRedundantCasts pass that operates at the LLVMIR dialect level to optimize fptrunc -> store -> load -> fpext patterns.

General Algorithm:

  1. Find all fptrunc -> store patterns in the function. For each pattern, record whether there's already a parallel store of the wide value to a separate buffer.
  2. Find all load -> fpext patterns where the load is from a buffer that has fptrunc stores.
  3. Verify safety for each load+fpext pattern:
    • All stores to the narrow buffer must be from tracked fptrunc patterns (i.e., no untracked stores that could write different values)
    • All tracked stores must dominate the load
    • The narrow buffer must be an alloca
  4. For safe patterns, create a wide buffer and the corresponding stores if they don't exist. If a parallel store already exists, reuse it:
    • Create a wide alloca right after the narrow alloca
    • For each fptrunc store, insert a store of the wide value to the wide buffer (right after the narrow store, using the same indices)
  5. Apply the transformation:
    • Redirect the load to read from the wide buffer instead
    • Replace uses of the fpext result with the wide load result
    • Delete the fpext (and the old load/GEP if unused)
  6. Clean up unused narrow buffer operations:
    • If the narrow buffer has no remaining uses, erase the fptrunc stores
      • These can only be erased if they are not used by any other operations
    • Erase the narrow alloca if it has no remaining uses

Test Plan

Test Result

  • Nightly CI

Submission Checklist

@justinrosner justinrosner mentioned this pull request Jan 8, 2026
3 tasks
%7 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16>
llvm.store %7, %6 : vector<4xf16>, !llvm.ptr<5>
%8 = llvm.getelementptr %2[12] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16
%9 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are there so many repeated llvm.fptrunc %1? I don't understand this.
Wouldn't it be easier to do this earlier? for example handling arith.extf, etc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fptrunc's seem to be the result of loop unrolling. They are all writing to the same buffer. I was doing this earlier and there are quite a few more difficulties with moving this pass somewhere right after GridwiseGemmToBlockwise. The dominance analysis (used for safety) becomes tricky because it doesn't work well when the trunc/ext ops are in different regions, and also having to rewrite the linalg generic makes things more difficult as well.

Both approaches have their pros and cons. We can discuss this more in the team meeting, or elsewhere offline.

// - If the narrow buffer has no remaining uses, erase the fptrunc stores
// - These can only be erased if they are not used by any other
// operations
// - Erase the narrow alloca if it has no remaining uses
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have tests for these two cases?


// Look for existing parallel wide store
for (Operation *wideUser : wideValue.getUsers()) {
auto wideStore = dyn_cast<StoreOp>(wideUser);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check it's the source instead of destination?

info.wideStore = nullptr;

// Look for existing parallel wide store
for (Operation *wideUser : wideValue.getUsers()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this could be done outside of this loop and store the results in a SmallVector?

<< wideStore << "\n");
info.wideBuffer = wideBuffer;
info.wideStore = wideStore;
break;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we store only the first one?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants