From 7b38ae4caeefc83e2e361a81afce00aedae09097 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Thu, 11 Dec 2025 09:17:15 +0800 Subject: [PATCH 1/5] Adds geru benchmark tests Introduces cross-backend geru benchmarks to compare Python, PyTorch, Triton, and Cute kernels across devices and dtypes --- tests/test_geru.py | 92 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 tests/test_geru.py diff --git a/tests/test_geru.py b/tests/test_geru.py new file mode 100644 index 0000000..3b76944 --- /dev/null +++ b/tests/test_geru.py @@ -0,0 +1,92 @@ +import pytest +import torch + +from kernel_course import testing +from kernel_course.python_ops import geru as python_geru + +try: + from kernel_course.pytorch_ops import geru as pytorch_geru + + HAS_PYTORCH = True +except Exception: + pytorch_geru = None + HAS_PYTORCH = False + +try: + from kernel_course.triton_ops import geru as triton_geru + + HAS_TRITON = True +except Exception: + triton_geru = None + HAS_TRITON = False + +try: + from kernel_course.cute_ops import geru as cute_geru + + HAS_CUTE = True +except Exception: + cute_geru = None + HAS_CUTE = False + + +def factory( + MN: tuple[int, int], + device: torch.device, + dtype: torch.dtype = torch.float32, +): + M, N = MN + A = torch.linspace(0.0, 1.0, steps=M * N, device=device, dtype=dtype).view(M, N) + x = torch.linspace(0.0, 1.0, steps=N, device=device, dtype=dtype) + y = torch.linspace(0.0, 1.0, steps=M, device=device, dtype=dtype) + alpha = 3.14 + return (A, x, y, alpha), {} + + +@pytest.mark.parametrize( + "device", + [ + pytest.param( + torch.device("cuda"), + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="requires CUDA" + ), + ), + pytest.param( + torch.device("mps"), + marks=pytest.mark.skipif( + not torch.backends.mps.is_available(), reason="requires MPS" + ), + ), + ], +) +@pytest.mark.parametrize( + "dtype", + [torch.float32, torch.float16, torch.bfloat16], +) +@pytest.mark.parametrize( + "numel", + [ + (1 << 4, 1 << 4), + (1 << 8, 1 << 8), + ], +) +def test_geru_benchmark( + device: torch.device, dtype: torch.dtype, numel: tuple[int, int] +) -> None: + impls = testing.get_impls( + python_impl=python_geru.geru, + pytorch_impl=pytorch_geru.geru if HAS_PYTORCH else None, + triton_impl=triton_geru.geru if HAS_TRITON else None, + cute_impl=cute_geru.geru if HAS_CUTE else None, + ) + + # Benchmark each implementation + config = testing.BenchmarkConfig(warmup=3, repeat=1_000) + results = testing.run_benchmarks( + impls, + lambda: factory(numel, device, dtype), + flops=2 * numel[0] * numel[1], + config=config, + ) + + testing.show_benchmarks(results) From c67b1442ece3f58572cabc11930c258adbf7e0cc Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Thu, 11 Dec 2025 09:21:33 +0800 Subject: [PATCH 2/5] Update geru test status to indicate completion --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index df1a1ce..0110510 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ The following common BLAS kernels have been implemented in multiple frameworks. | [axpby](./docs/axpby.md) | update vector| $y = \alpha x + \beta y$ | $3n$ | $3n$ | [✅](./kernel_course/python_ops/axpby.py) | [✅](./kernel_course/pytorch_ops/axpby.py) | [✅](./kernel_course/triton_ops/axpby.py) | ❌ | [✅](./tests/test_axpby.py) | | [dot](./docs/dot.md) | dot product | $z = x^\top y$ | $2n$ | $2n$ | [✅](./kernel_course/python_ops/dot.py) | [✅](./kernel_course/pytorch_ops/dot.py) | [✅](./kernel_course/triton_ops/dot.py) | ❌ | [✅](./tests/test_dot.py) | | [gemv](./docs/gemv.md) | general matrix-vector multiply | $y = \alpha A x + \beta y$ | $2mn$ | $mn + n + 2m$ | [✅](./kernel_course/python_ops/gemv.py) | [✅](./kernel_course/pytorch_ops/gemv.py) | [✅](./kernel_course/triton_ops/gemv.py) | ❌ | [✅](./tests/test_gemv.py) | -| [geru](./docs/geru.md) | general rank-1 update | $A = A + \alpha x y^\top$ | $2mn$ | $2mn + m + n$ | [✅](./kernel_course/python_ops/geru.py) | [✅](./kernel_course/pytorch_ops/geru.py) | [✅](./kernel_course/triton_ops/geru.py) | ❌ | ❌ | +| [geru](./docs/geru.md) | general rank-1 update | $A = A + \alpha x y^\top$ | $2mn$ | $2mn + m + n$ | [✅](./kernel_course/python_ops/geru.py) | [✅](./kernel_course/pytorch_ops/geru.py) | [✅](./kernel_course/triton_ops/geru.py) | ❌ | [✅](./tests/test_geru.py) | | gemm | general matrix-matrix multiply | $C = \alpha A B + \beta C$ | $2mnk$ | $mk + nk + 2mn$ | ❌ | ❌ | ❌ | ❌ | ❌ | From 1dd82a8aa6add6a2c123c171656c24909016b95e Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Thu, 11 Dec 2025 09:23:07 +0800 Subject: [PATCH 3/5] Fix formatting in geru function to clarify outer product calculation --- kernel_course/python_ops/geru.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kernel_course/python_ops/geru.py b/kernel_course/python_ops/geru.py index c2e99b2..1f3e670 100644 --- a/kernel_course/python_ops/geru.py +++ b/kernel_course/python_ops/geru.py @@ -20,6 +20,6 @@ def geru( torch.Tensor: The updated tensor `A`. """ - A = A + alpha * x[:, None] * y[None, :] + A = A + alpha * (x[:, None] * y[None, :]) return A From c5a34a767778c5d10b80271e8563f4d6d3b3764a Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Thu, 11 Dec 2025 09:23:13 +0800 Subject: [PATCH 4/5] Refactor geru function to use torch.outer for outer product calculation --- kernel_course/pytorch_ops/geru.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kernel_course/pytorch_ops/geru.py b/kernel_course/pytorch_ops/geru.py index 8ced16d..04c8933 100644 --- a/kernel_course/pytorch_ops/geru.py +++ b/kernel_course/pytorch_ops/geru.py @@ -20,6 +20,6 @@ def geru( torch.Tensor: The updated tensor `A`. """ - A += torch.mul(torch.ger(x, y), alpha) + A = torch.add(A, torch.mul(torch.outer(x, y), alpha)) return A From ed5ef56b8b03ae04677568ad536c9ff25b7177e5 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Thu, 11 Dec 2025 09:23:20 +0800 Subject: [PATCH 5/5] Rename test_gemv to test_gemv_benchmark for clarity --- tests/test_gemv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_gemv.py b/tests/test_gemv.py index bca3241..d439baa 100644 --- a/tests/test_gemv.py +++ b/tests/test_gemv.py @@ -71,7 +71,7 @@ def factory( (1 << 8, 1 << 8), ], ) -def test_gemv( +def test_gemv_benchmark( device: torch.device, dtype: torch.dtype, MN: tuple[int, int],