Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ The following common BLAS kernels have been implemented in multiple frameworks.
| [dot](./docs/dot.md) | dot product | $z = x^\top y$ | $2n$ | $2n$ | [✅](./kernel_course/python_ops/dot.py) | [✅](./kernel_course/pytorch_ops/dot.py) | [✅](./kernel_course/triton_ops/dot.py) | ❌ | [✅](./tests/test_dot.py) |
| [gemv](./docs/gemv.md) | general matrix-vector multiply | $y = \alpha A x + \beta y$ | $2mn$ | $mn + n + 2m$ | [✅](./kernel_course/python_ops/gemv.py) | [✅](./kernel_course/pytorch_ops/gemv.py) | [✅](./kernel_course/triton_ops/gemv.py) | ❌ | [✅](./tests/test_gemv.py) |
| [geru](./docs/geru.md) | general rank-1 update | $A = A + \alpha x y^\top$ | $2mn$ | $2mn + m + n$ | [✅](./kernel_course/python_ops/geru.py) | [✅](./kernel_course/pytorch_ops/geru.py) | [✅](./kernel_course/triton_ops/geru.py) | ❌ | [✅](./tests/test_geru.py) |
| gemm | general matrix-matrix multiply | $C = \alpha A B + \beta C$ | $2mnk$ | $mk + nk + 2mn$ | | ❌ | ❌ | ❌ | ❌ |
| [gemm](./docs/gemm.md) | general matrix-matrix multiply | $C = \alpha A B + \beta C$ | $2mnk$ | $mk + nk + 2mn$ | [✅](./kernel_course/python_ops/gemm.py) | ❌ | ❌ | ❌ | ❌ |


## Transformer Modules
Expand Down
35 changes: 35 additions & 0 deletions docs/gemm.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# GEMM Kernel

The `gemm` operator computes the matrix-matrix product of two matrices.

## Mathematical Definition

Given input matrices `A` and `B`, along with an output matrix `C` and scalars `α` and `β`, the kernel evaluates

$$
C = \alpha A B + \beta C
$$

The matrix-matrix product is computed by multiplying the matrix `A` with the matrix `B`, scaling the result by `α`, scaling the matrix `C` by `β`, and then adding the two scaled results together to produce the updated matrix `C`.

## Kernel Implementations

- [Python Implementation](../kernel_course/python_ops/gemm.py)
- [PyTorch Implementation](../kernel_course/pytorch_ops/gemm.py)
- [Triton Implementation](../kernel_course/triton_ops/gemm.py)
- [CuTe Implementation](../kernel_course/cute_ops/gemm.py)

All backends share the interface:

```python
def gemm(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, alpha: float, beta: float) -> torch.Tensor:
...
```

## Testing

See the [test suite](../tests/test_gemm.py) for the validation harness that exercises every backend.

```bash
pytest tests/test_gemm.py -s
```
28 changes: 28 additions & 0 deletions kernel_course/python_ops/gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch


def gemm(
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
alpha: float,
beta: float,
) -> torch.Tensor:
"""
Updates tensor `C` by adding the product of matrices `A` and `B`
scaled by `alpha`, and `C` scaled by `beta`.

Args:
A (torch.Tensor): First matrix tensor.
B (torch.Tensor): Second matrix tensor to be multiplied with `A`.
C (torch.Tensor): Matrix tensor to be updated.
alpha (float): Scaling factor for the product of `A` and `B`.
beta (float): Scaling factor for `C`.

Returns:
torch.Tensor: The updated tensor `C`.
"""

C = alpha * A @ B + beta * C

return C
Loading