Conversation
This reverts commit 86fbbac.
2095d3f to
ebc005f
Compare
d1ab38e to
0b16287
Compare
tests/pytorch/test_numerics.py
Outdated
| ) | ||
| if IS_HIP_EXTENSION: | ||
| from transformer_engine.pytorch.utils import is_mi200, is_mi308 | ||
| from transformer_engine.pytorch.utils import is_mi200, is_mi308, is_mi300_class |
There was a problem hiding this comment.
is_mi300_class methods is not needed, it is just 9.4 gfx family
| @@ -0,0 +1,276 @@ | |||
| /* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. */ | |||
There was a problem hiding this comment.
Add proper copyright header
| @@ -0,0 +1,11 @@ | |||
| /* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. */ | |||
There was a problem hiding this comment.
Put proper copyright header
| #endif | ||
|
|
||
| const int current_device = transformer_engine::cuda::current_device(); | ||
| const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90); |
There was a problem hiding this comment.
These constants are not used on ROCm
| return true; | ||
| }; | ||
|
|
||
| auto all_groups_uniform_k128 = [&](const NVTETensor *p, bool trans) -> bool { |
There was a problem hiding this comment.
Right, would you like me to #ifdef this function out (it is coming from upstream)?
| auto A_dt = inputA->data.dtype; | ||
| auto B_dt = inputB->data.dtype; | ||
| auto D_dt = OutputD->data.dtype; | ||
| return (A_dt == B_dt) && (A_dt == D_dt) && |
There was a problem hiding this comment.
Are CK tile constraints the same as CUTLASS?
There was a problem hiding this comment.
In terms of supported data types (which this function handles), yes - only bf16/fp16 are supported.
| } | ||
|
|
||
| // Normalize similar to upstream | ||
| // See https://github.com/NVIDIA/TransformerEngine/blob/59f6f3876767d07045152bfae07b5dd4c54e1725/transformer_engine/common/gemm/cutlass_grouped_gemm.cu#L54-L68 |
There was a problem hiding this comment.
There is no similar code in referenced upstream file. And can you explain transA_use = transB and vice versa
There was a problem hiding this comment.
In the referenced upstream code, the same swap is performed. For example, consider the following case in the upstream code:
} else if (!transb && transa) {
grouped_gemm::CutlassGroupedGemm<false, true, T>(B, A, D, workspace, alpha, beta, num_gemms,
stream, device, math_sm_count);Here, transa==true and transb==false, but they get passed into the template as transa==false and transb==true, and A and B are swapped in the function call itself.
My best understanding regarding why the swap needs to be performed is that it matches the BLAS semantics regarding column-major storage (see e.g. https://rocm.docs.amd.com/projects/rocBLAS/en/latest/conceptual/rocblas-design-notes.html#column-major-storage-and-1-based-indexing).
| size_t workspace_bytes, | ||
| hipStream_t stream) { | ||
|
|
||
| // FIXME: This could be a templated lambda function in C++20. |
There was a problem hiding this comment.
As an alternative dispatch_grouped can be incorporated to ck_tile_grouped_gemm with using of nested TRANSFORMER_ENGINE_SWITCH_CONDITION
Description
See https://github.com/ROCm/frameworks-internal/issues/13792 for context.
Primus-Turbo implementation: https://github.com/AMD-AGI/Primus-Turbo/blob/5bcd13785ef380fec0eec0911b7d6db5e606143e/csrc/kernels/grouped_gemm
TODOs:
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: