From 4b9a2227a9677f0c98c41301dcc02e5738574a7a Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Fri, 12 Dec 2025 16:22:01 +0800 Subject: [PATCH 1/3] Adds PyTorch GEMM helper Implements a reusable matrix multiply-and-accumulate routine to unify alpha*A@B + beta*C updates --- kernel_course/python_ops/gemm.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 kernel_course/python_ops/gemm.py diff --git a/kernel_course/python_ops/gemm.py b/kernel_course/python_ops/gemm.py new file mode 100644 index 0000000..d8047bb --- /dev/null +++ b/kernel_course/python_ops/gemm.py @@ -0,0 +1,28 @@ +import torch + + +def gemm( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + alpha: float, + beta: float, +) -> torch.Tensor: + """ + Updates tensor `C` by adding the product of matrices `A` and `B` + scaled by `alpha`, and `C` scaled by `beta`. + + Args: + A (torch.Tensor): First matrix tensor. + B (torch.Tensor): Second matrix tensor to be multiplied with `A`. + C (torch.Tensor): Matrix tensor to be updated. + alpha (float): Scaling factor for the product of `A` and `B`. + beta (float): Scaling factor for `C`. + + Returns: + torch.Tensor: The updated tensor `C`. + """ + + C = alpha * A @ B + beta * C + + return C From 3d10419d18d0dd6ea31689da66083a30b092c884 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Fri, 12 Dec 2025 16:22:08 +0800 Subject: [PATCH 2/3] Add documentation for GEMM kernel implementation --- docs/gemm.md | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 docs/gemm.md diff --git a/docs/gemm.md b/docs/gemm.md new file mode 100644 index 0000000..dbd17e1 --- /dev/null +++ b/docs/gemm.md @@ -0,0 +1,35 @@ +# GEMM Kernel + +The `gemm` operator computes the matrix-matrix product of two matrices. + +## Mathematical Definition + +Given input matrices `A` and `B`, along with an output matrix `C` and scalars `α` and `β`, the kernel evaluates + +$$ +C = \alpha A B + \beta C +$$ + +The matrix-matrix product is computed by multiplying the matrix `A` with the matrix `B`, scaling the result by `α`, scaling the matrix `C` by `β`, and then adding the two scaled results together to produce the updated matrix `C`. + +## Kernel Implementations + +- [Python Implementation](../kernel_course/python_ops/gemm.py) +- [PyTorch Implementation](../kernel_course/pytorch_ops/gemm.py) +- [Triton Implementation](../kernel_course/triton_ops/gemm.py) +- [CuTe Implementation](../kernel_course/cute_ops/gemm.py) + +All backends share the interface: + +```python +def gemm(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, alpha: float, beta: float) -> torch.Tensor: + ... +``` + +## Testing + +See the [test suite](../tests/test_gemm.py) for the validation harness that exercises every backend. + +```bash +pytest tests/test_gemm.py -s +``` From 3a230cb1e806ff7f02b6367c070ca2661747214e Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Fri, 12 Dec 2025 16:22:51 +0800 Subject: [PATCH 3/3] Update GEMM entry in README to include Python implementation and link --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 0110510..5eb4487 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ The following common BLAS kernels have been implemented in multiple frameworks. | [dot](./docs/dot.md) | dot product | $z = x^\top y$ | $2n$ | $2n$ | [✅](./kernel_course/python_ops/dot.py) | [✅](./kernel_course/pytorch_ops/dot.py) | [✅](./kernel_course/triton_ops/dot.py) | ❌ | [✅](./tests/test_dot.py) | | [gemv](./docs/gemv.md) | general matrix-vector multiply | $y = \alpha A x + \beta y$ | $2mn$ | $mn + n + 2m$ | [✅](./kernel_course/python_ops/gemv.py) | [✅](./kernel_course/pytorch_ops/gemv.py) | [✅](./kernel_course/triton_ops/gemv.py) | ❌ | [✅](./tests/test_gemv.py) | | [geru](./docs/geru.md) | general rank-1 update | $A = A + \alpha x y^\top$ | $2mn$ | $2mn + m + n$ | [✅](./kernel_course/python_ops/geru.py) | [✅](./kernel_course/pytorch_ops/geru.py) | [✅](./kernel_course/triton_ops/geru.py) | ❌ | [✅](./tests/test_geru.py) | -| gemm | general matrix-matrix multiply | $C = \alpha A B + \beta C$ | $2mnk$ | $mk + nk + 2mn$ | ❌ | ❌ | ❌ | ❌ | ❌ | +| [gemm](./docs/gemm.md) | general matrix-matrix multiply | $C = \alpha A B + \beta C$ | $2mnk$ | $mk + nk + 2mn$ | [✅](./kernel_course/python_ops/gemm.py) | ❌ | ❌ | ❌ | ❌ | ## Transformer Modules