Skip to content

Comments

Grouped GEMM with ck_tile#434

Open
matthiasdiener wants to merge 46 commits intodevfrom
ck-grouped-gemm
Open

Grouped GEMM with ck_tile#434
matthiasdiener wants to merge 46 commits intodevfrom
ck-grouped-gemm

Conversation

@matthiasdiener
Copy link
Contributor

@matthiasdiener matthiasdiener commented Jan 28, 2026

Description

See https://github.com/ROCm/frameworks-internal/issues/13792 for context.

Primus-Turbo implementation: https://github.com/AMD-AGI/Primus-Turbo/blob/5bcd13785ef380fec0eec0911b7d6db5e606143e/csrc/kernels/grouped_gemm

TODOs:

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Implement ck_tile-based group GEMM, similar to Cutlass

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@matthiasdiener matthiasdiener self-assigned this Jan 28, 2026
@matthiasdiener matthiasdiener changed the title [WIP] proof-of-concept: grouped GEMM with ck_tile [WIP] Grouped GEMM with ck_tile Jan 29, 2026
@matthiasdiener matthiasdiener changed the title [WIP] Grouped GEMM with ck_tile Grouped GEMM with ck_tile Feb 11, 2026
@matthiasdiener matthiasdiener marked this pull request as ready for review February 17, 2026 22:58
)
if IS_HIP_EXTENSION:
from transformer_engine.pytorch.utils import is_mi200, is_mi308
from transformer_engine.pytorch.utils import is_mi200, is_mi308, is_mi300_class
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_mi300_class methods is not needed, it is just 9.4 gfx family

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed in 7910038

@@ -0,0 +1,276 @@
/* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. */
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add proper copyright header

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, done in f680d6a

@@ -0,0 +1,11 @@
/* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. */
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put proper copyright header

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, done in f680d6a

#endif

const int current_device = transformer_engine::cuda::current_device();
const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These constants are not used on ROCm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disabled them in e8ebb0e.

return true;
};

auto all_groups_uniform_k128 = [&](const NVTETensor *p, bool trans) -> bool {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused on ROCm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, would you like me to #ifdef this function out (it is coming from upstream)?

auto A_dt = inputA->data.dtype;
auto B_dt = inputB->data.dtype;
auto D_dt = OutputD->data.dtype;
return (A_dt == B_dt) && (A_dt == D_dt) &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are CK tile constraints the same as CUTLASS?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In terms of supported data types (which this function handles), yes - only bf16/fp16 are supported.

}

// Normalize similar to upstream
// See https://github.com/NVIDIA/TransformerEngine/blob/59f6f3876767d07045152bfae07b5dd4c54e1725/transformer_engine/common/gemm/cutlass_grouped_gemm.cu#L54-L68
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no similar code in referenced upstream file. And can you explain transA_use = transB and vice versa

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the referenced upstream code, the same swap is performed. For example, consider the following case in the upstream code:

} else if (!transb && transa) {
      grouped_gemm::CutlassGroupedGemm<false, true, T>(B, A, D, workspace, alpha, beta, num_gemms,
                                                       stream, device, math_sm_count);

Here, transa==true and transb==false, but they get passed into the template as transa==false and transb==true, and A and B are swapped in the function call itself.

My best understanding regarding why the swap needs to be performed is that it matches the BLAS semantics regarding column-major storage (see e.g. https://rocm.docs.amd.com/projects/rocBLAS/en/latest/conceptual/rocblas-design-notes.html#column-major-storage-and-1-based-indexing).

size_t workspace_bytes,
hipStream_t stream) {

// FIXME: This could be a templated lambda function in C++20.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an alternative dispatch_grouped can be incorporated to ck_tile_grouped_gemm with using of nested TRANSFORMER_ENGINE_SWITCH_CONDITION

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of 6d85088?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants