Conversation
Jack-Khuu
left a comment
There was a problem hiding this comment.
Initial comments
Need to go through rmsnorm.py
| import subprocess | ||
| import sys | ||
| import threading | ||
| from typing import Optional, Tuple |
There was a problem hiding this comment.
If we want to keep it for <= Python 3.9 support that's fine. If not let's use | None and tuple for 3.10+
| @@ -0,0 +1,2927 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |||
There was a problem hiding this comment.
Sorry for the late review, it's nice that we have an extension system now to try things in VLLM so I mostly want to spend time reviewing the kernel itself and what'd make it easier for vendors like VLLM to actually merge this in. Mostly reiterating points I made here https://x.com/marksaroufim/status/2009096176789016600?s=20
A lot of it stems from this file is too long but I think it shouldn't be too hard to clean it up
- we don't need the cache to work over multiple cute DSL versions, presumably they're making breaking changes fairly frequently so let's just pick the latest version and update as needed
- The code almost looks like a splatted autotune run because it's trying to handle many cases and choose between different optimization. I think we should just try and ship the one specific config that is fast on some specific shapes on a specific model that the VLLM team cares about on B200. Otherwise they'll have trouble reviewing this code even if it's faster and I'd rather we generalize the code progressively as the need arises
- A lot of the pointer marshalling code can be deleted in favor of using
tvm-ffi, a good chunk of the file is doing this and this will be error prone - Point 2 also will have unexpected side effects, where tons of fallback makes it unpredictable for an end user precisely which kernel configuration will run which is something all of our numerics sensitive customers will really care about. A user would often like to explicitly state whether they want an op to be in place or not. I'd argue that instead of environment variables gating specific optimizations we should have arguments to a function or separate functions. Even further PyTorch now has an intra kernel dispatcher where we can make guarantees on which specific kernel will be called for a specific shape
- Finally while I think an e2e test in VLLM works great, we probably also want some smaller unit tests comparing numerics vs vanilla PyTorch code and Quack right here
There was a problem hiding this comment.
For #5, are we thinking numerics unit tests (i.e. feed in a specific tensor, compare output to some gold reference to check that we're not off beyond some threshold)? Or are we really thinking correctness test (which ends up checking for equality with some epsilon because of fp math)? For the latter we can probably cook up some input generation where we don't have to scratch our heads too much about how to choose the epsilon for various input dtypes (maybe even do bit-identical tests).
oink::fused_add_rms_norm backed by an SM100 CuTeDSL RMSNorm kernel. The ops are torch.compile-friendly (stride-preserving for padded-row inputs) and the fused op matches vLLM's in-place residual-add RMSNorm semantics.
- Switch correctness gate to PyTorch ref + record err stats\n- Tighten Softmax/LayerNorm tolerances (Quack-like)\n- Quack-style benchmark suite layout + SVG plots\n- Packaging/README polish for publishability
Add the kernelagent-oink vLLM plugin that registers Blackwell (SM100) RMSNorm
custom ops via torch.library.custom_op under the oink:: namespace:
The SM100 CuTeDSL implementation is layout-aware and preserves padded-row
strides (stride(1)==1, stride(0)>=N) so torch.compile/CUDA-graph capture sees a
stable stride contract. Includes small-M latency tuning for DSv3-like N=7168
and maintains high-M bandwidth, with correctness-first fallbacks on non-SM100.