From 25072d8192dd368985ddff185ed46623d006ae03 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 9 Dec 2025 18:10:35 +0800 Subject: [PATCH 1/3] Adds PyTorch GERU helper Implements a tensor update routine that scales and adds the outer product of two vectors to support geru semantics within the PyTorch ops module. --- kernel_course/pytorch_ops/geru.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 kernel_course/pytorch_ops/geru.py diff --git a/kernel_course/pytorch_ops/geru.py b/kernel_course/pytorch_ops/geru.py new file mode 100644 index 0000000..3cd9867 --- /dev/null +++ b/kernel_course/pytorch_ops/geru.py @@ -0,0 +1,25 @@ +import torch + + +def geru( + A: torch.Tensor, + x: torch.Tensor, + y: torch.Tensor, + alpha: float, +): + """ + Updates tensor `A` by adding the outer product of vectors `x` and `y` scaled by `alpha` using PyTorch operations. + + Args: + A (torch.Tensor): Matrix tensor to be updated. + x (torch.Tensor): Vector tensor. + y (torch.Tensor): Vector tensor. + alpha (float): Scaling factor for the outer product of `x` and `y`. + + Returns: + torch.Tensor: The updated tensor `A`. + """ + + A += torch.mul(torch.ger(x, y), alpha) + + return A From da194ef4298d77e99b377a801477382813dd4baf Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 9 Dec 2025 18:10:41 +0800 Subject: [PATCH 2/3] Updates GERU kernel status in README to reflect PyTorch implementation completion --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 8f15583..2e89239 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ The following common BLAS kernels have been implemented in multiple frameworks. | [axpby](./docs/axpby.md) | update vector| $y = \alpha x + \beta y$ | $3n$ | $3n$ | [✅](./kernel_course/python_ops/axpby.py) | [✅](./kernel_course/pytorch_ops/axpby.py) | [✅](./kernel_course/triton_ops/axpby.py) | ❌ | [✅](./tests/test_axpby.py) | | [dot](./docs/dot.md) | dot product | $z = x^\top y$ | $2n$ | $2n$ | [✅](./kernel_course/python_ops/dot.py) | [✅](./kernel_course/pytorch_ops/dot.py) | [✅](./kernel_course/triton_ops/dot.py) | ❌ | [✅](./tests/test_dot.py) | | [gemv](./docs/gemv.md) | general matrix-vector multiply | $y = \alpha A x + \beta y$ | $2mn$ | $mn + n + 2m$ | [✅](./kernel_course/python_ops/gemv.py) | [✅](./kernel_course/pytorch_ops/gemv.py) | [✅](./kernel_course/triton_ops/gemv.py) | ❌ | [✅](./tests/test_gemv.py) | -| [geru](./docs/geru.md) | general rank-1 update | $A = A + \alpha x y^\top$ | $2mn$ | $2mn + m + n$ | [✅](./kernel_course/python_ops/geru.py) | ❌ | ❌ | ❌ | ❌ | +| [geru](./docs/geru.md) | general rank-1 update | $A = A + \alpha x y^\top$ | $2mn$ | $2mn + m + n$ | [✅](./kernel_course/python_ops/geru.py) | [✅](./kernel_course/pytorch_ops/geru.py) | ❌ | ❌ | ❌ | | gemm | general matrix-matrix multiply | $C = \alpha A B + \beta C$ | $2mnk$ | $mk + nk + 2mn$ | ❌ | ❌ | ❌ | ❌ | ❌ | From 6333c9618ac4163272fd56a4b405f58f84230e6b Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 9 Dec 2025 18:12:40 +0800 Subject: [PATCH 3/3] Fix formatting in geru function docstring --- kernel_course/pytorch_ops/geru.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kernel_course/pytorch_ops/geru.py b/kernel_course/pytorch_ops/geru.py index 3cd9867..8ced16d 100644 --- a/kernel_course/pytorch_ops/geru.py +++ b/kernel_course/pytorch_ops/geru.py @@ -15,7 +15,7 @@ def geru( x (torch.Tensor): Vector tensor. y (torch.Tensor): Vector tensor. alpha (float): Scaling factor for the outer product of `x` and `y`. - + Returns: torch.Tensor: The updated tensor `A`. """