Skip to content

KernelAgent-Oink: Add SM100 CuTeDSL RMSNorm custom ops plugin for vLLM#69

Open
Laurawly wants to merge 12 commits intomainfrom
feat/oink
Open

KernelAgent-Oink: Add SM100 CuTeDSL RMSNorm custom ops plugin for vLLM#69
Laurawly wants to merge 12 commits intomainfrom
feat/oink

Conversation

@Laurawly
Copy link
Contributor

@Laurawly Laurawly commented Jan 6, 2026

Add the kernelagent-oink vLLM plugin that registers Blackwell (SM100) RMSNorm
custom ops via torch.library.custom_op under the oink:: namespace:

  • oink::rmsnorm(x, weight, eps) -> Tensor
  • oink::fused_add_rms_norm(x!, residual!, weight, eps) -> () (in-place, vLLM semantics)

The SM100 CuTeDSL implementation is layout-aware and preserves padded-row
strides (stride(1)==1, stride(0)>=N) so torch.compile/CUDA-graph capture sees a
stable stride contract. Includes small-M latency tuning for DSv3-like N=7168
and maintains high-M bandwidth, with correctness-first fallbacks on non-SM100.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 6, 2026
@Laurawly Laurawly changed the title kernelagent-oink: add SM100 CuTeDSL RMSNorm custom ops plugin for vLLM KernelAgent-Oink: Add SM100 CuTeDSL RMSNorm custom ops plugin for vLLM Jan 6, 2026
Copy link
Contributor

@Jack-Khuu Jack-Khuu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initial comments

Need to go through rmsnorm.py

import subprocess
import sys
import threading
from typing import Optional, Tuple
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to keep it for <= Python 3.9 support that's fine. If not let's use | None and tuple for 3.10+

@Laurawly Laurawly requested review from drisspg and v0i0 January 12, 2026 23:01
@@ -0,0 +1,2927 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Member

@msaroufim msaroufim Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late review, it's nice that we have an extension system now to try things in VLLM so I mostly want to spend time reviewing the kernel itself and what'd make it easier for vendors like VLLM to actually merge this in. Mostly reiterating points I made here https://x.com/marksaroufim/status/2009096176789016600?s=20

A lot of it stems from this file is too long but I think it shouldn't be too hard to clean it up

  1. we don't need the cache to work over multiple cute DSL versions, presumably they're making breaking changes fairly frequently so let's just pick the latest version and update as needed
  2. The code almost looks like a splatted autotune run because it's trying to handle many cases and choose between different optimization. I think we should just try and ship the one specific config that is fast on some specific shapes on a specific model that the VLLM team cares about on B200. Otherwise they'll have trouble reviewing this code even if it's faster and I'd rather we generalize the code progressively as the need arises
  3. A lot of the pointer marshalling code can be deleted in favor of using tvm-ffi, a good chunk of the file is doing this and this will be error prone
  4. Point 2 also will have unexpected side effects, where tons of fallback makes it unpredictable for an end user precisely which kernel configuration will run which is something all of our numerics sensitive customers will really care about. A user would often like to explicitly state whether they want an op to be in place or not. I'd argue that instead of environment variables gating specific optimizations we should have arguments to a function or separate functions. Even further PyTorch now has an intra kernel dispatcher where we can make guarantees on which specific kernel will be called for a specific shape
  5. Finally while I think an e2e test in VLLM works great, we probably also want some smaller unit tests comparing numerics vs vanilla PyTorch code and Quack right here

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For #5, are we thinking numerics unit tests (i.e. feed in a specific tensor, compare output to some gold reference to check that we're not off beyond some threshold)? Or are we really thinking correctness test (which ends up checking for equality with some epsilon because of fp math)? For the latter we can probably cook up some input generation where we don't have to scratch our heads too much about how to choose the epsilon for various input dtypes (maybe even do bit-identical tests).

Laurawly and others added 10 commits February 26, 2026 10:54
  oink::fused_add_rms_norm backed by an SM100 CuTeDSL RMSNorm kernel.

  The ops are torch.compile-friendly (stride-preserving for padded-row inputs)
  and the fused op matches vLLM's in-place residual-add RMSNorm semantics.
- Switch correctness gate to PyTorch ref + record err stats\n- Tighten Softmax/LayerNorm tolerances (Quack-like)\n- Quack-style benchmark suite layout + SVG plots\n- Packaging/README polish for publishability
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants