diff --git a/dion/newton_schulz_triton.py b/dion/newton_schulz_triton.py index 30e21a8..981e501 100644 --- a/dion/newton_schulz_triton.py +++ b/dion/newton_schulz_triton.py @@ -302,6 +302,164 @@ def ns_line_2(A: Tensor, alpha: float, beta: float, *, out: Tensor = None): return out +def _get_gemm_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + }, + num_stages=st, + num_warps=wp, + ) + for bm in (64, 128) + for bn in (64, 128, 256) + for bk in (32, 64, 128) + for st, wp in ((3, 4), (4, 4), (3, 8)) + if bm // bn <= 2 and bn // bm <= 2 + ] + + +@triton.jit +def _pid_to_block_ns3( + pid, + M, + N, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Same helper as in your earlier kernels, extended with N.""" + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + batch = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + return batch, pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N + + +@triton.autotune( + configs=_get_gemm_configs(), + key=["M", "N", "b_stride_r", "b_stride_c", "x_stride_r", "x_stride_c"], +) +@triton.jit +def ns_line_3_kernel( + B_ptr, # [B, M, M] symmetric + X_ptr, # [B, M, N] + C_ptr, # [B, M, N] + M, + N, # rows(X)=M, cols(X)=N + b_stride_b, + b_stride_r, + b_stride_c, + x_stride_b, + x_stride_r, + x_stride_c, + c_stride_b, + c_stride_r, + c_stride_c, + alpha, # scalar a (scale of X) + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch, m_start, n_start = _pid_to_block_ns3( + pid, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Offset base pointers to this batch + B_ptr += batch * b_stride_b + X_ptr += batch * x_stride_b + C_ptr += batch * c_stride_b + + # Create index ranges for the tile + offs_m = m_start + tl.arange(0, BLOCK_SIZE_M) + offs_n = n_start + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # Pointers to B and X tiles + b_ptrs = B_ptr + offs_m[:, None] * b_stride_r + offs_k[None, :] * b_stride_c + x_ptrs = X_ptr + offs_k[:, None] * x_stride_r + offs_n[None, :] * x_stride_c + + # Accumulator, initialized with bias * alpha + x_bias_ptrs = X_ptr + offs_m[:, None] * x_stride_r + offs_n[None, :] * x_stride_c + acc = ( + tl.load( + x_bias_ptrs, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), other=0.0 + ) + * alpha + ).to(tl.float32) + + # GEMM main loop + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + b = tl.load(b_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + x = tl.load(x_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + acc = tl.dot(b, x, acc) + b_ptrs += BLOCK_SIZE_K * b_stride_c + x_ptrs += BLOCK_SIZE_K * x_stride_r + + out_dtype = C_ptr.dtype.element_ty + acc = acc.to(out_dtype) + + # Store result + c_ptrs = C_ptr + offs_m[:, None] * c_stride_r + offs_n[None, :] * c_stride_c + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, acc, mask=mask) + + +def ns_line_3(B: Tensor, X: Tensor, a: float, *, out: Tensor = None) -> Tensor: + """ + Fused implementation of C = a * X + B @ X + B must be square & symmetric, X has same leading dim, arbitrary trailing cols. + """ + if B.shape[-2] != B.shape[-1]: + raise ValueError("B must be square") + + if B.shape[-2] != X.shape[-2]: + raise ValueError("B and X must have the same number of rows") + + # Broadcast & batch handling (supports 2‑ or 3‑D inputs) + M, N = X.shape[-2:] + batch = X.shape[0] if X.ndim == 3 else 1 + + if out is None: + out = torch.empty_like(X) + + grid = lambda meta: ( + batch + * triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + + ns_line_3_kernel[grid]( + B_ptr=B, + X_ptr=X, + C_ptr=out, + M=M, + N=N, + b_stride_b=B.stride(0) if B.ndim == 3 else 0, + b_stride_r=B.stride(-2), + b_stride_c=B.stride(-1), + x_stride_b=X.stride(0) if X.ndim == 3 else 0, + x_stride_r=X.stride(-2), + x_stride_c=X.stride(-1), + c_stride_b=out.stride(0) if out.ndim == 3 else 0, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=a, + ) + return out + + @torch.compile(dynamic=False, fullgraph=True) def zeropower_via_newtonschulz5(G: Tensor, epsilon: float = 1e-7): """ @@ -340,33 +498,46 @@ def newton_schulz_triton(G: Tensor, epsilon: float = 1e-7): """ # Newton-Schulz constants ns_consts = [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), + [4.6051, -9.6552, 5.6769], + [4.7505, -6.0861, 2.1790], + [2.7763, -2.3190, 0.5523], + [2.4231, -2.2861, 0.8193], ] - X = G.to(dtype=torch.bfloat16) if G.size(-2) > G.size(-1): X = X.mT - # Ensure spectral norm is at most 1 - X = X / (X.norm(dim=(-2, -1), keepdim=True) + epsilon) - - # Allocate buffers X = X.contiguous() A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) B = torch.empty_like(A) C = torch.empty_like(X) - ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm - - # Perform the NS iterations - for a, b, c in ns_consts: + # Ensure spectral norm is at most 1 + # we remove the previous normalization to switch to AOL rescaling + # Which is further explained in the paper: https://arxiv.org/pdf/2208.03160 + # which consists in computing W@W^t using ns_line_1 and then computing the + # scaling factors: fast_inv_sqrt(reduce_sum(abs(WW^t), axis=-1)) which is a vector + # since the main operation to compute those correspond to ns_line_1 + # we can fuse it with the first newton schulz iterate. Furthermore this gives a better + # starting point for the newton schulz iterations as the matrix is closer to orthogonal + # thanks to this, we can save one iteration of newton schulz. + ns_line_1(X, out=A) # gram matrix A = X @ X.mT + s = torch.rsqrt(torch.clamp_min( + A.abs().sum(dim=-1, keepdim=False), min=epsilon + )) # AOL rescaling vector + X = X * s.unsqueeze(-1) # rescale X using s making it closer to orthogonal + # first NS iteration with reuse of A + a, b, c = ns_consts[0] + A = A * s.unsqueeze(-1) * s.unsqueeze(-2) + ns_line_2(A, alpha=c, beta=b, out=B) + ns_line_3(B, X, a, out=C) + X, C = C, X + + # Perform the remaining NS iterations + for a, b, c in ns_consts[1:]: ns_line_1(X, out=A) # A = X @ X.mT ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A - ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + ns_line_3(B, X, a, out=C) # C = a * X + B @ X X, C = C, X # Swap references to avoid unnecessary copies if G.size(-2) > G.size(-1):