Skip to content
16 changes: 10 additions & 6 deletions kernel_course/triton_ops/axpby.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
)
@triton.jit
def axpby_kernel(
x_ptr,
y_ptr,
X,
Y,
alpha,
beta,
n_elements,
Expand All @@ -26,15 +26,19 @@ def axpby_kernel(
# We need note that offsets is a list of pointers
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Initialize pointers to the start of the blocks
x_ptr = X + offsets
y_ptr = Y + offsets
# Create a mask to guard memory operations against out-of-bounds accesses
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extra elements in case the input is not a multiple of the block_size
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
x = tl.load(x_ptr, mask=mask)
y = tl.load(y_ptr, mask=mask)
# Compute y = alpha * x + beta * y
y = alpha * x + beta * y
y = beta * y
y += alpha * x
# Write y back to DRAM
tl.store(y_ptr + offsets, y, mask=mask)
tl.store(y_ptr, y, mask=mask)


def axpby(
Expand Down
11 changes: 7 additions & 4 deletions kernel_course/triton_ops/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
)
@triton.jit
def copy_kernel(
x_ptr,
y_ptr,
X,
Y,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
Expand All @@ -24,12 +24,15 @@ def copy_kernel(
# We need note that offsets is a list of pointers
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Initialize pointers to the start of the blocks
x_ptr = X + offsets
y_ptr = Y + offsets
# Create a mask to guard memory operations against out-of-bounds accesses
mask = offsets < n_elements
# Load x from DRAM, masking out any extra elements in case the input is not a multiple of the block_size
x = tl.load(x_ptr + offsets, mask=mask)
x = tl.load(x_ptr, mask=mask)
# Write x to y in DRAM
tl.store(y_ptr + offsets, x, mask=mask)
tl.store(y_ptr, x, mask=mask)


def copy(
Expand Down
16 changes: 10 additions & 6 deletions kernel_course/triton_ops/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
)
@triton.jit
def dot_kernel(
x_ptr,
y_ptr,
z_ptr,
X,
Y,
Z,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
Expand All @@ -25,15 +25,19 @@ def dot_kernel(
# We need note that offsets is a list of pointers
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Initialize pointers to the start of the blocks
x_ptr = X + offsets
y_ptr = Y + offsets
z_ptr = Z + offsets
# Create a mask to guard memory operations against out-of-bounds accesses
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extra elements in case the input is not a multiple of the block_size
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
x = tl.load(x_ptr, mask=mask)
y = tl.load(y_ptr, mask=mask)
# Compute z = x \cdot y
z = tl.sum(x * y)
# Write z back to DRAM
tl.store(z_ptr + offsets, z, mask=mask)
tl.store(z_ptr, z, mask=mask)


def dot(
Expand Down
35 changes: 19 additions & 16 deletions kernel_course/triton_ops/gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
)
@triton.jit
def gemv_kernel(
A_ptr,
x_ptr,
y_ptr,
alpha,
beta,
A,
X,
Y,
stride_am,
stride_an,
stride_x,
stride_y,
alpha,
beta,
n_elements_M,
n_elements_N,
BLOCK_M: tl.constexpr,
Expand All @@ -39,6 +39,10 @@ def gemv_kernel(
# This program will process inputs that offset from the initial pointer
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# Initialize pointers to the start of the blocks
A_ptr = A + offs_m[:, None] * stride_am
x_ptr = X
y_ptr = Y + offs_m * stride_y
# Create a mask to guard memory operations against out-of-bounds accesses
mask_m = offs_m < n_elements_M
# Initialize the accumulator to zero for each row
Expand All @@ -54,12 +58,10 @@ def gemv_kernel(
mask_n = offs_n < n_elements_N
# Load a block of A and x from DRAM, masking out any extra elements in case the input is not a multiple of the block size
if EVEN_N & EVEN_M:
a = tl.load(
A_ptr + offs_m[:, None] * stride_am + offs_n[None, :] * stride_an
)
a = tl.load(A_ptr + offs_n[None, :] * stride_an)
else:
a = tl.load(
A_ptr + offs_m[:, None] * stride_am + offs_n[None, :] * stride_an,
A_ptr + offs_n[None, :] * stride_an,
mask=mask_m[:, None] & mask_n[None, :],
other=0.0,
)
Expand All @@ -71,16 +73,17 @@ def gemv_kernel(
acc += tl.sum(a * x[None, :], axis=1)
# Load y from DRAM, masking out any extra elements in case the input is not a multiple of the block size
if EVEN_M:
y = tl.load(y_ptr + offs_m * stride_y)
y = tl.load(y_ptr)
else:
y = tl.load(y_ptr + offs_m * stride_y, mask=mask_m, other=0.0)
y = tl.load(y_ptr, mask=mask_m, other=0.0)
# Compute y = alpha * A * x + beta * y
y_new = (alpha * acc + beta * y).to(y.dtype)
y_new = beta * y
y_new += alpha * acc
# Write y back to DRAM
if EVEN_M:
tl.store(y_ptr + offs_m * stride_y, y_new)
tl.store(y_ptr, y_new)
else:
tl.store(y_ptr + offs_m * stride_y, y_new, mask=mask_m)
tl.store(y_ptr, y_new, mask=mask_m)


def gemv(
Expand Down Expand Up @@ -117,12 +120,12 @@ def grid(meta):
A,
x,
y,
alpha,
beta,
A.stride(0),
A.stride(1),
x.stride(0),
y.stride(0),
alpha,
beta,
n_elements_M,
n_elements_N,
)
Expand Down
8 changes: 5 additions & 3 deletions kernel_course/triton_ops/scal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
@triton.jit
def scal_kernel(
y_ptr,
Y,
alpha,
n_elements,
BLOCK_SIZE: tl.constexpr,
Expand All @@ -24,14 +24,16 @@ def scal_kernel(
# We need note that offsets is a list of pointers
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Initialize pointers to the start of the blocks
y_ptr = Y + offsets
# Create a mask to guard memory operations against out-of-bounds accesses
mask = offsets < n_elements
# Load y from DRAM, masking out any extra elements in case the input is not a multiple of the block_size
y = tl.load(y_ptr + offsets, mask=mask)
y = tl.load(y_ptr, mask=mask)
# Scale y by alpha
y = y * alpha
# Write y back to DRAM
tl.store(y_ptr + offsets, y, mask=mask)
tl.store(y_ptr, y, mask=mask)


def scal(
Expand Down
15 changes: 9 additions & 6 deletions kernel_course/triton_ops/swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
)
@triton.jit
def swap_kernel(
x_ptr,
y_ptr,
X,
Y,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
Expand All @@ -24,14 +24,17 @@ def swap_kernel(
# We need note that offsets is a list of pointers
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Initialize pointers to the start of the blocks
x_ptr = X + offsets
y_ptr = Y + offsets
# Create a mask to guard memory operations against out-of-bounds accesses
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extra elements in case the input is not a multiple of the block_size
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
x = tl.load(x_ptr, mask=mask)
y = tl.load(y_ptr, mask=mask)
# Write y to x and x to y in DRAM
tl.store(x_ptr + offsets, y, mask=mask)
tl.store(y_ptr + offsets, x, mask=mask)
tl.store(x_ptr, y, mask=mask)
tl.store(y_ptr, x, mask=mask)


def swap(
Expand Down