Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 56 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
# KernelAgent — Multi‑Agent GPU Kernel Synthesis
# KernelAgent — Multi‑Agent GPU Kernel Synthesis and Optimization

KernelAgent turns PyTorch programs into verified Triton kernels. It was designed around KernelBench workloads and combines:
KernelAgent turns PyTorch programs into verified Triton kernels and optimize its performance. It was designed around KernelBench workloads and combines:

- Static problem analysis to decide whether to run a lightweight path or a full pipeline
- LLM‑assisted refactoring that isolates fusable subgraphs
- Parallel Triton kernel generation with strict runtime verification
- End‑to‑end composition that rebuilds the original forward pass using only the synthesized kernels
- Hardware‑guided optimization pipeline that iteratively improves performance

Blog post: [PyTorch KernelFalcon](https://pytorch.org/blog/kernelfalcon-autonomous-gpu-kernel-generation-via-deep-agents/)

Additional docs: coming soon

## Pipeline Overview
## Kernel Generation Pipeline Overview

![](./assets/kernelagent2.excalidraw.svg)

Every stage writes artifacts to a run directory under `.fuse/<run_id>/`, including the fused PyTorch code, `subgraphs.json`, individual KernelAgent sessions, and the final `compose_out/composed_kernel.py`.

## KernelAgent Multi-Worker Optimization Pipeline Overview
![](./assets/opt_agent.svg)
Every stage writes artifacts to a run directory under `.optimize/<run_id>/`, including the input Triton kernel, artifacts, individual optimization worker sessions, and the final `output/best_kernel.py`.


## Quickstart

### Requirements
Expand Down Expand Up @@ -143,6 +148,7 @@ More knobs live in `triton_kernel_agent/agent.py` and `Fuser/config.py`.
- Triton KernelAgent UI: `kernel-agent` or `python scripts/triton_ui.py`
- Fuser orchestration UI: `fuser-ui` or `python scripts/fuser_ui`
- Full pipeline UI: `pipeline-ui` or `python scripts/pipeline_ui`
- Optimization UI: `optimization-ui` or `python scripts/optimization_ui.py`

## Component Details

Expand All @@ -158,6 +164,49 @@ More knobs live in `triton_kernel_agent/agent.py` and `Fuser/config.py`.

- **Composer (`Fuser/compose_end_to_end.py`)**: stitches the verified kernels back into a single Triton program. The composed file contains one or more `@triton.jit` kernels plus a `kernel_function(...)` wrapper and a self-test that replays the original PyTorch problem. With `--verify`, the test is executed immediately and must succeed.

## Kernel Optimization Pipeline

KernelAgent includes a hardware-guided optimization pipeline that iteratively improves a verified Triton kernel's performance using GPU profiling feedback.

1. **Profile** — NCU collects 28 hardware metrics (compute utilization, memory bandwidth, cache hit rates, occupancy, stall breakdowns)
2. **Roofline Analysis** — Classifies the kernel as memory-bound, compute-bound, or underutilized based on SOL (speed-of-light) percentages
3. **Bottleneck Diagnosis** — An LLM analyzes the NCU metrics + kernel code to identify root causes and recommend specific fixes
4. **Optimization** — An LLM generates an optimized kernel applying the recommended fixes
5. **Verification** — The optimized kernel is tested for numerical correctness against PyTorch reference
6. **Benchmarking** — CUDA event timing measures the new kernel, tracking best-so-far with divergence-based revert

The loop runs for up to N rounds, with early termination when the kernel reaches roofline (≥95% SOL) or when performance converges.

### Usage

#### Gradio UI
```bash
optimization-ui.py --port 8088
```


### Key Components

| Component | Location | Role |
|---|---|---|
| **OptimizationOrchestrator** | `triton_kernel_agent/opt_worker_component/orchestrator/` | Main optimization loop |
| **KernelProfiler** | `triton_kernel_agent/opt_worker_component/profiling/` | NCU hardware profiling |
| **BottleneckAnalyzer** | `triton_kernel_agent/opt_worker_component/prescribing/` | LLM-based bottleneck diagnosis |
| **RooflineAnalyzer** | `kernel_perf_agent/kernel_opt/roofline/` | SOL classification and early stopping |
| **Benchmark** | `triton_kernel_agent/opt_worker_component/benchmarking/` | CUDA event timing |

### Optimization Artifacts

```
.optimize/workers/<worker_id>/<run_id>/artifacts
kernel_round_0.py # baseline kernel
kernel_round_N.py # kernel after round N
round001_opt_prompt.txt # optimization prompt sent to LLM
round001_opt_reply.txt # LLM response
round001_strategy.json # bottleneck analysis result
...
```

## Platform Support

KernelAgent supports multiple GPU platforms for Triton kernel execution:
Expand Down Expand Up @@ -218,6 +267,7 @@ It includes selected L1/L2/L3 problems with:
## Repository Layout

- `triton_kernel_agent/` — KernelAgent core (agent, worker manager, provider adapters, prompt templates)
- `triton_kernel_agent/opt_worker_component/` — optimization pipeline (profiler, benchmarker, bottleneck analyzer, orchestrator)
- `Fuser/` — auto-router, orchestration pipeline, CLIs, Gradio UIs
- `triton_kernel_agent/templates/` — Jinja templates used when prompting TritonKernelAgent
- `examples/` — sample problems and prompt snippets
Expand All @@ -234,7 +284,8 @@ It includes selected L1/L2/L3 problems with:

## Documentation & Community

- Architecture and deep-dive docs: `Coming Soon`
- Optimization pipeline docs: see [Kernel Optimization Pipeline](#kernel-optimization-pipeline) above
- Open-source recommendations: see `docs/open_source_recommendations.md`
- Issues: https://github.com/pytorch-labs/KernelAgent/issues
- Blog post: https://pytorch.org/blog/kernelfalcon-autonomous-gpu-kernel-generation-via-deep-agents/

Expand Down
2 changes: 2 additions & 0 deletions assets/opt_agent.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
166 changes: 166 additions & 0 deletions examples/optimize_01_matvec/input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# kernel.py
# Matrix-vector multiplication using Triton: C = A @ B
# Implements the exact problem from the test:
# - M = 2048
# - K = 1,048,576
# - A: (M, K), BF16
# - B: (K, 1), BF16
# - C: (M, 1), BF16
#
# Notes on fusion:
# - The entire operation (matrix-vector product) is executed in a single Triton kernel.
# - There is nothing else to fuse (no bias/activation in the test), so no extra kernel stages are required.
# - All math is performed inside the Triton kernel; the Python wrapper only validates/allocates/configures.
#
# Triton programming guidelines followed:
# - Use @triton.jit for kernels.
# - Use tl.constexpr for compile-time constants (BLOCK_M, BLOCK_K).
# - Proper indexing with tl.program_id, tl.arange, and tl.cdiv.
# - Use tl.load/tl.store with masks for OOB protection and coalesced access on contiguous inputs.
# - Accumulate in FP32 for numerical stability and convert to BF16 on store.

import triton
import triton.language as tl
import torch


@triton.jit
def _matvec_kernel(
a_ptr,
b_ptr,
c_ptr,
M,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_K: tl.constexpr,
):
# Program id along the M dimension (each program computes a block of rows)
pid_m = tl.program_id(0)

# Row indices this program handles
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
m_mask = offs_m < M

# Help compiler with alignment information
offs_m = tl.max_contiguous(tl.multiple_of(offs_m, BLOCK_M), BLOCK_M)

# Initialize FP32 accumulator for BLOCK_M rows
acc = tl.zeros((BLOCK_M,), dtype=tl.float32)

# Iterate over K dimension in chunks of BLOCK_K
# Use tl.range to ensure proper device-side looping
for k0 in tl.range(0, K, BLOCK_K):
offs_k = k0 + tl.arange(0, BLOCK_K)
k_mask = offs_k < K
# Also assist compiler with alignment info for K offsets
offs_k = tl.max_contiguous(tl.multiple_of(offs_k, BLOCK_K), BLOCK_K)

# Compute pointers:
# A tile is [BLOCK_M, BLOCK_K] region starting at (offs_m, offs_k)
a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
# B tile is a vector [BLOCK_K] at column 0 (since B is [K, 1])
b_ptrs = b_ptr + (offs_k * stride_bk + 0 * stride_bn)

# Load with masking to guard boundaries. Inputs are BF16; cast to FP32 for accumulation
a = tl.load(a_ptrs, mask=(m_mask[:, None] & k_mask[None, :]), other=0).to(
tl.float32
)
b = tl.load(b_ptrs, mask=k_mask, other=0).to(tl.float32)

# Fused multiply-accumulate for rows in this tile:
# sum over K tile dimension for each row
acc += tl.sum(a * b[None, :], axis=1)

# Convert accumulator to BF16 and store to C[:, 0]
out = acc.to(tl.bfloat16)
c_ptrs = c_ptr + (offs_m * stride_cm + 0 * stride_cn)
tl.store(c_ptrs, out, mask=m_mask)


def kernel_function(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
"""
Compute C = A @ B using a single Triton kernel.

What is fused:
- Entire matrix-vector multiplication is done in one pass inside the kernel.
- No additional ops (e.g., bias/activation) are required by the test, so none are fused.

Runtime constraints honored:
- Wrapper only validates arguments, allocates output, and launches the Triton kernel.
- All math is inside the Triton kernel; no torch.nn or torch.nn.functional usage.

Args:
A: [M, K] BF16 CUDA tensor
B: [K, 1] BF16 CUDA tensor (also accepts shape [K], it will be viewed as [K, 1])

Returns:
C: [M, 1] BF16 CUDA tensor
"""
# Validate device and dtype
if not A.is_cuda or not B.is_cuda:
raise ValueError("A and B must be CUDA tensors.")
if A.dtype != torch.bfloat16 or B.dtype != torch.bfloat16:
raise ValueError("A and B must be torch.bfloat16 tensors.")

if A.ndim != 2:
raise ValueError("A must be 2D [M, K].")
M, K = A.shape

# Accept B as [K] or [K, 1]
if B.ndim == 1:
if B.shape[0] != K:
raise ValueError(
f"When B is 1D, expected shape [K]={K}, but got {tuple(B.shape)}"
)
Bv = B.view(K, 1)
elif B.ndim == 2:
if B.shape[0] != K or B.shape[1] != 1:
raise ValueError(
f"B must be [K, 1], got {tuple(B.shape)} (K must match A.shape[1])"
)
Bv = B
else:
raise ValueError("B must be 1D [K] or 2D [K, 1].")

# Allocate output C [M, 1]
C = torch.empty((M, 1), device=A.device, dtype=A.dtype)

# Extract strides (in elements, not bytes)
stride_am, stride_ak = A.stride()
stride_bk, stride_bn = Bv.stride()
stride_cm, stride_cn = C.stride()

# Kernel launch configuration
# Choose modest tile sizes to balance register usage and loop count over K.
# For the huge K in the test, BLOCK_K=256 works well without excessive register pressure.
BLOCK_M = 128
BLOCK_K = 256

def grid(meta):
return (triton.cdiv(M, meta["BLOCK_M"]),)

_matvec_kernel[grid](
A,
Bv,
C,
M,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_M=BLOCK_M,
BLOCK_K=BLOCK_K,
num_warps=4,
num_stages=2,
)

return C
38 changes: 38 additions & 0 deletions examples/optimize_01_matvec/problem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
import torch.nn as nn


class Model(nn.Module):
"""
Simple model that performs matrix-vector multiplication (C = A * B).
"""

def __init__(self):
super(Model, self).__init__()

def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
"""
Performs matrix-vector multiplication.

Args:
A: Input matrix of shape (M, K).
B: Input vector of shape (K, 1).

Returns:
Output vector of shape (M, 1).
"""
return torch.matmul(A, B)


M = 256 * 8 # 2048
K = 131072 * 8 # 1048576


def get_inputs():
A = torch.rand(M, K)
B = torch.rand(K, 1)
return [A, B]


def get_init_inputs():
return [] # No special initialization inputs needed
Loading