Triton current scaling: avoid casting amax input#458
Triton current scaling: avoid casting amax input#458matthiasdiener wants to merge 3 commits intodevfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adjusts Triton current-scaling amax reduction to avoid casting the loaded input tile to float32, instead casting only the reduced tile_amax before atomic/store. This targets current-scaling behavior in the Triton cast+transpose path.
Changes:
- Remove
tl.float32cast ontl.load(...)in amax-reduction kernels. - Cast
tile_amaxtotl.float32only at the atomic/store sites.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| a = tl.load(A_ptrs, mask=mask, other=0) | ||
| tile_amax = tl.max(tl.abs(a)) | ||
| # accumulate tile-wise max into global amax | ||
| tl.atomic_max(amax_ptr, tile_amax, sem='relaxed') | ||
| tl.atomic_max(amax_ptr, tile_amax.to(tl.float32), sem='relaxed') |
There was a problem hiding this comment.
PR description still lists placeholder items (“Change A”, “Change B”) and doesn’t describe the actual change (moving float32 cast from load to the reduced amax result). Please update the description to match what this PR does so reviewers/users can understand intent and impact.
There was a problem hiding this comment.
Please update copyright date
wenchenvincent
left a comment
There was a problem hiding this comment.
Please run level 3 CI tests before merging.
I just started a level 3 test here: https://github.com/ROCm/TransformerEngine/actions/runs/22327178479 |
Description
Suggested by @ipanfilo
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: