From 3c8c21573d8bd2f865138c3d90fed36098ec7662 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Wed, 17 Dec 2025 19:59:58 +0800 Subject: [PATCH 1/2] Adds Triton GEMM implementation Introduces a Triton-kernelized GEMM to update matrices with alpha/beta scaling, complete with autotuning, boundary masking, and torch integration to pave the way for efficient compute-bound workloads. --- kernel_course/triton_ops/gemm.py | 163 +++++++++++++++++++++++++++++++ 1 file changed, 163 insertions(+) create mode 100644 kernel_course/triton_ops/gemm.py diff --git a/kernel_course/triton_ops/gemm.py b/kernel_course/triton_ops/gemm.py new file mode 100644 index 0000000..724324e --- /dev/null +++ b/kernel_course/triton_ops/gemm.py @@ -0,0 +1,163 @@ +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 64, "BLOCK_K": 64, "BLOCK_N": 64}, num_warps=4, num_stages=2 + ), + ], + key=["n_elements_M", "n_elements_K", "n_elements_N"], +) +@triton.heuristics( + { + "EVEN_M": lambda args: args["n_elements_M"] % args["BLOCK_M"] == 0, + "EVEN_K": lambda args: args["n_elements_K"] % args["BLOCK_K"] == 0, + "EVEN_N": lambda args: args["n_elements_N"] % args["BLOCK_N"] == 0, + } +) +@triton.jit +def gemm_kernel( + A, + B, + C, + stride_am, + stride_ak, + stride_bn, + stride_bk, + stride_cm, + stride_cn, + alpha, + beta, + n_elements_M, + n_elements_K, + n_elements_N, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_N: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_K: tl.constexpr, + EVEN_N: tl.constexpr, +): + # There are multiple program processing different blocks of data + # We identify which program we are in using program_id + start_m = tl.program_id(0) + start_n = tl.program_id(1) + # This program will process inputs that offset from the initial pointer + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_kb = tl.arange(0, BLOCK_K) + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + # Initialize pointers to the start of the blocks + a_ptr = A + offs_m[:, None] * stride_am + tl.arange(0, BLOCK_K)[None, :] * stride_ak + b_ptr = B + tl.arange(0, BLOCK_K)[:, None] * stride_bk + offs_n[None, :] * stride_bn + c_ptr = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + # Create a mask to guard memory operations against out-of-bounds accesses + mask_m = offs_m < n_elements_M + mask_n = offs_n < n_elements_N + # Initialize the accumulator + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + end_k = n_elements_K + # Loop over K dimension + for start_k in range(0, end_k, BLOCK_K): + # Align start_k to a multiple of BLOCK_K for efficient memory access + start_k = tl.multiple_of(start_k, BLOCK_K) + # This program will process inputs that offset from the initial pointer + offs_k = start_k + offs_kb + # Create a mask to guard memory operations against out-of-bounds accesses + mask_k = offs_k < n_elements_K + # Load a block of A and B from DRAM, masking out any extra elements in case the input is not a multiple of the block size + if EVEN_M & EVEN_K: + a = tl.load(a_ptr + start_k * stride_ak) + else: + a = tl.load( + a_ptr + start_k * stride_ak, + mask=mask_m[:, None] & mask_k[None, :], + other=0.0, + ) + if EVEN_N & EVEN_K: + b = tl.load(b_ptr + start_k * stride_bk) + else: + b = tl.load( + b_ptr + start_k * stride_bk, + mask=mask_k[:, None] & mask_n[None, :], + other=0.0, + ) + # Perform the matrix multiplication for the current block and accumulate the result + acc += tl.dot(a, b) + # Load C from DRAM, masking out any extra elements in case the input is not a multiple of the block size + if EVEN_M & EVEN_N: + c = tl.load(c_ptr) + else: + c = tl.load( + c_ptr, + mask=mask_m[:, None] & mask_n[None, :], + other=0.0, + ) + # Compute C = alpha * A * B + beta * C + c = beta * c + c += alpha * acc + # Store the updated C back to DRAM + if EVEN_M & EVEN_N: + tl.store(c_ptr, c) + else: + tl.store( + c_ptr, + c, + mask=mask_m[:, None] & mask_n[None, :], + ) + + +def gemm( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + alpha: float, + beta: float, +) -> torch.Tensor: + """ + Updates tensor `C` by adding the product of matrices `A` and `B` + scaled by `alpha`, and `C` scaled by `beta` using Triton operations. + + Args: + A (torch.Tensor): First matrix tensor. + B (torch.Tensor): Second matrix tensor to be multiplied with `A`. + C (torch.Tensor): Matrix tensor to be updated. + alpha (float): Scaling factor for the product of `A` and `B`. + beta (float): Scaling factor for `C`. + + Returns: + torch.Tensor: The updated tensor `C`. + """ + + # Calculate the number of elements in the input tensors + n_elements_M, n_elements_K = A.shape + n_elements_K, n_elements_N = B.shape + + # The SPMD grid is a 2D grid where each program computes a BLOCK_M x BLOCK_N block of the output matrix C + def grid(meta): + return ( + triton.cdiv(n_elements_M, meta["BLOCK_M"]), + triton.cdiv(n_elements_N, meta["BLOCK_N"]), + ) + + # Launch the Triton kernel + gemm_kernel[grid]( + A, + B, + C, + A.stride(0), + A.stride(1), + B.stride(1), + B.stride(0), + C.stride(0), + C.stride(1), + alpha, + beta, + n_elements_M, + n_elements_K, + n_elements_N, + ) + + return C From 603d8662a0f4094baa12741b24e265fe9c44f669 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Wed, 17 Dec 2025 20:00:25 +0800 Subject: [PATCH 2/2] Update GEMM entry in README to reflect Triton implementation status --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index f2097bc..b54c4ab 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ The following common BLAS kernels have been implemented in multiple frameworks. | [dot](./docs/dot.md) | dot product | $z = x^\top y$ | $2n$ | $2n$ | [✅](./kernel_course/python_ops/dot.py) | [✅](./kernel_course/pytorch_ops/dot.py) | [✅](./kernel_course/triton_ops/dot.py) | ❌ | [✅](./tests/test_dot.py) | | [gemv](./docs/gemv.md) | general matrix-vector multiply | $y = \alpha A x + \beta y$ | $2mn$ | $mn + n + 2m$ | [✅](./kernel_course/python_ops/gemv.py) | [✅](./kernel_course/pytorch_ops/gemv.py) | [✅](./kernel_course/triton_ops/gemv.py) | ❌ | [✅](./tests/test_gemv.py) | | [geru](./docs/geru.md) | general rank-1 update | $A = A + \alpha x y^\top$ | $2mn$ | $2mn + m + n$ | [✅](./kernel_course/python_ops/geru.py) | [✅](./kernel_course/pytorch_ops/geru.py) | [✅](./kernel_course/triton_ops/geru.py) | ❌ | [✅](./tests/test_geru.py) | -| [gemm](./docs/gemm.md) | general matrix-matrix multiply | $C = \alpha A B + \beta C$ | $2mnk$ | $mk + nk + 2mn$ | [✅](./kernel_course/python_ops/gemm.py) | [✅](./kernel_course/pytorch_ops/gemm.py) | ❌ | ❌ | ❌ | +| [gemm](./docs/gemm.md) | general matrix-matrix multiply | $C = \alpha A B + \beta C$ | $2mnk$ | $mk + nk + 2mn$ | [✅](./kernel_course/python_ops/gemm.py) | [✅](./kernel_course/pytorch_ops/gemm.py) | [✅](./kernel_course/triton_ops/gemm.py) | ❌ | ❌ | ## Transformer Modules