Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ The following common BLAS kernels have been implemented in multiple frameworks.
| [axpby](./docs/axpby.md) | update vector| $y = \alpha x + \beta y$ | $3n$ | $3n$ | [✅](./kernel_course/python_ops/axpby.py) | [✅](./kernel_course/pytorch_ops/axpby.py) | [✅](./kernel_course/triton_ops/axpby.py) | ❌ | [✅](./tests/test_axpby.py) |
| [dot](./docs/dot.md) | dot product | $z = x^\top y$ | $2n$ | $2n$ | [✅](./kernel_course/python_ops/dot.py) | [✅](./kernel_course/pytorch_ops/dot.py) | [✅](./kernel_course/triton_ops/dot.py) | ❌ | [✅](./tests/test_dot.py) |
| [gemv](./docs/gemv.md) | general matrix-vector multiply | $y = \alpha A x + \beta y$ | $2mn$ | $mn + n + 2m$ | [✅](./kernel_course/python_ops/gemv.py) | [✅](./kernel_course/pytorch_ops/gemv.py) | [✅](./kernel_course/triton_ops/gemv.py) | ❌ | [✅](./tests/test_gemv.py) |
| [geru](./docs/geru.md) | general rank-1 update | $A = A + \alpha x y^\top$ | $2mn$ | $2mn + m + n$ | [✅](./kernel_course/python_ops/geru.py) | [✅](./kernel_course/pytorch_ops/geru.py) | [✅](./kernel_course/triton_ops/geru.py) | ❌ | |
| [geru](./docs/geru.md) | general rank-1 update | $A = A + \alpha x y^\top$ | $2mn$ | $2mn + m + n$ | [✅](./kernel_course/python_ops/geru.py) | [✅](./kernel_course/pytorch_ops/geru.py) | [✅](./kernel_course/triton_ops/geru.py) | ❌ | [✅](./tests/test_geru.py) |
| gemm | general matrix-matrix multiply | $C = \alpha A B + \beta C$ | $2mnk$ | $mk + nk + 2mn$ | ❌ | ❌ | ❌ | ❌ | ❌ |


Expand Down
2 changes: 1 addition & 1 deletion kernel_course/python_ops/geru.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ def geru(
torch.Tensor: The updated tensor `A`.
"""

A = A + alpha * x[:, None] * y[None, :]
A = A + alpha * (x[:, None] * y[None, :])

return A
2 changes: 1 addition & 1 deletion kernel_course/pytorch_ops/geru.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ def geru(
torch.Tensor: The updated tensor `A`.
"""

A += torch.mul(torch.ger(x, y), alpha)
A = torch.add(A, torch.mul(torch.outer(x, y), alpha))

return A
2 changes: 1 addition & 1 deletion tests/test_gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def factory(
(1 << 8, 1 << 8),
],
)
def test_gemv(
def test_gemv_benchmark(
device: torch.device,
dtype: torch.dtype,
MN: tuple[int, int],
Expand Down
92 changes: 92 additions & 0 deletions tests/test_geru.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import pytest
import torch

from kernel_course import testing
from kernel_course.python_ops import geru as python_geru

try:
from kernel_course.pytorch_ops import geru as pytorch_geru

HAS_PYTORCH = True
except Exception:
pytorch_geru = None
HAS_PYTORCH = False

try:
from kernel_course.triton_ops import geru as triton_geru

HAS_TRITON = True
except Exception:
triton_geru = None
HAS_TRITON = False

try:
from kernel_course.cute_ops import geru as cute_geru

HAS_CUTE = True
except Exception:
cute_geru = None
HAS_CUTE = False


def factory(
MN: tuple[int, int],
device: torch.device,
dtype: torch.dtype = torch.float32,
):
M, N = MN
A = torch.linspace(0.0, 1.0, steps=M * N, device=device, dtype=dtype).view(M, N)
x = torch.linspace(0.0, 1.0, steps=N, device=device, dtype=dtype)
y = torch.linspace(0.0, 1.0, steps=M, device=device, dtype=dtype)
alpha = 3.14
return (A, x, y, alpha), {}


@pytest.mark.parametrize(
"device",
[
pytest.param(
torch.device("cuda"),
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason="requires CUDA"
),
),
pytest.param(
torch.device("mps"),
marks=pytest.mark.skipif(
not torch.backends.mps.is_available(), reason="requires MPS"
),
),
],
)
@pytest.mark.parametrize(
"dtype",
[torch.float32, torch.float16, torch.bfloat16],
)
@pytest.mark.parametrize(
"numel",
[
(1 << 4, 1 << 4),
(1 << 8, 1 << 8),
],
)
def test_geru_benchmark(
device: torch.device, dtype: torch.dtype, numel: tuple[int, int]
) -> None:
impls = testing.get_impls(
python_impl=python_geru.geru,
pytorch_impl=pytorch_geru.geru if HAS_PYTORCH else None,
triton_impl=triton_geru.geru if HAS_TRITON else None,
cute_impl=cute_geru.geru if HAS_CUTE else None,
)

# Benchmark each implementation
config = testing.BenchmarkConfig(warmup=3, repeat=1_000)
results = testing.run_benchmarks(
impls,
lambda: factory(numel, device, dtype),
flops=2 * numel[0] * numel[1],
config=config,
)

testing.show_benchmarks(results)