diff --git a/README.md b/README.md index 18792af0..1bd2a2a6 100644 --- a/README.md +++ b/README.md @@ -1,22 +1,27 @@ -# KernelAgent — Multi‑Agent GPU Kernel Synthesis +# KernelAgent — Multi‑Agent GPU Kernel Synthesis and Optimization -KernelAgent turns PyTorch programs into verified Triton kernels. It was designed around KernelBench workloads and combines: +KernelAgent turns PyTorch programs into verified Triton kernels and optimize its performance. It was designed around KernelBench workloads and combines: - Static problem analysis to decide whether to run a lightweight path or a full pipeline - LLM‑assisted refactoring that isolates fusable subgraphs - Parallel Triton kernel generation with strict runtime verification - End‑to‑end composition that rebuilds the original forward pass using only the synthesized kernels +- Hardware‑guided optimization pipeline that iteratively improves performance Blog post: [PyTorch KernelFalcon](https://pytorch.org/blog/kernelfalcon-autonomous-gpu-kernel-generation-via-deep-agents/) -Additional docs: coming soon -## Pipeline Overview +## Kernel Generation Pipeline Overview ![](./assets/kernelagent2.excalidraw.svg) Every stage writes artifacts to a run directory under `.fuse//`, including the fused PyTorch code, `subgraphs.json`, individual KernelAgent sessions, and the final `compose_out/composed_kernel.py`. +## KernelAgent Multi-Worker Optimization Pipeline Overview +![](./assets/opt_agent.svg) +Every stage writes artifacts to a run directory under `.optimize//`, including the input Triton kernel, artifacts, individual optimization worker sessions, and the final `output/best_kernel.py`. + + ## Quickstart ### Requirements @@ -143,6 +148,7 @@ More knobs live in `triton_kernel_agent/agent.py` and `Fuser/config.py`. - Triton KernelAgent UI: `kernel-agent` or `python scripts/triton_ui.py` - Fuser orchestration UI: `fuser-ui` or `python scripts/fuser_ui` - Full pipeline UI: `pipeline-ui` or `python scripts/pipeline_ui` + - Optimization UI: `optimization-ui` or `python scripts/optimization_ui.py` ## Component Details @@ -158,6 +164,49 @@ More knobs live in `triton_kernel_agent/agent.py` and `Fuser/config.py`. - **Composer (`Fuser/compose_end_to_end.py`)**: stitches the verified kernels back into a single Triton program. The composed file contains one or more `@triton.jit` kernels plus a `kernel_function(...)` wrapper and a self-test that replays the original PyTorch problem. With `--verify`, the test is executed immediately and must succeed. +## Kernel Optimization Pipeline + +KernelAgent includes a hardware-guided optimization pipeline that iteratively improves a verified Triton kernel's performance using GPU profiling feedback. + +1. **Profile** — NCU collects 28 hardware metrics (compute utilization, memory bandwidth, cache hit rates, occupancy, stall breakdowns) +2. **Roofline Analysis** — Classifies the kernel as memory-bound, compute-bound, or underutilized based on SOL (speed-of-light) percentages +3. **Bottleneck Diagnosis** — An LLM analyzes the NCU metrics + kernel code to identify root causes and recommend specific fixes +4. **Optimization** — An LLM generates an optimized kernel applying the recommended fixes +5. **Verification** — The optimized kernel is tested for numerical correctness against PyTorch reference +6. **Benchmarking** — CUDA event timing measures the new kernel, tracking best-so-far with divergence-based revert + +The loop runs for up to N rounds, with early termination when the kernel reaches roofline (≥95% SOL) or when performance converges. + +### Usage + +#### Gradio UI +```bash +optimization-ui.py --port 8088 +``` + + +### Key Components + +| Component | Location | Role | +|---|---|---| +| **OptimizationOrchestrator** | `triton_kernel_agent/opt_worker_component/orchestrator/` | Main optimization loop | +| **KernelProfiler** | `triton_kernel_agent/opt_worker_component/profiling/` | NCU hardware profiling | +| **BottleneckAnalyzer** | `triton_kernel_agent/opt_worker_component/prescribing/` | LLM-based bottleneck diagnosis | +| **RooflineAnalyzer** | `kernel_perf_agent/kernel_opt/roofline/` | SOL classification and early stopping | +| **Benchmark** | `triton_kernel_agent/opt_worker_component/benchmarking/` | CUDA event timing | + +### Optimization Artifacts + +``` +.optimize/workers///artifacts + kernel_round_0.py # baseline kernel + kernel_round_N.py # kernel after round N + round001_opt_prompt.txt # optimization prompt sent to LLM + round001_opt_reply.txt # LLM response + round001_strategy.json # bottleneck analysis result + ... +``` + ## Platform Support KernelAgent supports multiple GPU platforms for Triton kernel execution: @@ -218,6 +267,7 @@ It includes selected L1/L2/L3 problems with: ## Repository Layout - `triton_kernel_agent/` — KernelAgent core (agent, worker manager, provider adapters, prompt templates) +- `triton_kernel_agent/opt_worker_component/` — optimization pipeline (profiler, benchmarker, bottleneck analyzer, orchestrator) - `Fuser/` — auto-router, orchestration pipeline, CLIs, Gradio UIs - `triton_kernel_agent/templates/` — Jinja templates used when prompting TritonKernelAgent - `examples/` — sample problems and prompt snippets @@ -234,7 +284,8 @@ It includes selected L1/L2/L3 problems with: ## Documentation & Community -- Architecture and deep-dive docs: `Coming Soon` +- Optimization pipeline docs: see [Kernel Optimization Pipeline](#kernel-optimization-pipeline) above +- Open-source recommendations: see `docs/open_source_recommendations.md` - Issues: https://github.com/pytorch-labs/KernelAgent/issues - Blog post: https://pytorch.org/blog/kernelfalcon-autonomous-gpu-kernel-generation-via-deep-agents/ diff --git a/assets/opt_agent.svg b/assets/opt_agent.svg new file mode 100644 index 00000000..406a52d3 --- /dev/null +++ b/assets/opt_agent.svg @@ -0,0 +1,2 @@ +(2) Judge Agent (1) Profiler Agent Opt-Agent 1Opt-Agent 3Opt-Agent 2(5) Optimization ManagerKernel 1Kernel 2QueueTop candidates...Input Kernel...HistoryRAGKnowledgeReflextionDiagnose Prescribe SynthesizeExplore MeasureCollect(6) Benchmarking Agent(3) Analyzer Agent \ No newline at end of file diff --git a/examples/optimize_01_matvec/input.py b/examples/optimize_01_matvec/input.py new file mode 100644 index 00000000..0597a968 --- /dev/null +++ b/examples/optimize_01_matvec/input.py @@ -0,0 +1,166 @@ +# kernel.py +# Matrix-vector multiplication using Triton: C = A @ B +# Implements the exact problem from the test: +# - M = 2048 +# - K = 1,048,576 +# - A: (M, K), BF16 +# - B: (K, 1), BF16 +# - C: (M, 1), BF16 +# +# Notes on fusion: +# - The entire operation (matrix-vector product) is executed in a single Triton kernel. +# - There is nothing else to fuse (no bias/activation in the test), so no extra kernel stages are required. +# - All math is performed inside the Triton kernel; the Python wrapper only validates/allocates/configures. +# +# Triton programming guidelines followed: +# - Use @triton.jit for kernels. +# - Use tl.constexpr for compile-time constants (BLOCK_M, BLOCK_K). +# - Proper indexing with tl.program_id, tl.arange, and tl.cdiv. +# - Use tl.load/tl.store with masks for OOB protection and coalesced access on contiguous inputs. +# - Accumulate in FP32 for numerical stability and convert to BF16 on store. + +import triton +import triton.language as tl +import torch + + +@triton.jit +def _matvec_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, +): + # Program id along the M dimension (each program computes a block of rows) + pid_m = tl.program_id(0) + + # Row indices this program handles + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + m_mask = offs_m < M + + # Help compiler with alignment information + offs_m = tl.max_contiguous(tl.multiple_of(offs_m, BLOCK_M), BLOCK_M) + + # Initialize FP32 accumulator for BLOCK_M rows + acc = tl.zeros((BLOCK_M,), dtype=tl.float32) + + # Iterate over K dimension in chunks of BLOCK_K + # Use tl.range to ensure proper device-side looping + for k0 in tl.range(0, K, BLOCK_K): + offs_k = k0 + tl.arange(0, BLOCK_K) + k_mask = offs_k < K + # Also assist compiler with alignment info for K offsets + offs_k = tl.max_contiguous(tl.multiple_of(offs_k, BLOCK_K), BLOCK_K) + + # Compute pointers: + # A tile is [BLOCK_M, BLOCK_K] region starting at (offs_m, offs_k) + a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) + # B tile is a vector [BLOCK_K] at column 0 (since B is [K, 1]) + b_ptrs = b_ptr + (offs_k * stride_bk + 0 * stride_bn) + + # Load with masking to guard boundaries. Inputs are BF16; cast to FP32 for accumulation + a = tl.load(a_ptrs, mask=(m_mask[:, None] & k_mask[None, :]), other=0).to( + tl.float32 + ) + b = tl.load(b_ptrs, mask=k_mask, other=0).to(tl.float32) + + # Fused multiply-accumulate for rows in this tile: + # sum over K tile dimension for each row + acc += tl.sum(a * b[None, :], axis=1) + + # Convert accumulator to BF16 and store to C[:, 0] + out = acc.to(tl.bfloat16) + c_ptrs = c_ptr + (offs_m * stride_cm + 0 * stride_cn) + tl.store(c_ptrs, out, mask=m_mask) + + +def kernel_function(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + """ + Compute C = A @ B using a single Triton kernel. + + What is fused: + - Entire matrix-vector multiplication is done in one pass inside the kernel. + - No additional ops (e.g., bias/activation) are required by the test, so none are fused. + + Runtime constraints honored: + - Wrapper only validates arguments, allocates output, and launches the Triton kernel. + - All math is inside the Triton kernel; no torch.nn or torch.nn.functional usage. + + Args: + A: [M, K] BF16 CUDA tensor + B: [K, 1] BF16 CUDA tensor (also accepts shape [K], it will be viewed as [K, 1]) + + Returns: + C: [M, 1] BF16 CUDA tensor + """ + # Validate device and dtype + if not A.is_cuda or not B.is_cuda: + raise ValueError("A and B must be CUDA tensors.") + if A.dtype != torch.bfloat16 or B.dtype != torch.bfloat16: + raise ValueError("A and B must be torch.bfloat16 tensors.") + + if A.ndim != 2: + raise ValueError("A must be 2D [M, K].") + M, K = A.shape + + # Accept B as [K] or [K, 1] + if B.ndim == 1: + if B.shape[0] != K: + raise ValueError( + f"When B is 1D, expected shape [K]={K}, but got {tuple(B.shape)}" + ) + Bv = B.view(K, 1) + elif B.ndim == 2: + if B.shape[0] != K or B.shape[1] != 1: + raise ValueError( + f"B must be [K, 1], got {tuple(B.shape)} (K must match A.shape[1])" + ) + Bv = B + else: + raise ValueError("B must be 1D [K] or 2D [K, 1].") + + # Allocate output C [M, 1] + C = torch.empty((M, 1), device=A.device, dtype=A.dtype) + + # Extract strides (in elements, not bytes) + stride_am, stride_ak = A.stride() + stride_bk, stride_bn = Bv.stride() + stride_cm, stride_cn = C.stride() + + # Kernel launch configuration + # Choose modest tile sizes to balance register usage and loop count over K. + # For the huge K in the test, BLOCK_K=256 works well without excessive register pressure. + BLOCK_M = 128 + BLOCK_K = 256 + + def grid(meta): + return (triton.cdiv(M, meta["BLOCK_M"]),) + + _matvec_kernel[grid]( + A, + Bv, + C, + M, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_M=BLOCK_M, + BLOCK_K=BLOCK_K, + num_warps=4, + num_stages=2, + ) + + return C diff --git a/examples/optimize_01_matvec/problem.py b/examples/optimize_01_matvec/problem.py new file mode 100644 index 00000000..35a8a417 --- /dev/null +++ b/examples/optimize_01_matvec/problem.py @@ -0,0 +1,38 @@ +import torch +import torch.nn as nn + + +class Model(nn.Module): + """ + Simple model that performs matrix-vector multiplication (C = A * B). + """ + + def __init__(self): + super(Model, self).__init__() + + def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + """ + Performs matrix-vector multiplication. + + Args: + A: Input matrix of shape (M, K). + B: Input vector of shape (K, 1). + + Returns: + Output vector of shape (M, 1). + """ + return torch.matmul(A, B) + + +M = 256 * 8 # 2048 +K = 131072 * 8 # 1048576 + + +def get_inputs(): + A = torch.rand(M, K) + B = torch.rand(K, 1) + return [A, B] + + +def get_init_inputs(): + return [] # No special initialization inputs needed diff --git a/examples/optimize_01_matvec/test.py b/examples/optimize_01_matvec/test.py new file mode 100644 index 00000000..06a9d8bf --- /dev/null +++ b/examples/optimize_01_matvec/test.py @@ -0,0 +1,272 @@ +"""Correctness test for matrix-vector multiplication kernel.""" + +import inspect +import sys + +import torch +from kernel import kernel_function +from problem import get_init_inputs, get_inputs, Model + +_CONV_TYPES = ( + torch.nn.Conv1d, + torch.nn.Conv2d, + torch.nn.Conv3d, + torch.nn.ConvTranspose1d, + torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d, +) +_NORM_TYPES = ( + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + torch.nn.LayerNorm, + torch.nn.GroupNorm, + torch.nn.InstanceNorm1d, + torch.nn.InstanceNorm2d, + torch.nn.InstanceNorm3d, +) +_POOL_TYPES = ( + torch.nn.MaxPool1d, + torch.nn.MaxPool2d, + torch.nn.MaxPool3d, + torch.nn.AvgPool1d, + torch.nn.AvgPool2d, + torch.nn.AvgPool3d, + torch.nn.AdaptiveAvgPool1d, + torch.nn.AdaptiveAvgPool2d, + torch.nn.AdaptiveAvgPool3d, + torch.nn.AdaptiveMaxPool1d, + torch.nn.AdaptiveMaxPool2d, + torch.nn.AdaptiveMaxPool3d, +) + + +def _extract_model_params(model): + """Extract learnable parameters and layer config from a PyTorch model.""" + params = {} + + for _, module in model.named_modules(): + if isinstance(module, (*_CONV_TYPES, torch.nn.Linear)): + if hasattr(module, "weight") and module.weight is not None: + params.setdefault("weight", module.weight) + params.setdefault("w", module.weight) + if getattr(module, "bias", None) is not None: + params.setdefault("conv_bias", module.bias) + params.setdefault("bias", module.bias) + for attr in ("stride", "padding", "dilation", "output_padding"): + val = getattr(module, attr, None) + if val is not None: + params.setdefault(attr, val) + if hasattr(module, "groups"): + params.setdefault("groups", module.groups) + + elif isinstance(module, _NORM_TYPES): + if getattr(module, "weight", None) is not None: + params.setdefault("weight", module.weight) + params.setdefault("w", module.weight) + if getattr(module, "bias", None) is not None: + params.setdefault("bias", module.bias) + if hasattr(module, "eps"): + params["eps"] = module.eps + if hasattr(module, "num_groups"): + params["num_groups"] = module.num_groups + if hasattr(module, "normalized_shape"): + params["normalized_shape"] = module.normalized_shape + + elif isinstance(module, _POOL_TYPES): + for attr in ("kernel_size", "stride", "padding", "dilation"): + val = getattr(module, attr, None) + if val is not None: + params.setdefault(attr, val) + + if hasattr(model, "bias") and isinstance( + model.bias, (torch.Tensor, torch.nn.Parameter) + ): + params["add_bias"] = model.bias + params.setdefault("bias", model.bias) + + # Extract simple scalar attributes stored by Model.__init__ + # (catches dim, negative_slope, min_val, max_val, etc.) + _INIT_SCALAR_NAMES = { + "dim", + "negative_slope", + "min_val", + "max_val", + "beta", + "threshold", + "alpha", + "lambd", + "upper", + "lower", + "p", + } + for attr_name in _INIT_SCALAR_NAMES: + if hasattr(model, attr_name) and not isinstance( + getattr(model, attr_name), (torch.Tensor, torch.nn.Module) + ): + params.setdefault(attr_name, getattr(model, attr_name)) + + return params + + +def test_kernel(): + device = "cuda" + dtype = torch.bfloat16 + + # Setup reference model + model = Model(*get_init_inputs()).to(device).to(dtype) + inputs = [ + ( + x.to(device).to(dtype) + if isinstance(x, torch.Tensor) and x.is_floating_point() + else (x.to(device) if isinstance(x, torch.Tensor) else x) + ) + for x in get_inputs() + ] + + # Get reference output + with torch.no_grad(): + ref_output = model(*inputs) + + # Smart parameter binding: detect if kernel needs model params + sig = inspect.signature(kernel_function) + kernel_params = list(sig.parameters.keys()) + param_kinds = [p.kind for p in sig.parameters.values()] + has_var_positional = any(k == inspect.Parameter.VAR_POSITIONAL for k in param_kinds) + has_var_keyword = any(k == inspect.Parameter.VAR_KEYWORD for k in param_kinds) + _MODEL_PARAM_NAMES = { + "weight", + "w", + "kernel_size", + "stride", + "padding", + "dilation", + "output_padding", + "groups", + "bias", + "conv_bias", + "eps", + "num_groups", + "normalized_shape", + "dim", + "negative_slope", + "min_val", + "max_val", + "beta", + "threshold", + "alpha", + "lambd", + "upper", + "lower", + "p", + } + needs_model = bool(_MODEL_PARAM_NAMES & set(kernel_params)) + # If kernel uses *args/**kwargs, inspect its source for weight-related hints + if not needs_model and (has_var_positional or has_var_keyword): + try: + src = inspect.getsource(kernel_function) + needs_model = any( + kw in src + for kw in ( + "weight", + "is_weight", + "w.shape", + "w.ndim", + "kernel_size", + "dilation", + ) + ) + except (OSError, TypeError): + pass + + if needs_model: + model_params = _extract_model_params(model) + has_weight = "weight" in model_params or "w" in model_params + if has_var_positional and has_weight: + # *args kernel with weight: pass (input, weight1, weight2, ...) positionally + pos_args = list(inputs) + # Collect ALL conv/linear weights from model + for _, mod in model.named_modules(): + if isinstance(mod, (*_CONV_TYPES, torch.nn.Linear)): + if hasattr(mod, "weight") and mod.weight is not None: + pos_args.append(mod.weight) + # Pass config params as kwargs + config_kwargs = {} + for k, v in model_params.items(): + if k not in ("weight", "w", "bias", "conv_bias", "add_bias"): + # Convert uniform tuples to scalar int for compatibility + if ( + isinstance(v, (tuple, list)) + and len(v) >= 1 + and all(e == v[0] for e in v) + ): + v = v[0] + config_kwargs[k] = v + kernel_output = kernel_function(*pos_args, **config_kwargs) + else: + # Bind keyword args, adapting tuple/int form to match defaults + call_args = {} + pos_idx = 0 + for pname in kernel_params: + p = sig.parameters[pname] + if ( + p.kind == inspect.Parameter.VAR_POSITIONAL + or p.kind == inspect.Parameter.VAR_KEYWORD + ): + continue + if pname in model_params: + val = model_params[pname] + # Convert tuple/list to scalar when kernel expects int + if isinstance(val, (tuple, list)): + if p.default is not inspect.Parameter.empty and isinstance( + p.default, int + ): + val = val[0] + elif len(val) == 1: + val = val[0] + call_args[pname] = val + elif pos_idx < len(inputs): + call_args[pname] = inputs[pos_idx] + pos_idx += 1 + kernel_output = kernel_function(**call_args) + else: + kernel_output = kernel_function(*inputs) + + # Compare + # Handle in-place kernels that return None + if kernel_output is None: + # Assume in-place modification of first input + kernel_output = inputs[0] + # Handle shape mismatch: kernel may return per-sample loss vs reference scalar mean + if ref_output.dim() == 0 and kernel_output.dim() >= 1: + kernel_output = kernel_output.mean() + elif kernel_output.dim() == 0 and ref_output.dim() >= 1: + ref_output = ref_output.mean() + # Align dtypes for comparison + if ref_output.dtype != kernel_output.dtype: + # If kernel outputs higher precision, recompute reference at that precision + # using the SAME inputs to ensure fair comparison + if kernel_output.dtype == torch.float32 and ref_output.dtype in ( + torch.bfloat16, + torch.float16, + ): + model_f32 = Model(*get_init_inputs()).to(device).to(torch.float32) + inputs_f32 = [ + x.to(torch.float32) if x.is_floating_point() else x for x in inputs + ] + with torch.no_grad(): + ref_output = model_f32(*inputs_f32) + else: + kernel_output = kernel_output.to(ref_output.dtype) + if torch.allclose(ref_output, kernel_output, rtol=1e-2, atol=1e-2): + print("PASS") + return True + else: + max_diff = (ref_output - kernel_output).abs().max().item() + print(f"FAIL: max difference = {max_diff}") + return False + + +if __name__ == "__main__": + success = test_kernel() + sys.exit(0 if success else 1) diff --git a/examples/optimize_02_rmsnorm/input.py b/examples/optimize_02_rmsnorm/input.py new file mode 100644 index 00000000..e5e7744d --- /dev/null +++ b/examples/optimize_02_rmsnorm/input.py @@ -0,0 +1,259 @@ +import torch +import triton +import triton.language as tl + + +""" +RMS Normalization over the channel/feature dimension (dim=1) for NCHW tensors using Triton. + +Fusion and design notes: +- We implement the whole RMSNorm in a single Triton kernel: reduction of sum-of-squares across channels + and the normalization write-back are fused into one kernel launch. Within the kernel, we do two passes + over the input per tile: first to accumulate the sum of squares along the feature dimension, second to + apply the normalization scale and store. This avoids allocating any intermediate tensors while keeping + the Python wrapper free of compute. +- The kernel is tiled along the contiguous W dimension for coalesced loads/stores, and iterates over C + (features) to perform the reduction and normalization. Masking is used for boundary conditions. +- The wrapper supports both in-place and out-of-place operation. If no output tensor is provided, we + default to in-place to reduce peak memory consumption for the large test tensor. + +Runtime constraints: +- All math is inside the Triton kernel (tl.load/tl.store/tl.math.rsqrt, etc.). +- The Python wrapper only validates arguments, allocates output (if requested), and launches the kernel. +- No torch.nn / torch.nn.functional usage anywhere in the execution path. +""" + + +@triton.jit +def _rmsnorm_nchw_kernel( + x_ptr, + y_ptr, + N, + C, + H, + W, + stride_nx, + stride_cx, + stride_hx, + stride_wx, + stride_ny, + stride_cy, + stride_hy, + stride_wy, + eps, + BLOCK_W: tl.constexpr, +): + # 2D launch: + # - axis 0 tiles along W + # - axis 1 enumerates all (N*H) rows + pid_w = tl.program_id(axis=0) + pid_nh = tl.program_id(axis=1) + + # Which n and h row are we processing? + n = pid_nh // H + h = pid_nh - n * H # equivalent to pid_nh % H + + # Offsets along W for this tile + start_w = pid_w * BLOCK_W + offs_w = start_w + tl.arange(0, BLOCK_W) + mask_w = offs_w < W + + # Cast strides and indices to int64 for address arithmetic safety + stride_nx = tl.full([], stride_nx, tl.int64) + stride_cx = tl.full([], stride_cx, tl.int64) + stride_hx = tl.full([], stride_hx, tl.int64) + stride_wx = tl.full([], stride_wx, tl.int64) + stride_ny = tl.full([], stride_ny, tl.int64) + stride_cy = tl.full([], stride_cy, tl.int64) + stride_hy = tl.full([], stride_hy, tl.int64) + stride_wy = tl.full([], stride_wy, tl.int64) + + n = n.to(tl.int64) + h = h.to(tl.int64) + offs_w_i64 = offs_w.to(tl.int64) + + # Base offsets for given (n, h) + base_nh_x = n * stride_nx + h * stride_hx + base_nh_y = n * stride_ny + h * stride_hy + + # Accumulator for sum of squares across channels (compute in float32) + acc = tl.zeros([BLOCK_W], dtype=tl.float32) + + # First pass: accumulate sum of squares along C + # Use a dynamic loop since C is provided at runtime. + for c in tl.range(0, C): + c_i64 = c.to(tl.int64) + x_offsets = base_nh_x + c_i64 * stride_cx + offs_w_i64 * stride_wx + x_vals = tl.load(x_ptr + x_offsets, mask=mask_w, other=0.0) + x_f32 = x_vals.to(tl.float32) + acc += x_f32 * x_f32 + + # Compute inverse RMS: inv_rms = 1 / sqrt(mean(x^2) + eps) + # mean is acc / C + c_f32 = tl.full([1], C, dtype=tl.float32) + mean = acc / c_f32 + inv_rms = tl.math.rsqrt(mean + eps) + + # Second pass: normalize and store + for c in tl.range(0, C): + c_i64 = c.to(tl.int64) + x_offsets = base_nh_x + c_i64 * stride_cx + offs_w_i64 * stride_wx + y_offsets = base_nh_y + c_i64 * stride_cy + offs_w_i64 * stride_wy + x_vals = tl.load(x_ptr + x_offsets, mask=mask_w, other=0.0) + x_f32 = x_vals.to(tl.float32) + y_f32 = x_f32 * inv_rms + y_vals = y_f32.to(x_vals.dtype) + tl.store(y_ptr + y_offsets, y_vals, mask=mask_w) + + +def _parse_kernel_args(x, args, kwargs): + """ + Parse flexible arguments from the test harness. + Returns: + eps (float), num_features (int or None), out_tensor (Tensor or None) + """ + # Defaults + eps = kwargs.pop("eps", None) + num_features = kwargs.pop("num_features", None) + # Some tests may pass `features=...` + if "features" in kwargs and num_features is None: + num_features = kwargs.pop("features") + # Accept multiple possible output keywords + out = kwargs.pop("out", None) + if out is None: + out = kwargs.pop("output", None) + if out is None: + out = kwargs.pop("y", None) + if out is None: + out = kwargs.pop("dst", None) + + # Handle positional args: could be (eps), (features), or (eps, features) + if len(args) == 1: + a0 = args[0] + if isinstance(a0, (int,)) and num_features is None: + num_features = int(a0) + else: + # assume eps + if eps is None: + eps = float(a0) + elif len(args) == 2: + a0, a1 = args + # try to identify by types + if isinstance(a0, (float,)) or not isinstance(a0, (int,)): + # assume eps first, features second + if eps is None: + eps = float(a0) + if num_features is None and isinstance(a1, (int,)): + num_features = int(a1) + else: + # assume features first, eps second + if num_features is None: + num_features = int(a0) + if eps is None: + eps = float(a1) + + # Finalize defaults + if eps is None: + eps = 1e-5 + # num_features can be None; we will infer from x.shape[1] + return eps, num_features, out + + +def kernel_function(x, *args, **kwargs): + """ + RMS Normalization over feature/channel dim (dim=1) for NCHW tensors on CUDA. + + Behavior: + - Normalizes each (n, h, w) vector across channels c in [0, C), computing: + rms = sqrt(mean(x[n, :, h, w]^2) + eps) + y[n, c, h, w] = x[n, c, h, w] / rms + - Uses a single fused Triton kernel launch with a two-pass streaming strategy: + 1) Reduce sum of squares across C + 2) Apply scale and write normalized values + This avoids Python-side compute and keeps memory usage low (no large intermediates). + - If an output tensor is provided via out/output/y/dst, writes there. Otherwise, performs in-place + normalization on x to minimize peak memory. + + Accepted call patterns (examples): + - kernel_function(x) + - kernel_function(x, eps) + - kernel_function(x, features) + - kernel_function(x, eps, features) + - kernel_function(x, num_features=..., eps=...) + - kernel_function(x, out=prealloc), kernel_function(x, output=...), y=..., dst=... + + Args: + x: CUDA tensor with shape (N, C, H, W). Dtype: float16 or bfloat16 recommended. + eps: small epsilon for numerical stability (default 1e-5) + num_features: expected C; if provided, validated against x.shape[1] + out/output/y/dst: optional output tensor. If omitted, operation runs in-place on x. + + Returns: + The normalized tensor (same shape/type/device as x). If run in-place and returning None is + acceptable to the caller, you may still return x for convenience. + """ + if not isinstance(x, torch.Tensor): + raise TypeError("x must be a torch.Tensor") + if x.device.type != "cuda": + raise ValueError("x must be on CUDA device") + if x.ndim != 4: + raise ValueError(f"Expected 4D NCHW tensor, got shape {tuple(x.shape)}") + if x.dtype not in (torch.float16, torch.bfloat16, torch.float32): + raise ValueError( + f"Unsupported dtype: {x.dtype}. Use float16, bfloat16, or float32." + ) + + eps, num_features, out = _parse_kernel_args(x, args, kwargs) + + N, C, H, W = x.shape + if num_features is not None and int(num_features) != C: + raise ValueError( + f"num_features ({num_features}) does not match input channels ({C})." + ) + + # Setup output tensor. If not provided, do in-place to save memory (huge tensors in the test). + if out is None: + # In-place: write results directly to x + y = x + else: + if not isinstance(out, torch.Tensor): + raise TypeError("Provided output must be a torch.Tensor") + if out.shape != x.shape or out.device != x.device or out.dtype != x.dtype: + raise ValueError( + "Output tensor must match input in shape, device, and dtype." + ) + y = out + + # Strides in elements + sx0, sx1, sx2, sx3 = x.stride() + sy0, sy1, sy2, sy3 = y.stride() + + # Kernel launch configuration + # Tile along W for coalesced access + BLOCK_W = 256 + grid = (triton.cdiv(W, BLOCK_W), N * H) + + # Launch kernel + _rmsnorm_nchw_kernel[grid]( + x, + y, + N, + C, + H, + W, + sx0, + sx1, + sx2, + sx3, + sy0, + sy1, + sy2, + sy3, + float(eps), + BLOCK_W=BLOCK_W, + num_warps=4, + num_stages=2, + ) + + # Return result tensor. If in-place, return x to satisfy callers expecting a Tensor. + return y diff --git a/examples/optimize_02_rmsnorm/problem.py b/examples/optimize_02_rmsnorm/problem.py new file mode 100644 index 00000000..708c7001 --- /dev/null +++ b/examples/optimize_02_rmsnorm/problem.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn + + +class Model(nn.Module): + """ + Simple model that performs RMS Normalization. + """ + + def __init__(self, num_features: int, eps: float = 1e-5): + """ + Initializes the RMSNorm layer. + + Args: + num_features (int): Number of features in the input tensor. + eps (float, optional): A small value added to the denominator to avoid division by zero. Defaults to 1e-5. + """ + super(Model, self).__init__() + self.num_features = num_features + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Applies RMS Normalization to the input tensor. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, num_features, *). + + Returns: + torch.Tensor: Output tensor with RMS Normalization applied, same shape as input. + """ + # Calculate the RMS along the feature dimension + rms = torch.sqrt(torch.mean(x**2, dim=1, keepdim=True) + self.eps) + + # Normalize the input by dividing by the RMS + return x / rms + + +batch_size = 112 +features = 64 +dim1 = 512 +dim2 = 512 + + +def get_inputs(): + x = torch.rand(batch_size, features, dim1, dim2) + return [x] + + +def get_init_inputs(): + return [features] diff --git a/examples/optimize_02_rmsnorm/test.py b/examples/optimize_02_rmsnorm/test.py new file mode 100644 index 00000000..b489c477 --- /dev/null +++ b/examples/optimize_02_rmsnorm/test.py @@ -0,0 +1,270 @@ +"""Correctness test for RMSNorm kernel.""" + +import inspect +import sys +import torch + +from problem import Model, get_inputs, get_init_inputs +from kernel import kernel_function + +_CONV_TYPES = ( + torch.nn.Conv1d, + torch.nn.Conv2d, + torch.nn.Conv3d, + torch.nn.ConvTranspose1d, + torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d, +) +_NORM_TYPES = ( + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + torch.nn.LayerNorm, + torch.nn.GroupNorm, + torch.nn.InstanceNorm1d, + torch.nn.InstanceNorm2d, + torch.nn.InstanceNorm3d, +) +_POOL_TYPES = ( + torch.nn.MaxPool1d, + torch.nn.MaxPool2d, + torch.nn.MaxPool3d, + torch.nn.AvgPool1d, + torch.nn.AvgPool2d, + torch.nn.AvgPool3d, + torch.nn.AdaptiveAvgPool1d, + torch.nn.AdaptiveAvgPool2d, + torch.nn.AdaptiveAvgPool3d, + torch.nn.AdaptiveMaxPool1d, + torch.nn.AdaptiveMaxPool2d, + torch.nn.AdaptiveMaxPool3d, +) + + +def _extract_model_params(model): + """Extract learnable parameters and layer config from a PyTorch model.""" + params = {} + + for _, module in model.named_modules(): + if isinstance(module, (*_CONV_TYPES, torch.nn.Linear)): + if hasattr(module, "weight") and module.weight is not None: + params.setdefault("weight", module.weight) + params.setdefault("w", module.weight) + if getattr(module, "bias", None) is not None: + params.setdefault("conv_bias", module.bias) + params.setdefault("bias", module.bias) + for attr in ("stride", "padding", "dilation", "output_padding"): + val = getattr(module, attr, None) + if val is not None: + params.setdefault(attr, val) + if hasattr(module, "groups"): + params.setdefault("groups", module.groups) + + elif isinstance(module, _NORM_TYPES): + if getattr(module, "weight", None) is not None: + params.setdefault("weight", module.weight) + params.setdefault("w", module.weight) + if getattr(module, "bias", None) is not None: + params.setdefault("bias", module.bias) + if hasattr(module, "eps"): + params["eps"] = module.eps + if hasattr(module, "num_groups"): + params["num_groups"] = module.num_groups + if hasattr(module, "normalized_shape"): + params["normalized_shape"] = module.normalized_shape + + elif isinstance(module, _POOL_TYPES): + for attr in ("kernel_size", "stride", "padding", "dilation"): + val = getattr(module, attr, None) + if val is not None: + params.setdefault(attr, val) + + if hasattr(model, "bias") and isinstance( + model.bias, (torch.Tensor, torch.nn.Parameter) + ): + params["add_bias"] = model.bias + params.setdefault("bias", model.bias) + + # Extract simple scalar attributes stored by Model.__init__ + # (catches dim, negative_slope, min_val, max_val, etc.) + _INIT_SCALAR_NAMES = { + "dim", + "negative_slope", + "min_val", + "max_val", + "beta", + "threshold", + "alpha", + "lambd", + "upper", + "lower", + "p", + } + for attr_name in _INIT_SCALAR_NAMES: + if hasattr(model, attr_name) and not isinstance( + getattr(model, attr_name), (torch.Tensor, torch.nn.Module) + ): + params.setdefault(attr_name, getattr(model, attr_name)) + + return params + + +def test_kernel(): + device = "cuda" + dtype = torch.bfloat16 + + # Setup reference model + model = Model(*get_init_inputs()).to(device).to(dtype) + inputs = [ + x.to(device).to(dtype) + if isinstance(x, torch.Tensor) and x.is_floating_point() + else (x.to(device) if isinstance(x, torch.Tensor) else x) + for x in get_inputs() + ] + + # Get reference output + with torch.no_grad(): + ref_output = model(*inputs) + + # Smart parameter binding: detect if kernel needs model params + sig = inspect.signature(kernel_function) + kernel_params = list(sig.parameters.keys()) + param_kinds = [p.kind for p in sig.parameters.values()] + has_var_positional = any(k == inspect.Parameter.VAR_POSITIONAL for k in param_kinds) + has_var_keyword = any(k == inspect.Parameter.VAR_KEYWORD for k in param_kinds) + _MODEL_PARAM_NAMES = { + "weight", + "w", + "kernel_size", + "stride", + "padding", + "dilation", + "output_padding", + "groups", + "bias", + "conv_bias", + "eps", + "num_groups", + "normalized_shape", + "dim", + "negative_slope", + "min_val", + "max_val", + "beta", + "threshold", + "alpha", + "lambd", + "upper", + "lower", + "p", + } + needs_model = bool(_MODEL_PARAM_NAMES & set(kernel_params)) + # If kernel uses *args/**kwargs, inspect its source for weight-related hints + if not needs_model and (has_var_positional or has_var_keyword): + try: + src = inspect.getsource(kernel_function) + needs_model = any( + kw in src + for kw in ( + "weight", + "is_weight", + "w.shape", + "w.ndim", + "kernel_size", + "dilation", + ) + ) + except (OSError, TypeError): + pass + + if needs_model: + model_params = _extract_model_params(model) + has_weight = "weight" in model_params or "w" in model_params + if has_var_positional and has_weight: + # *args kernel with weight: pass (input, weight1, weight2, ...) positionally + pos_args = list(inputs) + # Collect ALL conv/linear weights from model + for _, mod in model.named_modules(): + if isinstance(mod, (*_CONV_TYPES, torch.nn.Linear)): + if hasattr(mod, "weight") and mod.weight is not None: + pos_args.append(mod.weight) + # Pass config params as kwargs + config_kwargs = {} + for k, v in model_params.items(): + if k not in ("weight", "w", "bias", "conv_bias", "add_bias"): + # Convert uniform tuples to scalar int for compatibility + if ( + isinstance(v, (tuple, list)) + and len(v) >= 1 + and all(e == v[0] for e in v) + ): + v = v[0] + config_kwargs[k] = v + kernel_output = kernel_function(*pos_args, **config_kwargs) + else: + # Bind keyword args, adapting tuple/int form to match defaults + call_args = {} + pos_idx = 0 + for pname in kernel_params: + p = sig.parameters[pname] + if ( + p.kind == inspect.Parameter.VAR_POSITIONAL + or p.kind == inspect.Parameter.VAR_KEYWORD + ): + continue + if pname in model_params: + val = model_params[pname] + # Convert tuple/list to scalar when kernel expects int + if isinstance(val, (tuple, list)): + if p.default is not inspect.Parameter.empty and isinstance( + p.default, int + ): + val = val[0] + elif len(val) == 1: + val = val[0] + call_args[pname] = val + elif pos_idx < len(inputs): + call_args[pname] = inputs[pos_idx] + pos_idx += 1 + kernel_output = kernel_function(**call_args) + else: + kernel_output = kernel_function(*inputs) + + # Compare + # Handle in-place kernels that return None + if kernel_output is None: + # Assume in-place modification of first input + kernel_output = inputs[0] + # Handle shape mismatch: kernel may return per-sample loss vs reference scalar mean + if ref_output.dim() == 0 and kernel_output.dim() >= 1: + kernel_output = kernel_output.mean() + elif kernel_output.dim() == 0 and ref_output.dim() >= 1: + ref_output = ref_output.mean() + # Align dtypes for comparison + if ref_output.dtype != kernel_output.dtype: + # If kernel outputs higher precision, recompute reference at that precision + # using the SAME inputs to ensure fair comparison + if kernel_output.dtype == torch.float32 and ref_output.dtype in ( + torch.bfloat16, + torch.float16, + ): + model_f32 = Model(*get_init_inputs()).to(device).to(torch.float32) + inputs_f32 = [ + x.to(torch.float32) if x.is_floating_point() else x for x in inputs + ] + with torch.no_grad(): + ref_output = model_f32(*inputs_f32) + else: + kernel_output = kernel_output.to(ref_output.dtype) + if torch.allclose(ref_output, kernel_output, rtol=1e-2, atol=1e-2): + print("PASS") + return True + else: + max_diff = (ref_output - kernel_output).abs().max().item() + print(f"FAIL: max difference = {max_diff}") + return False + + +if __name__ == "__main__": + success = test_kernel() + sys.exit(0 if success else 1) diff --git a/examples/optimize_03_max_pooling/input.py b/examples/optimize_03_max_pooling/input.py new file mode 100644 index 00000000..8281ed6c --- /dev/null +++ b/examples/optimize_03_max_pooling/input.py @@ -0,0 +1,224 @@ +# kernel.py +# Triton-based MaxPool3d implementation specialized to the test configuration: +# - Input tensor shape: (N=16, C=32, D=128, H=128, W=128) +# - Pooling params typically called by the test: kernel_size=3, stride=2, padding=1, dilation=3 +# +# Notes on fusion: +# - MaxPool3d is a standalone reduction operator. There is no natural upstream/downstream op specified +# in the test to fuse with (e.g., bias, activation), so this kernel focuses on an efficient single-pass +# pooling implementation. If a pipeline included additional pointwise ops on the pooled output, those +# could be fused into the epilogue to reduce memory traffic. + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _maxpool3d_kernel( + x_ptr, # *const T + y_ptr, # *T + N, + C, + D, + H, + W, # input sizes + OD, + OH, + OW, # output sizes + strideN, + strideC, + strideD, + strideH, + strideW, # input strides (in elements) + ostrideN, + ostrideC, + ostrideD, + ostrideH, + ostrideW, # output strides (in elements) + KERNEL_SIZE: tl.constexpr, # pool kernel size (assumed cubic here) + STRIDE: tl.constexpr, # pool stride (assumed same across dims) + PADDING: tl.constexpr, # pool padding (assumed same across dims) + DILATION: tl.constexpr, # dilation (assumed same across dims) + BLOCK_W: tl.constexpr, # vectorized span of OW per program +): + # Program ids: + # axis 0: blocks along OW + # axis 1: specific OD index + # axis 2: flattened (N*C*OH) + pid_w = tl.program_id(axis=0) + pid_d = tl.program_id(axis=1) + pid_z = tl.program_id(axis=2) + + # Decode pid_z into (n, c, oh) + oh = pid_z % OH + nc = pid_z // OH + c = nc % C + n = nc // C + + # Compute OW offsets this program handles + ow_start = pid_w * BLOCK_W + ow_offsets = ow_start + tl.arange(0, BLOCK_W) + ow_mask = ow_offsets < OW + + # Output indices along D and H (scalars per program) + od = pid_d + + # Base starts in input space for the pooling window + d_base = od * STRIDE - PADDING + h_base = oh * STRIDE - PADDING + # Vector of base W inputs for this block + w_base = ow_offsets * STRIDE - PADDING + + # Accumulator in fp32 for numerical robustness (stores max across the 3x3x3 window) + acc = tl.full([BLOCK_W], -float("inf"), dtype=tl.float32) + + # Precompute base strides for (n, c) + base_nc = n * strideN + c * strideC + + # Iterate over the pooling window (kd, kh, kw) with compile-time unrolling + for kd in tl.static_range(0, KERNEL_SIZE): + d_idx = d_base + kd * DILATION + valid_d = (d_idx >= 0) & (d_idx < D) + # Safe index to keep addresses in-bounds for masked loads + d_idx_safe = tl.where(valid_d, d_idx, 0) + + for kh in tl.static_range(0, KERNEL_SIZE): + h_idx = h_base + kh * DILATION + valid_h = (h_idx >= 0) & (h_idx < H) + valid_dh = valid_d & valid_h + h_idx_safe = tl.where(valid_h, h_idx, 0) + + # Base pointer for current (n, c, d_idx, h_idx) + base_dh = base_nc + d_idx_safe * strideD + h_idx_safe * strideH + + for kw in tl.static_range(0, KERNEL_SIZE): + w_idx = w_base + kw * DILATION + # Check bounds per-lane; combine with ow_mask and valid_dh + w_valid = ow_mask & (w_idx >= 0) & (w_idx < W) & valid_dh + w_idx_safe = tl.where(w_valid, w_idx, 0) + + # Element pointers for this (kd, kh, kw) and OW lanes + ptrs = x_ptr + base_dh + w_idx_safe * strideW + + # Load with mask; out-of-bounds lanes use -inf so they don't affect max + vals = tl.load(ptrs, mask=w_valid, other=-float("inf")) + vals_f32 = vals.to(tl.float32) + acc = tl.maximum(acc, vals_f32) + + # Store result + out_base = y_ptr + n * ostrideN + c * ostrideC + od * ostrideD + oh * ostrideH + out_ptrs = out_base + ow_offsets * ostrideW + tl.store(out_ptrs, acc, mask=ow_mask) + + +def _compute_out_dim( + L_in: int, kernel: int, stride: int, padding: int, dilation: int +) -> int: + # PyTorch formula: floor((L_in + 2*padding - dilation*(kernel - 1) - 1) / stride + 1) + return (L_in + 2 * padding - dilation * (kernel - 1) - 1) // stride + 1 + + +def kernel_function( + x: torch.Tensor, kernel_size: int, stride: int, padding: int, dilation: int +): + """ + Triton-backed 3D Max Pooling (no indices), compatible with the test's call signature. + + Args: + x: Input tensor of shape (N, C, D, H, W), CUDA device. + kernel_size: int, pooling kernel size (assumed cubic) + stride: int, pooling stride (assumed same for D/H/W) + padding: int, zero-padding applied on each side (assumed same for D/H/W) + dilation: int, dilation factor (assumed same for D/H/W) + + Returns: + y: Output tensor of shape (N, C, OD, OH, OW) with the same dtype/device as x. + + Design and fusion notes: + - This is a single-pass, fused pooling reduction: it computes the maximum over the 3D dilated window + directly from global memory and writes the result, with masking to handle padding/boundaries. + - No additional post-processing stages are specified in the test; thus, there are no further ops to fuse. + If a follow-up pointwise op were known, it could be integrated into the epilogue to reduce memory traffic. + + Runtime policy: + - The wrapper only validates arguments, computes output shape, allocates the output tensor, and launches + the Triton kernel. All math (window traversal and reduction) happens inside the Triton kernel. + """ + # Basic checks + if not x.is_cuda: + raise ValueError("Input must be a CUDA tensor.") + if x.ndim != 5: + raise ValueError( + f"Expected 5D input (N, C, D, H, W), got shape {tuple(x.shape)}" + ) + if ( + not isinstance(kernel_size, int) + or not isinstance(stride, int) + or not isinstance(padding, int) + or not isinstance(dilation, int) + ): + raise TypeError("kernel_size, stride, padding, dilation must be ints.") + + N, C, D, H, W = x.shape + K = kernel_size + S = stride + P = padding + Di = dilation + + # Compute output shape + OD = _compute_out_dim(D, K, S, P, Di) + OH = _compute_out_dim(H, K, S, P, Di) + OW = _compute_out_dim(W, K, S, P, Di) + if OD <= 0 or OH <= 0 or OW <= 0: + raise ValueError( + "Computed non-positive output dimension(s). Check pooling parameters." + ) + + # Allocate output + y = torch.empty((N, C, OD, OH, OW), device=x.device, dtype=x.dtype) + + # Get strides in "element" units (PyTorch strides are already in elements, not bytes) + strideN, strideC, strideD, strideH, strideW = x.stride() + ostrideN, ostrideC, ostrideD, ostrideH, ostrideW = y.stride() + + # Configure launch + # We tile along OW dimension with BLOCK_W elements per program. + # OW is 62 in the test, so BLOCK_W=64 covers each row in one program; remaining lanes are masked. + BLOCK_W = 64 + + def grid(meta): + return (triton.cdiv(OW, meta["BLOCK_W"]), OD, N * C * OH) + + # Launch kernel + _maxpool3d_kernel[grid]( + x, + y, + N, + C, + D, + H, + W, + OD, + OH, + OW, + strideN, + strideC, + strideD, + strideH, + strideW, + ostrideN, + ostrideC, + ostrideD, + ostrideH, + ostrideW, + KERNEL_SIZE=K, + STRIDE=S, + PADDING=P, + DILATION=Di, + BLOCK_W=BLOCK_W, + num_warps=4, # Reasonable default for this memory-bound reduction + num_stages=2, + ) + + return y diff --git a/examples/optimize_03_max_pooling/problem.py b/examples/optimize_03_max_pooling/problem.py new file mode 100644 index 00000000..3859ba83 --- /dev/null +++ b/examples/optimize_03_max_pooling/problem.py @@ -0,0 +1,70 @@ +import torch +import torch.nn as nn + + +class Model(nn.Module): + """ + Simple model that performs Max Pooling 3D. + """ + + def __init__( + self, + kernel_size: int, + stride: int = None, + padding: int = 0, + dilation: int = 1, + return_indices: bool = False, + ceil_mode: bool = False, + ): + """ + Initializes the Max Pooling 3D layer. + + Args: + kernel_size (int): Size of the kernel for the max pooling operation. + stride (int, optional): Stride of the pooling operation. Defaults to None, which means stride is equal to kernel_size. + padding (int, optional): Padding applied to the input tensor. Defaults to 0. + dilation (int, optional): Spacing between kernel elements. Defaults to 1. + return_indices (bool, optional): Whether to return indices of the maximum values. Defaults to False. + ceil_mode (bool, optional): When True, the output size is ceil(input_size / stride) instead of floor. Defaults to False. + """ + super(Model, self).__init__() + self.maxpool = nn.MaxPool3d( + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + return_indices=return_indices, + ceil_mode=ceil_mode, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Applies Max Pooling 3D to the input tensor. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, channels, dim1, dim2, dim3). + + Returns: + torch.Tensor: Output tensor with Max Pooling 3D applied. + """ + return self.maxpool(x) + + +batch_size = 16 +channels = 32 +dim1 = 128 +dim2 = 128 +dim3 = 128 +kernel_size = 3 +stride = 2 +padding = 1 +dilation = 3 + + +def get_inputs(): + x = torch.rand(batch_size, channels, dim1, dim2, dim3) + return [x] + + +def get_init_inputs(): + return [kernel_size, stride, padding, dilation] diff --git a/examples/optimize_03_max_pooling/test.py b/examples/optimize_03_max_pooling/test.py new file mode 100644 index 00000000..575c5a9e --- /dev/null +++ b/examples/optimize_03_max_pooling/test.py @@ -0,0 +1,272 @@ +"""Correctness test for 3D max pooling kernel.""" + +import inspect +import sys + +import torch +from kernel import kernel_function +from problem import get_init_inputs, get_inputs, Model + +_CONV_TYPES = ( + torch.nn.Conv1d, + torch.nn.Conv2d, + torch.nn.Conv3d, + torch.nn.ConvTranspose1d, + torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d, +) +_NORM_TYPES = ( + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + torch.nn.LayerNorm, + torch.nn.GroupNorm, + torch.nn.InstanceNorm1d, + torch.nn.InstanceNorm2d, + torch.nn.InstanceNorm3d, +) +_POOL_TYPES = ( + torch.nn.MaxPool1d, + torch.nn.MaxPool2d, + torch.nn.MaxPool3d, + torch.nn.AvgPool1d, + torch.nn.AvgPool2d, + torch.nn.AvgPool3d, + torch.nn.AdaptiveAvgPool1d, + torch.nn.AdaptiveAvgPool2d, + torch.nn.AdaptiveAvgPool3d, + torch.nn.AdaptiveMaxPool1d, + torch.nn.AdaptiveMaxPool2d, + torch.nn.AdaptiveMaxPool3d, +) + + +def _extract_model_params(model): + """Extract learnable parameters and layer config from a PyTorch model.""" + params = {} + + for _, module in model.named_modules(): + if isinstance(module, (*_CONV_TYPES, torch.nn.Linear)): + if hasattr(module, "weight") and module.weight is not None: + params.setdefault("weight", module.weight) + params.setdefault("w", module.weight) + if getattr(module, "bias", None) is not None: + params.setdefault("conv_bias", module.bias) + params.setdefault("bias", module.bias) + for attr in ("stride", "padding", "dilation", "output_padding"): + val = getattr(module, attr, None) + if val is not None: + params.setdefault(attr, val) + if hasattr(module, "groups"): + params.setdefault("groups", module.groups) + + elif isinstance(module, _NORM_TYPES): + if getattr(module, "weight", None) is not None: + params.setdefault("weight", module.weight) + params.setdefault("w", module.weight) + if getattr(module, "bias", None) is not None: + params.setdefault("bias", module.bias) + if hasattr(module, "eps"): + params["eps"] = module.eps + if hasattr(module, "num_groups"): + params["num_groups"] = module.num_groups + if hasattr(module, "normalized_shape"): + params["normalized_shape"] = module.normalized_shape + + elif isinstance(module, _POOL_TYPES): + for attr in ("kernel_size", "stride", "padding", "dilation"): + val = getattr(module, attr, None) + if val is not None: + params.setdefault(attr, val) + + if hasattr(model, "bias") and isinstance( + model.bias, (torch.Tensor, torch.nn.Parameter) + ): + params["add_bias"] = model.bias + params.setdefault("bias", model.bias) + + # Extract simple scalar attributes stored by Model.__init__ + # (catches dim, negative_slope, min_val, max_val, etc.) + _INIT_SCALAR_NAMES = { + "dim", + "negative_slope", + "min_val", + "max_val", + "beta", + "threshold", + "alpha", + "lambd", + "upper", + "lower", + "p", + } + for attr_name in _INIT_SCALAR_NAMES: + if hasattr(model, attr_name) and not isinstance( + getattr(model, attr_name), (torch.Tensor, torch.nn.Module) + ): + params.setdefault(attr_name, getattr(model, attr_name)) + + return params + + +def test_kernel(): + device = "cuda" + dtype = torch.bfloat16 + + # Setup reference model + model = Model(*get_init_inputs()).to(device).to(dtype) + inputs = [ + ( + x.to(device).to(dtype) + if isinstance(x, torch.Tensor) and x.is_floating_point() + else (x.to(device) if isinstance(x, torch.Tensor) else x) + ) + for x in get_inputs() + ] + + # Get reference output + with torch.no_grad(): + ref_output = model(*inputs) + + # Smart parameter binding: detect if kernel needs model params + sig = inspect.signature(kernel_function) + kernel_params = list(sig.parameters.keys()) + param_kinds = [p.kind for p in sig.parameters.values()] + has_var_positional = any(k == inspect.Parameter.VAR_POSITIONAL for k in param_kinds) + has_var_keyword = any(k == inspect.Parameter.VAR_KEYWORD for k in param_kinds) + _MODEL_PARAM_NAMES = { + "weight", + "w", + "kernel_size", + "stride", + "padding", + "dilation", + "output_padding", + "groups", + "bias", + "conv_bias", + "eps", + "num_groups", + "normalized_shape", + "dim", + "negative_slope", + "min_val", + "max_val", + "beta", + "threshold", + "alpha", + "lambd", + "upper", + "lower", + "p", + } + needs_model = bool(_MODEL_PARAM_NAMES & set(kernel_params)) + # If kernel uses *args/**kwargs, inspect its source for weight-related hints + if not needs_model and (has_var_positional or has_var_keyword): + try: + src = inspect.getsource(kernel_function) + needs_model = any( + kw in src + for kw in ( + "weight", + "is_weight", + "w.shape", + "w.ndim", + "kernel_size", + "dilation", + ) + ) + except (OSError, TypeError): + pass + + if needs_model: + model_params = _extract_model_params(model) + has_weight = "weight" in model_params or "w" in model_params + if has_var_positional and has_weight: + # *args kernel with weight: pass (input, weight1, weight2, ...) positionally + pos_args = list(inputs) + # Collect ALL conv/linear weights from model + for _, mod in model.named_modules(): + if isinstance(mod, (*_CONV_TYPES, torch.nn.Linear)): + if hasattr(mod, "weight") and mod.weight is not None: + pos_args.append(mod.weight) + # Pass config params as kwargs + config_kwargs = {} + for k, v in model_params.items(): + if k not in ("weight", "w", "bias", "conv_bias", "add_bias"): + # Convert uniform tuples to scalar int for compatibility + if ( + isinstance(v, (tuple, list)) + and len(v) >= 1 + and all(e == v[0] for e in v) + ): + v = v[0] + config_kwargs[k] = v + kernel_output = kernel_function(*pos_args, **config_kwargs) + else: + # Bind keyword args, adapting tuple/int form to match defaults + call_args = {} + pos_idx = 0 + for pname in kernel_params: + p = sig.parameters[pname] + if ( + p.kind == inspect.Parameter.VAR_POSITIONAL + or p.kind == inspect.Parameter.VAR_KEYWORD + ): + continue + if pname in model_params: + val = model_params[pname] + # Convert tuple/list to scalar when kernel expects int + if isinstance(val, (tuple, list)): + if p.default is not inspect.Parameter.empty and isinstance( + p.default, int + ): + val = val[0] + elif len(val) == 1: + val = val[0] + call_args[pname] = val + elif pos_idx < len(inputs): + call_args[pname] = inputs[pos_idx] + pos_idx += 1 + kernel_output = kernel_function(**call_args) + else: + kernel_output = kernel_function(*inputs) + + # Compare + # Handle in-place kernels that return None + if kernel_output is None: + # Assume in-place modification of first input + kernel_output = inputs[0] + # Handle shape mismatch: kernel may return per-sample loss vs reference scalar mean + if ref_output.dim() == 0 and kernel_output.dim() >= 1: + kernel_output = kernel_output.mean() + elif kernel_output.dim() == 0 and ref_output.dim() >= 1: + ref_output = ref_output.mean() + # Align dtypes for comparison + if ref_output.dtype != kernel_output.dtype: + # If kernel outputs higher precision, recompute reference at that precision + # using the SAME inputs to ensure fair comparison + if kernel_output.dtype == torch.float32 and ref_output.dtype in ( + torch.bfloat16, + torch.float16, + ): + model_f32 = Model(*get_init_inputs()).to(device).to(torch.float32) + inputs_f32 = [ + x.to(torch.float32) if x.is_floating_point() else x for x in inputs + ] + with torch.no_grad(): + ref_output = model_f32(*inputs_f32) + else: + kernel_output = kernel_output.to(ref_output.dtype) + if torch.allclose(ref_output, kernel_output, rtol=1e-2, atol=1e-2): + print("PASS") + return True + else: + max_diff = (ref_output - kernel_output).abs().max().item() + print(f"FAIL: max difference = {max_diff}") + return False + + +if __name__ == "__main__": + success = test_kernel() + sys.exit(0 if success else 1) diff --git a/pyproject.toml b/pyproject.toml index b34a3097..833b6a45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dev = [ fuser-ui = "scripts.fuser_ui:main" kernel-agent = "scripts.triton_ui:main" list-models = "scripts.list_models:main" +optimization-ui = "scripts.optimization_ui:main" pipeline-ui = "scripts.pipeline_ui:main" [project.urls] diff --git a/scripts/optimization_ui.py b/scripts/optimization_ui.py new file mode 100644 index 00000000..8e956d9a --- /dev/null +++ b/scripts/optimization_ui.py @@ -0,0 +1,1013 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Gradio UI for the hardware-guided kernel optimization pipeline.""" + +from __future__ import annotations + +import argparse +import logging +import os +import re +import sys +import threading +import time +import traceback +from pathlib import Path + +import gradio as gr +from dotenv import load_dotenv + +# Ensure project root is importable when run as a script. +_PROJECT_ROOT = Path(__file__).resolve().parent.parent +if str(_PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(_PROJECT_ROOT)) + + +def _list_kernelbench_problems(base: Path) -> list[tuple[str, str]]: + """Return list of (label, absolute_path) pairs for KernelBench problems.""" + problems: list[tuple[str, str]] = [] + if not base.exists(): + return problems + for level_dir in sorted(base.glob("level*")): + if not level_dir.is_dir(): + continue + if level_dir.name.lower() == "level4": + continue + for problem in sorted(level_dir.glob("*.py")): + label = f"{level_dir.name}/{problem.name}" + problems.append((label, str(problem.resolve()))) + return problems + + +def _discover_problems() -> list[tuple[str, str]]: + """Find KernelBench problems from common locations.""" + candidate_roots = [ + Path.cwd() / "external" / "KernelBench" / "KernelBench", + Path.cwd() / "KernelBench" / "KernelBench", + Path.cwd().parent / "KernelBench" / "KernelBench", + Path.cwd().parent.parent / "KernelBench" / "KernelBench", + ] + seen: set[str] = set() + problems: list[tuple[str, str]] = [] + for root in candidate_roots: + for label, path in _list_kernelbench_problems(root): + if path not in seen: + seen.add(path) + problems.append((label, path)) + return problems + + +_EXAMPLES_DIR = Path(__file__).resolve().parent.parent / "examples" + +_CUSTOM_OPTION = "-- Custom (paste below) --" + + +def _discover_examples() -> list[tuple[str, str]]: + """Find optimization examples from the examples/ directory. + + Returns list of (label, directory_path) for dirs matching ``optimize_*`` + that contain ``input.py`` and ``test.py``. + """ + examples: list[tuple[str, str]] = [] + if not _EXAMPLES_DIR.is_dir(): + return examples + for d in sorted(_EXAMPLES_DIR.glob("optimize_*")): + if not d.is_dir(): + continue + if (d / "input.py").exists() and (d / "test.py").exists(): + # Turn "optimize_01_matvec" into "MatVec" + label = d.name.split("_", 2)[-1].replace("_", " ").title() + examples.append((label, str(d))) + return examples + + +def _build_input_choices() -> list[str]: + """Build the dropdown choices: examples + custom.""" + choices: list[str] = [] + for label, _ in _discover_examples(): + choices.append(f"Example: {label}") + choices.append(_CUSTOM_OPTION) + return choices + + +def _get_gpu_choices() -> list[str]: + """Return GPU names from the specs database.""" + from kernel_perf_agent.kernel_opt.diagnose_prompt.gpu_specs_database import ( + GPU_SPECS_DATABASE, + ) + + return sorted(GPU_SPECS_DATABASE.keys()) + + +def _env_var_for_model(model_name: str) -> str: + """Determine which API key env var a model needs.""" + if "claude" in model_name.lower() or "anthropic" in model_name.lower(): + return "ANTHROPIC_API_KEY" + return "OPENAI_API_KEY" + + +def _load_sibling_file(problem_path: str, filename: str) -> str: + """Load a sibling file (input.py, test.py) next to a problem file.""" + if not problem_path: + return "" + parent = Path(problem_path).parent + candidate = parent / filename + if candidate.exists(): + try: + return candidate.read_text(encoding="utf-8") + except OSError: + pass + return "" + + +def run_optimization( + problem_label: str, + kernel_code: str, + test_code: str, + model_name: str, + gpu_name: str, + max_rounds: int, + high_reasoning: bool, + platform: str, + api_key: str | None, + strategy: str = "greedy", + num_workers: int = 1, + strategy_config: dict | None = None, + problem_file_override: str | None = None, + log_capture: _LogCapture | None = None, +) -> tuple[str, str, str, str | None, str]: + """Run the optimization pipeline and return (status_md, best_kernel, log, download_path, per_round_html).""" + from triton_kernel_agent.opt_manager import OptimizationManager + + if not kernel_code or not kernel_code.strip(): + return "**Error:** No kernel code provided.", "", "", None, "" + if not test_code or not test_code.strip(): + return "**Error:** No test code provided.", "", "", None, "" + + # Resolve API key + env_var = _env_var_for_model(model_name) + user_key = api_key.strip() if api_key else None + original_env_key = os.environ.get(env_var) + temp_key_set = False + if user_key: + os.environ[env_var] = user_key + temp_key_set = True + + try: + # Set up run directory + ts = int(time.time()) + run_dir = Path.cwd() / ".optimize" / f"optimization_{ts}" + output_dir = run_dir / "output" + output_dir.mkdir(parents=True, exist_ok=True) + + # Resolve problem file: explicit override > KB label lookup > stub + if problem_file_override and Path(problem_file_override).exists(): + problem_file = Path(problem_file_override) + else: + problem_mapping = {label: path for label, path in _discover_problems()} + source_problem = problem_mapping.get(problem_label, "") + if source_problem and Path(source_problem).exists(): + problem_file = Path(source_problem) + else: + # Write a stub problem file from kernel code context + problem_file = run_dir / "problem.py" + problem_file.parent.mkdir(parents=True, exist_ok=True) + problem_file.write_text( + "# Auto-generated problem stub\n" + "import torch\nimport torch.nn as nn\n\n" + "class Model(nn.Module):\n" + " def __init__(self):\n" + " super().__init__()\n" + " def forward(self, x):\n" + " return x\n", + encoding="utf-8", + ) + + # Set up log capture on the OptimizationManager logger + if log_capture is None: + log_capture = _LogCapture() + log_capture.metadata["log_dir"] = str(run_dir) + + mgr_logger = logging.getLogger("OptimizationManager") + stream_handler = logging.StreamHandler(log_capture) + stream_handler.setFormatter( + logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") + ) + mgr_logger.addHandler(stream_handler) + + try: + manager = OptimizationManager( + strategy=strategy, + num_workers=num_workers, + max_rounds=max_rounds, + log_dir=run_dir, + strategy_config=strategy_config, + openai_model=model_name, + high_reasoning_effort=high_reasoning, + gpu_name=gpu_name, + target_platform=platform, + ) + + result = manager.run_optimization( + initial_kernel=kernel_code, + problem_file=problem_file, + test_code=test_code, + ) + finally: + mgr_logger.removeHandler(stream_handler) + + # Build status markdown + status_md = _build_status_markdown(result, strategy, num_workers) + + # Build per-round best data from database + per_round_best: dict[int, dict] = {} + try: + all_entries = manager.database.get_all() + for entry in all_entries: + gen = entry.generation + if gen is None or gen < 1: + continue + time_ms = entry.metrics.time_ms + if ( + gen not in per_round_best + or time_ms < per_round_best[gen]["time_ms"] + ): + per_round_best[gen] = { + "time_ms": time_ms, + "program_id": entry.program_id, + "kernel_code": entry.kernel_code, + } + except Exception: + pass + round_html = _build_per_round_html(per_round_best) + + # Save best kernel for download + best_kernel = result.get("kernel_code") or "" + download_path = None + if best_kernel: + best_file = output_dir / "best_kernel.py" + best_file.write_text(best_kernel, encoding="utf-8") + download_path = str(best_file) + + log_text = log_capture.getvalue() + return status_md, best_kernel, log_text, download_path, round_html + + except Exception as e: + tb = traceback.format_exc() + return f"## Error\n\n```\n{e}\n```\n\n```\n{tb}\n```", "", "", None, "" + finally: + if temp_key_set: + if original_env_key is not None: + os.environ[env_var] = original_env_key + else: + os.environ.pop(env_var, None) + + +def _build_status_markdown(result: dict, strategy: str, num_workers: int) -> str: + """Build the final status markdown from an OptimizationManager result dict.""" + if not result.get("success"): + return "## Optimization Failed\n\nNo improvement found." + + best_time = result.get("best_time_ms", 0) + total_rounds = result.get("total_rounds", 0) + top_kernels = result.get("top_kernels", []) + initial_kernel_time = result.get("initial_kernel_time_ms", float("inf")) + pytorch_baseline = result.get("pytorch_baseline_ms", float("inf")) + pytorch_compile = result.get("pytorch_compile_ms", float("inf")) + + strategy_label = ( + f"Beam Search ({num_workers} workers)" + if strategy == "beam_search" + else f"Greedy ({num_workers} worker)" + ) + + status_md = "## Optimization Complete\n\n" + status_md += "| Metric | Value |\n|---|---|\n" + status_md += f"| Best Time | {best_time:.4f} ms |\n" + if initial_kernel_time != float("inf"): + status_md += f"| Initial Kernel | {initial_kernel_time:.4f} ms |\n" + if pytorch_baseline != float("inf"): + status_md += f"| PyTorch Eager | {pytorch_baseline:.4f} ms |\n" + if pytorch_compile != float("inf"): + status_md += f"| PyTorch Compile | {pytorch_compile:.4f} ms |\n" + if initial_kernel_time != float("inf") and best_time > 0: + speedup = initial_kernel_time / best_time + status_md += f"| Speedup vs Initial | {speedup:.2f}x |\n" + status_md += f"| Rounds | {total_rounds} |\n" + status_md += f"| Strategy | {strategy_label} |\n" + + if len(top_kernels) > 1: + status_md += f"| Top Kernels | {len(top_kernels)} found |\n" + status_md += "\n### Top Kernels\n" + status_md += "| # | Time (ms) | Generation |\n|---|---|---|\n" + for i, k in enumerate(top_kernels, 1): + status_md += f"| {i} | {k['time_ms']:.4f} | {k.get('generation', '-')} |\n" + + return status_md + + +def _build_per_round_html(per_round_best: dict[int, dict]) -> str: + """Render per-round best results as collapsible HTML sections. + + Args: + per_round_best: Mapping of round number to best entry dict + with keys: time_ms, program_id, kernel_code (snippet). + + Returns: + HTML string with
/ sections, last round open. + """ + if not per_round_best: + return "" + parts = ["

Per-Round Results

"] + max_round = max(per_round_best) + for rnd in sorted(per_round_best): + entry = per_round_best[rnd] + time_ms = entry.get("time_ms", float("inf")) + prog_id = entry.get("program_id", "?") + # Last round is open by default + open_attr = " open" if rnd == max_round else "" + parts.append(f"") + parts.append( + f"Round {rnd}: {time_ms:.4f} ms (ID: {prog_id})" + ) + code = entry.get("kernel_code", "") + if code: + # Show first 30 lines as preview + lines = code.splitlines()[:30] + preview = "\n".join(lines) + if len(code.splitlines()) > 30: + preview += "\n# ... (truncated)" + parts.append(f"
{preview}
") + parts.append("
") + return "\n".join(parts) + + +class _LogCapture: + """Thread-safe stream-like object that captures log messages.""" + + def __init__(self) -> None: + self._parts: list[str] = [] + self._lock = threading.Lock() + self._read_index: int = 0 + self.metadata: dict = {} + + def write(self, msg: str) -> None: + with self._lock: + self._parts.append(msg) + + def flush(self) -> None: + pass + + def getvalue(self) -> str: + with self._lock: + return "".join(self._parts) + + def get_new_lines(self) -> str: + """Return log content appended since the last call.""" + with self._lock: + new = self._parts[self._read_index :] + self._read_index = len(self._parts) + return "".join(new) + + +# Patterns matched against the *message* portion of each log line (after the +# ``asctime - LEVEL - `` prefix). Order matters: first match wins per line. +_LOG_PATTERNS: list[tuple[re.Pattern[str], str]] = [ + # Round boundary + (re.compile(r"ROUND\s+(\d+)/(\d+)"), "round"), + # Orchestrator phase transitions (exact prefixes to avoid duplicates) + (re.compile(r"\[\d+\] Profiling current kernel with NCU"), "phase_profile"), + (re.compile(r"\[\d+\] Analyzing bottleneck"), "phase_analyze"), + (re.compile(r"\[\d+\] Using pre-computed bottleneck"), "phase_analyze"), + (re.compile(r"\[\d+\] Generating optimized kernel"), "phase_generate"), + (re.compile(r"\[\d+\] Verifying correctness"), "phase_verify"), + # Verification result + (re.compile(r"\[\d+\].*Correctness check passed"), "verify_pass"), + (re.compile(r"\[\d+\].*Correctness check failed"), "verify_fail"), + # Performance results + ( + re.compile(r"NEW BEST RUNTIME.*?(\d+\.?\d*)\s*ms.*?speedup:\s*(\d+\.?\d*)x"), + "new_best", + ), + ( + re.compile( + r"\[\d+\] No improvement:\s*(\d+\.?\d*)\s*ms.*?best\s+(\d+\.?\d*)\s*ms" + ), + "no_improve", + ), + # Manager-level baselines (must precede worker-level "Baseline time:") + (re.compile(r"PyTorch baseline:\s*(\d+\.?\d*)ms"), "pytorch_eager"), + (re.compile(r"PyTorch compile baseline:\s*(\d+\.?\d*)ms"), "pytorch_compile"), + (re.compile(r"Initial kernel time:\s*(\d+\.?\d*)ms"), "initial_kernel_time"), + (re.compile(r"Speedup vs initial kernel:\s*(\d+\.?\d*)x"), "final_speedup_initial"), + (re.compile(r"Speedup vs PyTorch eager:\s*(\d+\.?\d*)x"), "final_speedup_pytorch"), + # Worker-level baseline + (re.compile(r"Baseline time:\s*(\d+\.?\d*)\s*ms"), "baseline"), + (re.compile(r"Using known kernel time:\s*(\d+\.?\d*)\s*ms"), "baseline"), + # Roofline / SOL (orchestrator-level, has "-bound" context) + (re.compile(r"Baseline SOL:\s*(\d+\.?\d*)%.*?(\w+)-bound"), "baseline_sol"), + (re.compile(r"\[\d+\] Roofline.*?(\w+)-bound.*?(\d+\.?\d*)% SOL"), "roofline"), + # Per-round best from manager + (re.compile(r"Round (\d+) best: worker (\d+) at (\d+\.?\d*) ms"), "round_best"), + (re.compile(r"Round (\d+): no successful workers"), "round_no_success"), + # Early termination + (re.compile(r"\[\d+\].*Early termination:\s*(.+)"), "early_stop"), + # Final summary + (re.compile(r"OPTIMIZATION COMPLETE"), "done"), + (re.compile(r"Speedup vs baseline:\s*(\d+\.?\d*)x"), "final_speedup"), + # Errors + (re.compile(r"timeout|timed?\s*out", re.IGNORECASE), "error"), + (re.compile(r"LLM.*?failed", re.IGNORECASE), "error"), +] + +_WORKER_DIR_RE = re.compile(r"/w(\d+)/") + + +def _tail_worker_logs(log_dir: str, offsets: dict[str, int]) -> dict[int, str]: + """Read new content from worker log files since last poll. + + Args: + log_dir: Root log directory (same as run_dir). + offsets: Mutable dict mapping log file path -> last read position. + + Returns: + Dict mapping worker_id (int) to new log content for that worker. + """ + per_worker: dict[int, list[str]] = {} + workers_dir = Path(log_dir) / "workers" + if not workers_dir.exists(): + return {} + for log_file in sorted(workers_dir.glob("w*/r*/logs/*.log")): + path_str = str(log_file) + prev = offsets.get(path_str, 0) + wid_match = _WORKER_DIR_RE.search(path_str) + wid = int(wid_match.group(1)) if wid_match else 0 + try: + with open(log_file, encoding="utf-8", errors="replace") as f: + f.seek(prev) + chunk = f.read() + if chunk: + offsets[path_str] = prev + len(chunk) + per_worker.setdefault(wid, []).append(chunk) + except OSError: + pass + return {wid: "".join(parts) for wid, parts in per_worker.items()} + + +_TIMESTAMP_RE = re.compile(r"(\d{2}:\d{2}:\d{2})") + + +def _parse_log_for_status(raw_lines: str, manager_round: str = "") -> str: + """Extract curated status lines from raw log output, prefixed with timestamps. + + Args: + raw_lines: Raw log text. + manager_round: If set (e.g. "3/5"), worker-level "Round 1/1" lines + are rewritten to show the real manager round instead. + """ + if not raw_lines: + return "" + curated: list[str] = [] + for line in raw_lines.splitlines(): + # Extract HH:MM:SS from the log prefix + ts_match = _TIMESTAMP_RE.search(line) + ts = ts_match.group(1) if ts_match else "" + + for pattern, kind in _LOG_PATTERNS: + m = pattern.search(line) + if not m: + continue + prefix = f"[{ts}] " if ts else "" + if kind == "round": + round_label = ( + manager_round if manager_round else f"{m.group(1)}/{m.group(2)}" + ) + curated.append(f"\n{prefix}=== Round {round_label} ===") + elif kind == "phase_profile": + curated.append(f"{prefix} Profiling kernel (NCU)...") + elif kind == "phase_analyze": + curated.append(f"{prefix} Analyzing bottleneck...") + elif kind == "phase_generate": + curated.append(f"{prefix} Generating optimized kernel...") + elif kind == "phase_verify": + curated.append(f"{prefix} Verifying correctness...") + elif kind == "verify_pass": + curated.append(f"{prefix} Correctness: PASSED") + elif kind == "verify_fail": + curated.append(f"{prefix} Correctness: FAILED") + elif kind == "new_best": + time_val = m.group(1) + speedup_val = float(m.group(2)) + if speedup_val > 1.0: + curated.append( + f"{prefix} \U0001f389 SPEEDUP {speedup_val:.2f}x \u2014 NEW BEST: {time_val} ms" + ) + else: + curated.append( + f"{prefix} NEW BEST: {time_val} ms (speedup {m.group(2)}x)" + ) + elif kind == "no_improve": + curated.append( + f"{prefix} No improvement ({m.group(1)} ms, best {m.group(2)} ms)" + ) + elif kind == "pytorch_eager": + curated.append(f"{prefix}PyTorch eager baseline: {m.group(1)} ms") + elif kind == "pytorch_compile": + curated.append(f"{prefix}PyTorch compile baseline: {m.group(1)} ms") + elif kind == "initial_kernel_time": + curated.append(f"{prefix}Initial kernel: {m.group(1)} ms") + elif kind == "final_speedup_initial": + curated.append(f"{prefix} Speedup vs initial kernel: {m.group(1)}x") + elif kind == "final_speedup_pytorch": + curated.append(f"{prefix} Speedup vs PyTorch eager: {m.group(1)}x") + elif kind == "baseline": + curated.append(f"{prefix} Worker baseline: {m.group(1)} ms") + elif kind == "baseline_sol": + curated.append( + f"{prefix}Baseline SOL: {m.group(1)}% ({m.group(2)}-bound)" + ) + elif kind == "roofline": + curated.append( + f"{prefix} Roofline: {m.group(1)}-bound, {m.group(2)}% SOL" + ) + elif kind == "round_best": + curated.append( + f"{prefix} Round {m.group(1)} winner: worker {m.group(2)} at {m.group(3)} ms" + ) + elif kind == "round_no_success": + curated.append(f"{prefix} Round {m.group(1)}: no successful workers") + elif kind == "early_stop": + curated.append(f"{prefix} Early stop: {m.group(1).strip()}") + elif kind == "done": + curated.append(f"\n{prefix}OPTIMIZATION COMPLETE") + elif kind == "final_speedup": + curated.append(f"{prefix} Final speedup: {m.group(1)}x") + elif kind == "error": + curated.append(f"{prefix} [ERROR] {m.group(0)}") + break # first matching pattern per line + return "\n".join(curated) + + +def build_interface() -> gr.Blocks: + from utils.providers.models import _get_model_name_to_config + + # Build dropdown: examples + custom + input_choices = _build_input_choices() + default_input = input_choices[0] if input_choices else _CUSTOM_OPTION + + # Pre-load default example content so fields aren't empty on launch + _examples = _discover_examples() + _example_map_init: dict[str, str] = { + f"Example: {label}": dirpath for label, dirpath in _examples + } + default_kernel = "" + default_test = "" + if default_input in _example_map_init: + _d = Path(_example_map_init[default_input]) + try: + default_kernel = (_d / "input.py").read_text(encoding="utf-8") + except OSError: + pass + try: + default_test = (_d / "test.py").read_text(encoding="utf-8") + except OSError: + pass + + model_names = sorted(_get_model_name_to_config().keys()) or ["gpt-5"] + default_model = "gpt-5" if "gpt-5" in model_names else model_names[0] + + gpu_choices = _get_gpu_choices() + default_gpu = gpu_choices[0] if gpu_choices else "" + + with gr.Blocks( + title="KernelAgent — Optimization UI", + theme=gr.themes.Soft(), + css=".worker-log textarea { background-color: #f5f5f5 !important; }", + ) as app: + gr.Markdown( + "# KernelAgent — Kernel Optimization\n\n" + "Hardware-guided optimization: NCU profiling, roofline analysis, " + "LLM bottleneck diagnosis, and iterative refinement.\n\n" + "We have prepared **three examples** to get started — pick one " + "from the dropdown, or paste your own kernel and test code.\n\n" + "**Note:** 5 rounds of optimization can take about 30 minutes." + ) + + with gr.Row(): + # Left column: configuration + with gr.Column(scale=1): + gr.Markdown("## Configuration") + + api_key_input = gr.Textbox( + label="API Key (optional)", + placeholder="sk-... or sk-ant-...", + type="password", + value="", + info="Session-only. Uses env var from .env if empty.", + ) + + input_dropdown = gr.Dropdown( + choices=input_choices, + label="Input Source", + value=default_input, + interactive=True, + info="Pick an example to get started, or select Custom to paste your own.", + ) + + kernel_input = gr.Textbox( + label="Kernel Code", + placeholder="Paste a verified Triton kernel here...", + lines=12, + max_lines=30, + value=default_kernel, + ) + + test_input = gr.Textbox( + label="Test Code", + placeholder="Paste test code here...", + lines=8, + max_lines=20, + value=default_test, + ) + + model_dropdown = gr.Dropdown( + choices=model_names, + label="Model", + value=default_model, + interactive=True, + ) + + strategy_radio = gr.Radio( + choices=["Greedy (1 worker)", "Beam Search (4 workers)"], + value="Greedy (1 worker)", + label="Search Strategy", + ) + + gpu_dropdown = gr.Dropdown( + choices=gpu_choices, + label="GPU", + value=default_gpu, + interactive=True, + info="Select the GPU on your machine.", + ) + + max_rounds_slider = gr.Slider( + 1, 10, value=5, step=1, label="Max Optimization Rounds" + ) + + high_reasoning_cb = gr.Checkbox( + label="High Reasoning Effort", + value=True, + info="Use high reasoning for better quality (o4-mini/o3 series).", + ) + + optimize_button = gr.Button("Optimize Kernel", variant="primary") + + # Right column: results with tabs + with gr.Column(scale=2): + gr.Markdown("## Results") + + status_output = gr.Markdown( + value="*Ready — select a problem and paste a kernel to optimize.*" + ) + + with gr.Tab("Log"): + manager_log_output = gr.Textbox( + label="Manager", + interactive=False, + lines=8, + max_lines=20, + ) + with gr.Row(): + with gr.Column() as w0_col: + w0_log = gr.Textbox( + label="Worker 0", + interactive=False, + lines=18, + max_lines=40, + elem_classes=["worker-log"], + ) + with gr.Column(visible=False) as w1_col: + w1_log = gr.Textbox( + label="Worker 1", + interactive=False, + lines=18, + max_lines=40, + elem_classes=["worker-log"], + ) + with gr.Column(visible=False) as w2_col: + w2_log = gr.Textbox( + label="Worker 2", + interactive=False, + lines=18, + max_lines=40, + elem_classes=["worker-log"], + ) + with gr.Column(visible=False) as w3_col: + w3_log = gr.Textbox( + label="Worker 3", + interactive=False, + lines=18, + max_lines=40, + elem_classes=["worker-log"], + ) + + with gr.Tab("Best Kernel"): + kernel_output = gr.Code( + label="Optimized Kernel", + language="python", + interactive=False, + lines=25, + ) + per_round_html = gr.HTML( + value="", + label="Per-Round Results", + ) + + with gr.Tab("Download"): + download_output = gr.File( + label="Download best kernel", + interactive=False, + ) + + # Wire input dropdown to auto-load kernel and test code + _example_map = _example_map_init + + def _read_file(path: Path) -> str: + try: + return path.read_text(encoding="utf-8") + except OSError: + return "" + + def on_input_selected(label: str) -> tuple[str, str]: + if label == _CUSTOM_OPTION or not label: + return "", "" + if label in _example_map: + d = Path(_example_map[label]) + return _read_file(d / "input.py"), _read_file(d / "test.py") + return "", "" + + input_dropdown.change( + fn=on_input_selected, + inputs=input_dropdown, + outputs=[kernel_input, test_input], + ) + + # Toggle worker column visibility based on strategy + def on_strategy_change(choice: str): + is_beam = choice == "Beam Search (4 workers)" + return ( + gr.update(visible=True), + gr.update(visible=is_beam), + gr.update(visible=is_beam), + gr.update(visible=is_beam), + ) + + strategy_radio.change( + fn=on_strategy_change, + inputs=strategy_radio, + outputs=[w0_col, w1_col, w2_col, w3_col], + ) + + # Wire optimize button + def _parse_strategy(choice: str) -> tuple[str, int, dict]: + """Map strategy radio label to (strategy, num_workers, strategy_config).""" + if choice == "Beam Search (4 workers)": + return "beam_search", 4, {"num_top_kernels": 2, "num_bottlenecks": 2} + return "greedy", 1, {"max_no_improvement": 3} + + def on_optimize( + input_label: str, + kernel_code: str, + test_code: str, + model_name: str, + strategy_choice: str, + gpu_name: str, + max_rounds: int, + high_reasoning: bool, + api_key: str | None, + ): + strategy, num_workers, strategy_config = _parse_strategy(strategy_choice) + + # Resolve problem_label and problem_file_override from input source + problem_label = "" + problem_file_override = None + if input_label.startswith("KB: "): + problem_label = input_label[4:] + elif input_label in _example_map: + problem_file_override = str( + Path(_example_map[input_label]) / "problem.py" + ) + + log_capture = _LogCapture() + result: list[tuple[str, str, str, str | None]] = [] + error: list[BaseException] = [] + + def _worker() -> None: + try: + result.append( + run_optimization( + problem_label=problem_label, + kernel_code=kernel_code, + test_code=test_code, + model_name=model_name, + gpu_name=gpu_name, + max_rounds=int(max_rounds), + high_reasoning=high_reasoning, + platform="cuda", + api_key=api_key, + strategy=strategy, + num_workers=num_workers, + strategy_config=strategy_config, + problem_file_override=problem_file_override, + log_capture=log_capture, + ) + ) + except BaseException as exc: + error.append(exc) + + thread = threading.Thread(target=_worker, daemon=True) + thread.start() + + # Accumulated curated logs: manager + per-worker + mgr_curated = "" + worker_curated: dict[int, str] = {i: "" for i in range(4)} + # Track live status from log lines + current_round = "" + current_phase = "" + best_info = "" + _round_re = re.compile(r"Round (\d+/\d+)") + _best_re = re.compile(r"NEW BEST: (.+)") + worker_log_offsets: dict[str, int] = {} + + def _poll_logs() -> None: + nonlocal mgr_curated, current_round, current_phase, best_info + # Manager-level log + mgr_new = log_capture.get_new_lines() + if mgr_new: + parsed = _parse_log_for_status(mgr_new) + if parsed: + mgr_curated += parsed + "\n" + for cline in (parsed or "").splitlines(): + rm = _round_re.search(cline) + if rm: + current_round = rm.group(1) + current_phase = "" + for kw in ("Profiling", "Analyzing", "Generating", "Verifying"): + if kw in cline: + current_phase = kw.lower() + bm = _best_re.search(cline) + if bm: + best_info = bm.group(1) + # Worker-level logs (per-worker) + log_dir = log_capture.metadata.get("log_dir", "") + if log_dir: + per_worker = _tail_worker_logs(log_dir, worker_log_offsets) + for wid, raw in per_worker.items(): + parsed = _parse_log_for_status(raw, manager_round=current_round) + if parsed: + worker_curated[wid] = ( + worker_curated.get(wid, "") + parsed + "\n" + ) + + round_html_val: list[str] = [] + + def _make_yield(status, kernel_code, download): + return ( + status, + kernel_code, + mgr_curated.rstrip(), + worker_curated.get(0, "").rstrip(), + worker_curated.get(1, "").rstrip(), + worker_curated.get(2, "").rstrip(), + worker_curated.get(3, "").rstrip(), + download, + round_html_val[-1] if round_html_val else "", + ) + + # Poll with a hard timeout so the generator always terminates. + # 30 min per round × max_rounds + extra margin for baselines. + poll_deadline = time.time() + int(max_rounds) * 1800 + 600 + while thread.is_alive(): + thread.join(timeout=2) + _poll_logs() + status_parts = ["**Optimizing…**"] + if current_round: + status_parts.append(f"Round {current_round}") + if current_phase: + status_parts.append(f"({current_phase})") + if best_info: + status_parts.append(f"| Best so far: {best_info}") + yield _make_yield(" ".join(status_parts), "", None) + if time.time() > poll_deadline: + error.append( + TimeoutError("Optimization exceeded maximum wall time") + ) + break + + # Drain remaining logs + _poll_logs() + + if error: + tb = "".join(traceback.format_exception(error[0])) + yield _make_yield( + f"## Error\n\n```\n{error[0]}\n```\n\n```\n{tb}\n```", + "", + None, + ) + elif result: + status, best_kernel, raw_log, download_path, rh = result[0] + round_html_val.append(rh) + # If no curated manager log, fall back to raw + if not mgr_curated.strip(): + mgr_curated = raw_log + yield _make_yield(status, best_kernel, download_path) + else: + yield _make_yield( + "## Error\n\nOptimization thread finished without result.", + "", + None, + ) + + optimize_button.click( + fn=on_optimize, + inputs=[ + input_dropdown, + kernel_input, + test_input, + model_dropdown, + strategy_radio, + gpu_dropdown, + max_rounds_slider, + high_reasoning_cb, + api_key_input, + ], + outputs=[ + status_output, + kernel_output, + manager_log_output, + w0_log, + w1_log, + w2_log, + w3_log, + download_output, + per_round_html, + ], + show_progress="hidden", + ) + + return app + + +def main() -> None: + parser = argparse.ArgumentParser(description="Optimization UI") + parser.add_argument("--port", type=int, default=8088) + parser.add_argument("--host", type=str, default="localhost") + args = parser.parse_args() + + load_dotenv() + app = build_interface() + + print("Starting Optimization UI...") + + meta_keyfile = Path("/var/facebook/x509_identities/server.pem") + is_meta_devserver = meta_keyfile.exists() + + if is_meta_devserver: + server_name = os.uname()[1] + print(f"Meta devserver detected. Visit https://{server_name}:{args.port}/") + app.launch( + share=False, + show_error=True, + server_name=server_name, + server_port=args.port, + ssl_keyfile=str(meta_keyfile), + ssl_certfile=str(meta_keyfile), + ssl_verify=False, + inbrowser=False, + ) + else: + print(f"Visit http://{args.host}:{args.port}/") + app.launch( + share=False, + show_error=True, + server_name=args.host, + server_port=args.port, + inbrowser=True, + ) + + +if __name__ == "__main__": + main() diff --git a/triton_kernel_agent/opt_manager.py b/triton_kernel_agent/opt_manager.py index 09bf6957..ac0d358a 100644 --- a/triton_kernel_agent/opt_manager.py +++ b/triton_kernel_agent/opt_manager.py @@ -242,6 +242,14 @@ def run_optimization( # Benchmark PyTorch baseline once (before spawning workers) pytorch_baseline = self._benchmark_pytorch_baseline(problem_file) + # Benchmark torch.compile baseline + pytorch_compile_time = self._benchmark_pytorch_compile(problem_file) + + # Benchmark the initial kernel + initial_kernel_time = self._benchmark_initial_kernel( + initial_kernel, problem_file + ) + # Round loop round_num = 0 for round_num in range(1, max_rounds + 1): @@ -262,6 +270,16 @@ def run_optimization( # 3. Update strategy with results self.strategy.update_with_results(results, round_num) + # Log per-round winner summary + successful = [r for r in results if r.get("success")] + if successful: + best = min(successful, key=lambda r: r.get("time_ms", float("inf"))) + self.logger.info( + f"Round {round_num} best: worker {best['worker_id']} at {best['time_ms']:.4f} ms" + ) + else: + self.logger.info(f"Round {round_num}: no successful workers") + # 4. Check termination if self.strategy.should_terminate(round_num, max_rounds): self.logger.info("Strategy signaled termination") @@ -277,12 +295,21 @@ def run_optimization( if best: self.logger.info(f"Best time: {best.metrics.time_ms:.4f}ms") + if initial_kernel_time != float("inf") and best.metrics.time_ms > 0: + speedup = initial_kernel_time / best.metrics.time_ms + self.logger.info(f"Speedup vs initial kernel: {speedup:.2f}x") + if pytorch_baseline != float("inf") and best.metrics.time_ms > 0: + speedup_pt = pytorch_baseline / best.metrics.time_ms + self.logger.info(f"Speedup vs PyTorch eager: {speedup_pt:.2f}x") return { "success": best is not None and best.metrics.time_ms != float("inf"), "kernel_code": best.kernel_code if best else None, "best_time_ms": best.metrics.time_ms if best else float("inf"), "total_rounds": round_num, + "pytorch_baseline_ms": pytorch_baseline, + "pytorch_compile_ms": pytorch_compile_time, + "initial_kernel_time_ms": initial_kernel_time, "top_kernels": [ { "kernel_code": p.kernel_code, @@ -325,6 +352,75 @@ def _benchmark_pytorch_baseline(self, problem_file: Path) -> float: return pytorch_time + def _benchmark_initial_kernel( + self, initial_kernel: str, problem_file: Path + ) -> float: + """Benchmark the initial kernel before optimization begins. + + Args: + initial_kernel: Kernel source code + problem_file: Path to problem.py + + Returns: + Initial kernel time in ms + """ + from triton_kernel_agent.opt_worker_component.benchmarking.benchmark import ( + Benchmark, + ) + + artifacts_dir = self.log_dir / "artifacts" + artifacts_dir.mkdir(parents=True, exist_ok=True) + + # Write kernel to a temp file + kernel_file = artifacts_dir / "initial_kernel.py" + kernel_file.write_text(initial_kernel, encoding="utf-8") + + benchmarker = Benchmark( + logger=self.logger, + artifacts_dir=artifacts_dir, + benchmark_lock=self.benchmark_lock, + worker_id=-1, + ) + + result = benchmarker.benchmark_kernel(kernel_file, problem_file) + kernel_time = result.get("time_ms", float("inf")) + + if kernel_time != float("inf"): + self.logger.info(f"Initial kernel time: {kernel_time:.4f}ms") + + return kernel_time + + def _benchmark_pytorch_compile(self, problem_file: Path) -> float: + """Benchmark torch.compile'd PyTorch baseline. + + Args: + problem_file: Path to problem.py + + Returns: + torch.compile baseline time in ms + """ + from triton_kernel_agent.opt_worker_component.benchmarking.benchmark import ( + Benchmark, + ) + + artifacts_dir = self.log_dir / "artifacts" + artifacts_dir.mkdir(parents=True, exist_ok=True) + + benchmarker = Benchmark( + logger=self.logger, + artifacts_dir=artifacts_dir, + benchmark_lock=self.benchmark_lock, + worker_id=-1, + ) + + result = benchmarker.benchmark_pytorch_compile(problem_file) + compile_time = result.get("time_ms", float("inf")) + + if compile_time != float("inf"): + self.logger.info(f"PyTorch compile baseline: {compile_time:.4f}ms") + + return compile_time + def _run_workers( self, candidates: list[dict[str, Any]], @@ -400,6 +496,11 @@ def _run_workers( self.logger.warning(f"Worker {w.pid} timed out, terminating") w.terminate() w.join(timeout=5) + if w.is_alive(): + self.logger.warning(f"Worker {w.pid} still alive, killing") + w.kill() + w.join(timeout=2) + w.close() # Collect results results = [] @@ -409,6 +510,10 @@ def _run_workers( except Exception: break + # Clean up queue resources to prevent thread hangs during GC + result_queue.close() + result_queue.join_thread() + successful = sum(1 for r in results if r.get("success")) self.logger.info( f"Round {round_num}: {successful}/{len(candidates)} workers succeeded " diff --git a/triton_kernel_agent/opt_worker_component/benchmarking/benchmark.py b/triton_kernel_agent/opt_worker_component/benchmarking/benchmark.py index 8ee39e46..9f8314ac 100644 --- a/triton_kernel_agent/opt_worker_component/benchmarking/benchmark.py +++ b/triton_kernel_agent/opt_worker_component/benchmarking/benchmark.py @@ -234,3 +234,68 @@ def benchmark_pytorch( self.logger.error(f"PyTorch baseline benchmark failed: {e}") self.logger.error(traceback.format_exc()) return {"time_ms": float("inf")} + + def benchmark_pytorch_compile( + self, + problem_file: Path, + dtype: Optional[torch.dtype] = None, + ) -> dict[str, Any]: + """Benchmark torch.compile'd PyTorch baseline using direct in-process timing. + + Mirrors benchmark_pytorch() but wraps the model with torch.compile() + and uses extended warmup (3 forward calls) before timing to allow + compilation and warm caches. + + Args: + problem_file: Path to problem file (must define Model class and get_inputs()) + dtype: Data type to use (default: auto-detect based on model parameters) + + Returns: + Dictionary with benchmark results: + - time_ms: Mean time in ms + - stats: Full timing statistics (mean, std, min, max, all_times, etc.) + """ + try: + with self.lock_manager: + model, inputs = prepare_pytorch_model( + problem_file=problem_file, + device="cuda", + dtype=dtype, + ) + + model = torch.compile(model) + + # Extended warmup: 3 forward calls to trigger compilation + for _ in range(3): + model(*inputs) + torch.cuda.synchronize() + + if self.timing_method == "do_bench": + times = time_with_triton_do_bench( + lambda: model(*inputs), + [], + warmup=self.warmup, + rep=self.repeat, + verbose=False, + ) + else: # cuda_event + times = time_with_cuda_events( + lambda: model(*inputs), + [], + num_warmup=self.warmup, + num_trials=self.repeat, + clear_cache=True, + verbose=False, + ) + + stats = compute_timing_stats(times) + + return { + "time_ms": stats["mean"], + "stats": stats, + } + + except Exception as e: + self.logger.error(f"PyTorch compile benchmark failed: {e}") + self.logger.error(traceback.format_exc()) + return {"time_ms": float("inf")} diff --git a/triton_kernel_agent/opt_worker_component/orchestrator/optimization_orchestrator.py b/triton_kernel_agent/opt_worker_component/orchestrator/optimization_orchestrator.py index d4832f49..2d05ce02 100644 --- a/triton_kernel_agent/opt_worker_component/orchestrator/optimization_orchestrator.py +++ b/triton_kernel_agent/opt_worker_component/orchestrator/optimization_orchestrator.py @@ -29,11 +29,7 @@ from kernel_perf_agent.kernel_opt.roofline.ncu_roofline import RooflineAnalyzer from triton_kernel_agent.prompt_manager import PromptManager from triton_kernel_agent.worker import VerificationWorker -from triton_kernel_agent.worker_util import ( - _call_llm, - _extract_code_from_response, - _write_kernel_file, -) +from triton_kernel_agent.worker_util import _write_kernel_file from utils.providers.base import BaseProvider @@ -854,12 +850,8 @@ def _generate_optimized_kernel(self, opt_prompt: str, round_num: int) -> str | N self.logger.info(f"[{round_num}] Generating optimized kernel...") try: messages = [{"role": "user", "content": opt_prompt}] - response_text = _call_llm( - provider=self.provider, - model=self.model, - messages=messages, - high_reasoning_effort=self.high_reasoning_effort, - logger=self.logger, + response_text = self.verification_worker._call_llm( + messages, max_tokens=24576, ) @@ -869,9 +861,8 @@ def _generate_optimized_kernel(self, opt_prompt: str, round_num: int) -> str | N f.write(response_text) # Extract code - optimized_kernel = _extract_code_from_response( - response_text=response_text, - logger=self.logger, + optimized_kernel = self.verification_worker._extract_code_from_response( + response_text, ) if not optimized_kernel or len(optimized_kernel) < 100: @@ -949,14 +940,14 @@ def _generate_reflexion(self, attempt: OptimizationAttempt) -> Reflexion | None: reflexion_prompt = self.prompt_manager.render_reflexion_prompt(attempt) messages = [{"role": "user", "content": reflexion_prompt}] - response_text = _call_llm( - provider=self.provider, - model=self.model, - messages=messages, - high_reasoning_effort=False, # Use standard reasoning for reflexion - logger=self.logger, + # Use provider directly with high_reasoning_effort=False + # (worker._call_llm would force high_reasoning=True if worker was configured that way) + response = self.provider.get_response( + self.model, + messages, max_tokens=2048, ) + response_text = response.content # Save reflexion response reflexion_file = ( diff --git a/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_template.j2 b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_template.j2 index 29a7d413..48866810 100644 --- a/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_template.j2 +++ b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_template.j2 @@ -15,6 +15,7 @@ limitations under the License. #} """NCU profiling wrapper.""" +import importlib import sys import torch import inspect @@ -22,14 +23,15 @@ sys.path.insert(0, str({{ kernel_file_parent }})) sys.path.insert(0, str({{ problem_file_parent }})) from {{ kernel_module }} import kernel_function -from {{ problem_module }} import get_inputs, get_init_inputs -# Try to import Model if it exists (for Conv, Linear, etc.) -try: - from {{ problem_module }} import Model - has_model = True -except ImportError: - has_model = False +_problem_mod = importlib.import_module({{ problem_module | tojson }}) +get_inputs = _problem_mod.get_inputs +get_init_inputs = _problem_mod.get_init_inputs + +# Try to get Model if it exists (for Conv, Linear, etc.) +has_model = hasattr(_problem_mod, 'Model') +if has_model: + Model = _problem_mod.Model # Get inputs inputs = get_inputs() diff --git a/triton_kernel_agent/worker_util.py b/triton_kernel_agent/worker_util.py index ee5ef955..2e113d44 100644 --- a/triton_kernel_agent/worker_util.py +++ b/triton_kernel_agent/worker_util.py @@ -25,6 +25,18 @@ # ------------------------ +def _call_llm( + provider, + model: str, + messages: list, + logger: Logger | None = None, + **kwargs, +) -> str: + """Call an LLM provider and return the response text.""" + response = provider.get_response(model, messages, **kwargs) + return response.content + + def _extract_history_usage_from_response( response_text: str, logger: Logger | None = None, diff --git a/utils/providers/relay_provider.py b/utils/providers/relay_provider.py index 6edc4d7e..3c35781a 100644 --- a/utils/providers/relay_provider.py +++ b/utils/providers/relay_provider.py @@ -93,7 +93,7 @@ def _handle_request( self.server_url, json=request_data, headers={"Content-Type": "application/json"}, - timeout=int(os.environ.get("LLM_RELAY_TIMEOUT_S", 120)), + timeout=int(os.environ.get("LLM_RELAY_TIMEOUT_S", 600)), ) if response.status_code != 200: