Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions transformer_engine/pytorch/triton_kernels/cast_transpose.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# License for AMD contributions = MIT. See LICENSE for more information

import torch
Expand Down Expand Up @@ -53,10 +53,10 @@ def _amax_reduce_triton(
A_ptrs = A + rm[:, None] * stride_am + rn[None, :] * stride_an
mask = (rm < M)[:, None] & (rn < N)[None, :]

a = tl.load(A_ptrs, mask=mask, other=0).to(tl.float32)
a = tl.load(A_ptrs, mask=mask, other=0)
tile_amax = tl.max(tl.abs(a))
# accumulate tile-wise max into global amax
tl.atomic_max(amax_ptr, tile_amax, sem='relaxed')
tl.atomic_max(amax_ptr, tile_amax.to(tl.float32), sem='relaxed')
Comment on lines +56 to +59
Copy link

Copilot AI Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR description still lists placeholder items (“Change A”, “Change B”) and doesn’t describe the actual change (moving float32 cast from load to the reduced amax result). Please update the description to match what this PR does so reviewers/users can understand intent and impact.

Copilot uses AI. Check for mistakes.


@triton.jit
Expand Down Expand Up @@ -229,11 +229,11 @@ def _amax_reduce_triton_stage1(
A_ptrs = A + rm[:, None] * stride_am + rn[None, :] * stride_an
mask = (rm < M)[:, None] & (rn < N)[None, :]

a = tl.load(A_ptrs, mask=mask, other=0).to(tl.float32)
a = tl.load(A_ptrs, mask=mask, other=0)
tile_amax = tl.max(tl.abs(a))

# Store per-program amax in workspace
tl.store(block_amax + pid, tile_amax)
tl.store(block_amax + pid, tile_amax.to(tl.float32))

if pid == 0:
tl.store(num_blocks, tl.num_programs(0))
Expand Down
Loading