From 39536470af646dab4be3a8c840b789ec9e6a9ba8 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Fri, 19 Dec 2025 18:14:42 +0800 Subject: [PATCH 1/2] Adds GEMM benchmark tests Introduces parameterized GEMM benchmarks that compare multiple backend implementations across devices and dtypes to validate performance regressions early --- tests/test_gemm.py | 95 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 tests/test_gemm.py diff --git a/tests/test_gemm.py b/tests/test_gemm.py new file mode 100644 index 0000000..908aff1 --- /dev/null +++ b/tests/test_gemm.py @@ -0,0 +1,95 @@ +import pytest +import torch + +from kernel_course import testing +from kernel_course.python_ops import gemm as python_gemm + +try: + from kernel_course.pytorch_ops import gemm as pytorch_gemm + + HAS_PYTORCH = True +except Exception: + pytorch_gemm = None + HAS_PYTORCH = False + +try: + from kernel_course.triton_ops import gemm as triton_gemm + + HAS_TRITON = True +except Exception: + triton_gemm = None + HAS_TRITON = False + +try: + from kernel_course.cute_ops import gemm as cute_gemm + + HAS_CUTE = True +except Exception: + cute_gemm = None + HAS_CUTE = False + + +def factory( + MNK: tuple[int, int, int], + device: torch.device, + dtype: torch.dtype = torch.float32, +): + M, N, K = MNK + A = torch.linspace(0.0, 1.0, steps=M * K, device=device, dtype=dtype).view(M, K) + B = torch.linspace(0.0, 1.0, steps=K * N, device=device, dtype=dtype).view(K, N) + C = torch.linspace(0.0, 1.0, steps=M * N, device=device, dtype=dtype).view(M, N) + alpha = 1.14 + beta = 5.14 + return (A, B, C, alpha, beta), {} + + +@pytest.mark.parametrize( + "device", + [ + pytest.param( + torch.device("cuda"), + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="requires CUDA" + ), + ), + pytest.param( + torch.device("mps"), + marks=pytest.mark.skipif( + not torch.backends.mps.is_available(), reason="requires MPS" + ), + ), + ], +) +@pytest.mark.parametrize( + "dtype", + [torch.float32, torch.float16, torch.bfloat16], +) +@pytest.mark.parametrize( + "MNK", + [ + (1 << 4, 1 << 4, 1 << 4), + (1 << 8, 1 << 8, 1 << 8), + ], +) +def test_gemm_benchmark( + device: torch.device, + dtype: torch.dtype, + MNK: tuple[int, int, int], +) -> None: + impls = testing.get_impls( + python_impl=python_gemm.gemm, + pytorch_impl=pytorch_gemm.gemm if HAS_PYTORCH else None, + triton_impl=triton_gemm.gemm if HAS_TRITON else None, + cute_impl=cute_gemm.gemm if HAS_CUTE else None, + ) + + # Benchmark each implementation + config = testing.BenchmarkConfig(warmup=3, repeat=100) + results = testing.run_benchmarks( + impls, + lambda: factory(MNK, device, dtype), + flops=2 * MNK[0] * MNK[1] * MNK[2], + config=config, + ) + + testing.show_benchmarks(results) From 1c44286f63c59c38cdbd77708c304bff06696f38 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Fri, 19 Dec 2025 18:14:48 +0800 Subject: [PATCH 2/2] Adds test entry for GEMM implementation in README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b54c4ab..4d0f140 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ The following common BLAS kernels have been implemented in multiple frameworks. | [dot](./docs/dot.md) | dot product | $z = x^\top y$ | $2n$ | $2n$ | [✅](./kernel_course/python_ops/dot.py) | [✅](./kernel_course/pytorch_ops/dot.py) | [✅](./kernel_course/triton_ops/dot.py) | ❌ | [✅](./tests/test_dot.py) | | [gemv](./docs/gemv.md) | general matrix-vector multiply | $y = \alpha A x + \beta y$ | $2mn$ | $mn + n + 2m$ | [✅](./kernel_course/python_ops/gemv.py) | [✅](./kernel_course/pytorch_ops/gemv.py) | [✅](./kernel_course/triton_ops/gemv.py) | ❌ | [✅](./tests/test_gemv.py) | | [geru](./docs/geru.md) | general rank-1 update | $A = A + \alpha x y^\top$ | $2mn$ | $2mn + m + n$ | [✅](./kernel_course/python_ops/geru.py) | [✅](./kernel_course/pytorch_ops/geru.py) | [✅](./kernel_course/triton_ops/geru.py) | ❌ | [✅](./tests/test_geru.py) | -| [gemm](./docs/gemm.md) | general matrix-matrix multiply | $C = \alpha A B + \beta C$ | $2mnk$ | $mk + nk + 2mn$ | [✅](./kernel_course/python_ops/gemm.py) | [✅](./kernel_course/pytorch_ops/gemm.py) | [✅](./kernel_course/triton_ops/gemm.py) | ❌ | ❌ | +| [gemm](./docs/gemm.md) | general matrix-matrix multiply | $C = \alpha A B + \beta C$ | $2mnk$ | $mk + nk + 2mn$ | [✅](./kernel_course/python_ops/gemm.py) | [✅](./kernel_course/pytorch_ops/gemm.py) | [✅](./kernel_course/triton_ops/gemm.py) | ❌ | [✅](./tests/test_gemm.py) | ## Transformer Modules