Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 186 additions & 15 deletions dion/newton_schulz_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,164 @@ def ns_line_2(A: Tensor, alpha: float, beta: float, *, out: Tensor = None):
return out


def _get_gemm_configs():
return [
triton.Config(
{
"BLOCK_SIZE_M": bm,
"BLOCK_SIZE_N": bn,
"BLOCK_SIZE_K": bk,
"GROUP_SIZE_M": 8,
},
num_stages=st,
num_warps=wp,
)
for bm in (64, 128)
for bn in (64, 128, 256)
for bk in (32, 64, 128)
for st, wp in ((3, 4), (4, 4), (3, 8))
if bm // bn <= 2 and bn // bm <= 2
]


@triton.jit
def _pid_to_block_ns3(
pid,
M,
N,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Same helper as in your earlier kernels, extended with N."""
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)

batch = pid // (num_pid_m * num_pid_n)
pid = pid % (num_pid_m * num_pid_n)

pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M)

return batch, pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N


@triton.autotune(
configs=_get_gemm_configs(),
key=["M", "N", "b_stride_r", "b_stride_c", "x_stride_r", "x_stride_c"],
)
@triton.jit
def ns_line_3_kernel(
B_ptr, # [B, M, M] symmetric
X_ptr, # [B, M, N]
C_ptr, # [B, M, N]
M,
N, # rows(X)=M, cols(X)=N
b_stride_b,
b_stride_r,
b_stride_c,
x_stride_b,
x_stride_r,
x_stride_c,
c_stride_b,
c_stride_r,
c_stride_c,
alpha, # scalar a (scale of X)
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
pid = tl.program_id(axis=0)
batch, m_start, n_start = _pid_to_block_ns3(
pid, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M
)

# Offset base pointers to this batch
B_ptr += batch * b_stride_b
X_ptr += batch * x_stride_b
C_ptr += batch * c_stride_b

# Create index ranges for the tile
offs_m = m_start + tl.arange(0, BLOCK_SIZE_M)
offs_n = n_start + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)

# Pointers to B and X tiles
b_ptrs = B_ptr + offs_m[:, None] * b_stride_r + offs_k[None, :] * b_stride_c
x_ptrs = X_ptr + offs_k[:, None] * x_stride_r + offs_n[None, :] * x_stride_c

# Accumulator, initialized with bias * alpha
x_bias_ptrs = X_ptr + offs_m[:, None] * x_stride_r + offs_n[None, :] * x_stride_c
acc = (
tl.load(
x_bias_ptrs, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), other=0.0
)
* alpha
).to(tl.float32)

# GEMM main loop
for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)):
b = tl.load(b_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0)
x = tl.load(x_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0)
acc = tl.dot(b, x, acc)
b_ptrs += BLOCK_SIZE_K * b_stride_c
x_ptrs += BLOCK_SIZE_K * x_stride_r

out_dtype = C_ptr.dtype.element_ty
acc = acc.to(out_dtype)

# Store result
c_ptrs = C_ptr + offs_m[:, None] * c_stride_r + offs_n[None, :] * c_stride_c
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, acc, mask=mask)


def ns_line_3(B: Tensor, X: Tensor, a: float, *, out: Tensor = None) -> Tensor:
"""
Fused implementation of C = a * X + B @ X
B must be square & symmetric, X has same leading dim, arbitrary trailing cols.
"""
if B.shape[-2] != B.shape[-1]:
raise ValueError("B must be square")

if B.shape[-2] != X.shape[-2]:
raise ValueError("B and X must have the same number of rows")

# Broadcast & batch handling (supports 2‑ or 3‑D inputs)
M, N = X.shape[-2:]
batch = X.shape[0] if X.ndim == 3 else 1

if out is None:
out = torch.empty_like(X)

grid = lambda meta: (
batch
* triton.cdiv(M, meta["BLOCK_SIZE_M"])
* triton.cdiv(N, meta["BLOCK_SIZE_N"]),
)

ns_line_3_kernel[grid](
B_ptr=B,
X_ptr=X,
C_ptr=out,
M=M,
N=N,
b_stride_b=B.stride(0) if B.ndim == 3 else 0,
b_stride_r=B.stride(-2),
b_stride_c=B.stride(-1),
x_stride_b=X.stride(0) if X.ndim == 3 else 0,
x_stride_r=X.stride(-2),
x_stride_c=X.stride(-1),
c_stride_b=out.stride(0) if out.ndim == 3 else 0,
c_stride_r=out.stride(-2),
c_stride_c=out.stride(-1),
alpha=a,
)
return out


@torch.compile(dynamic=False, fullgraph=True)
def zeropower_via_newtonschulz5(G: Tensor, epsilon: float = 1e-7):
"""
Expand Down Expand Up @@ -340,33 +498,46 @@ def newton_schulz_triton(G: Tensor, epsilon: float = 1e-7):
"""
# Newton-Schulz constants
ns_consts = [
(4.0848, -6.8946, 2.9270),
(3.9505, -6.3029, 2.6377),
(3.7418, -5.5913, 2.3037),
(2.8769, -3.1427, 1.2046),
(2.8366, -3.0525, 1.2012),
[4.6051, -9.6552, 5.6769],
[4.7505, -6.0861, 2.1790],
[2.7763, -2.3190, 0.5523],
[2.4231, -2.2861, 0.8193],
]

X = G.to(dtype=torch.bfloat16)
if G.size(-2) > G.size(-1):
X = X.mT

# Ensure spectral norm is at most 1
X = X / (X.norm(dim=(-2, -1), keepdim=True) + epsilon)

# Allocate buffers
X = X.contiguous()
A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype)
B = torch.empty_like(A)
C = torch.empty_like(X)

ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm

# Perform the NS iterations
for a, b, c in ns_consts:
# Ensure spectral norm is at most 1
# we remove the previous normalization to switch to AOL rescaling
# Which is further explained in the paper: https://arxiv.org/pdf/2208.03160
# which consists in computing W@W^t using ns_line_1 and then computing the
# scaling factors: fast_inv_sqrt(reduce_sum(abs(WW^t), axis=-1)) which is a vector
# since the main operation to compute those correspond to ns_line_1
# we can fuse it with the first newton schulz iterate. Furthermore this gives a better
# starting point for the newton schulz iterations as the matrix is closer to orthogonal
# thanks to this, we can save one iteration of newton schulz.
ns_line_1(X, out=A) # gram matrix A = X @ X.mT
s = torch.rsqrt(torch.clamp_min(
A.abs().sum(dim=-1, keepdim=False), min=epsilon
)) # AOL rescaling vector
X = X * s.unsqueeze(-1) # rescale X using s making it closer to orthogonal
# first NS iteration with reuse of A
a, b, c = ns_consts[0]
A = A * s.unsqueeze(-1) * s.unsqueeze(-2)
ns_line_2(A, alpha=c, beta=b, out=B)
ns_line_3(B, X, a, out=C)
X, C = C, X

# Perform the remaining NS iterations
for a, b, c in ns_consts[1:]:
ns_line_1(X, out=A) # A = X @ X.mT
ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A
ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X
ns_line_3(B, X, a, out=C) # C = a * X + B @ X
X, C = C, X # Swap references to avoid unnecessary copies

if G.size(-2) > G.size(-1):
Expand Down