From ddc9f704bfb57c249b83c6a79f538961a2e20192 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 16 Dec 2025 11:33:22 +0800 Subject: [PATCH 1/2] Implements PyTorch GEMM helper Adds a matrix update utility that composes matmul with alpha and beta scaling to enable reuse of GEMM semantics in PyTorch flows --- kernel_course/pytorch_ops/gemm.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 kernel_course/pytorch_ops/gemm.py diff --git a/kernel_course/pytorch_ops/gemm.py b/kernel_course/pytorch_ops/gemm.py new file mode 100644 index 0000000..9d92976 --- /dev/null +++ b/kernel_course/pytorch_ops/gemm.py @@ -0,0 +1,28 @@ +import torch + + +def gemm( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + alpha: float, + beta: float, +) -> torch.Tensor: + """ + Updates tensor `C` by adding the product of matrices `A` and `B` + scaled by `alpha`, and `C` scaled by `beta` using PyTorch operations. + + Args: + A (torch.Tensor): First matrix tensor. + B (torch.Tensor): Second matrix tensor to be multiplied with `A`. + C (torch.Tensor): Matrix tensor to be updated. + alpha (float): Scaling factor for the product of `A` and `B`. + beta (float): Scaling factor for `C`. + + Returns: + torch.Tensor: The updated tensor `C`. + """ + + C = torch.add(torch.mul(alpha, torch.matmul(A, B)), torch.mul(beta, C)) + + return C From 87753166292ff7523b3a49ada59da20572bb9320 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 16 Dec 2025 11:34:25 +0800 Subject: [PATCH 2/2] Update GEMM entry in README to reflect PyTorch implementation status --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 5eb4487..f2097bc 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ The following common BLAS kernels have been implemented in multiple frameworks. | [dot](./docs/dot.md) | dot product | $z = x^\top y$ | $2n$ | $2n$ | [✅](./kernel_course/python_ops/dot.py) | [✅](./kernel_course/pytorch_ops/dot.py) | [✅](./kernel_course/triton_ops/dot.py) | ❌ | [✅](./tests/test_dot.py) | | [gemv](./docs/gemv.md) | general matrix-vector multiply | $y = \alpha A x + \beta y$ | $2mn$ | $mn + n + 2m$ | [✅](./kernel_course/python_ops/gemv.py) | [✅](./kernel_course/pytorch_ops/gemv.py) | [✅](./kernel_course/triton_ops/gemv.py) | ❌ | [✅](./tests/test_gemv.py) | | [geru](./docs/geru.md) | general rank-1 update | $A = A + \alpha x y^\top$ | $2mn$ | $2mn + m + n$ | [✅](./kernel_course/python_ops/geru.py) | [✅](./kernel_course/pytorch_ops/geru.py) | [✅](./kernel_course/triton_ops/geru.py) | ❌ | [✅](./tests/test_geru.py) | -| [gemm](./docs/gemm.md) | general matrix-matrix multiply | $C = \alpha A B + \beta C$ | $2mnk$ | $mk + nk + 2mn$ | [✅](./kernel_course/python_ops/gemm.py) | ❌ | ❌ | ❌ | ❌ | +| [gemm](./docs/gemm.md) | general matrix-matrix multiply | $C = \alpha A B + \beta C$ | $2mnk$ | $mk + nk + 2mn$ | [✅](./kernel_course/python_ops/gemm.py) | [✅](./kernel_course/pytorch_ops/gemm.py) | ❌ | ❌ | ❌ | ## Transformer Modules