diff --git a/kernel_course/triton_ops/axpby.py b/kernel_course/triton_ops/axpby.py index 85f918f..af12da0 100644 --- a/kernel_course/triton_ops/axpby.py +++ b/kernel_course/triton_ops/axpby.py @@ -11,8 +11,8 @@ ) @triton.jit def axpby_kernel( - x_ptr, - y_ptr, + X, + Y, alpha, beta, n_elements, @@ -26,15 +26,19 @@ def axpby_kernel( # We need note that offsets is a list of pointers block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Initialize pointers to the start of the blocks + x_ptr = X + offsets + y_ptr = Y + offsets # Create a mask to guard memory operations against out-of-bounds accesses mask = offsets < n_elements # Load x and y from DRAM, masking out any extra elements in case the input is not a multiple of the block_size - x = tl.load(x_ptr + offsets, mask=mask) - y = tl.load(y_ptr + offsets, mask=mask) + x = tl.load(x_ptr, mask=mask) + y = tl.load(y_ptr, mask=mask) # Compute y = alpha * x + beta * y - y = alpha * x + beta * y + y = beta * y + y += alpha * x # Write y back to DRAM - tl.store(y_ptr + offsets, y, mask=mask) + tl.store(y_ptr, y, mask=mask) def axpby( diff --git a/kernel_course/triton_ops/copy.py b/kernel_course/triton_ops/copy.py index 2de1626..909371f 100644 --- a/kernel_course/triton_ops/copy.py +++ b/kernel_course/triton_ops/copy.py @@ -11,8 +11,8 @@ ) @triton.jit def copy_kernel( - x_ptr, - y_ptr, + X, + Y, n_elements, BLOCK_SIZE: tl.constexpr, ): @@ -24,12 +24,15 @@ def copy_kernel( # We need note that offsets is a list of pointers block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Initialize pointers to the start of the blocks + x_ptr = X + offsets + y_ptr = Y + offsets # Create a mask to guard memory operations against out-of-bounds accesses mask = offsets < n_elements # Load x from DRAM, masking out any extra elements in case the input is not a multiple of the block_size - x = tl.load(x_ptr + offsets, mask=mask) + x = tl.load(x_ptr, mask=mask) # Write x to y in DRAM - tl.store(y_ptr + offsets, x, mask=mask) + tl.store(y_ptr, x, mask=mask) def copy( diff --git a/kernel_course/triton_ops/dot.py b/kernel_course/triton_ops/dot.py index 2644d7a..427d8f0 100644 --- a/kernel_course/triton_ops/dot.py +++ b/kernel_course/triton_ops/dot.py @@ -11,9 +11,9 @@ ) @triton.jit def dot_kernel( - x_ptr, - y_ptr, - z_ptr, + X, + Y, + Z, n_elements, BLOCK_SIZE: tl.constexpr, ): @@ -25,15 +25,19 @@ def dot_kernel( # We need note that offsets is a list of pointers block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Initialize pointers to the start of the blocks + x_ptr = X + offsets + y_ptr = Y + offsets + z_ptr = Z + offsets # Create a mask to guard memory operations against out-of-bounds accesses mask = offsets < n_elements # Load x and y from DRAM, masking out any extra elements in case the input is not a multiple of the block_size - x = tl.load(x_ptr + offsets, mask=mask) - y = tl.load(y_ptr + offsets, mask=mask) + x = tl.load(x_ptr, mask=mask) + y = tl.load(y_ptr, mask=mask) # Compute z = x \cdot y z = tl.sum(x * y) # Write z back to DRAM - tl.store(z_ptr + offsets, z, mask=mask) + tl.store(z_ptr, z, mask=mask) def dot( diff --git a/kernel_course/triton_ops/gemv.py b/kernel_course/triton_ops/gemv.py index 1cb00b6..349da3e 100644 --- a/kernel_course/triton_ops/gemv.py +++ b/kernel_course/triton_ops/gemv.py @@ -17,15 +17,15 @@ ) @triton.jit def gemv_kernel( - A_ptr, - x_ptr, - y_ptr, - alpha, - beta, + A, + X, + Y, stride_am, stride_an, stride_x, stride_y, + alpha, + beta, n_elements_M, n_elements_N, BLOCK_M: tl.constexpr, @@ -39,6 +39,10 @@ def gemv_kernel( # This program will process inputs that offset from the initial pointer offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) + # Initialize pointers to the start of the blocks + A_ptr = A + offs_m[:, None] * stride_am + x_ptr = X + y_ptr = Y + offs_m * stride_y # Create a mask to guard memory operations against out-of-bounds accesses mask_m = offs_m < n_elements_M # Initialize the accumulator to zero for each row @@ -54,12 +58,10 @@ def gemv_kernel( mask_n = offs_n < n_elements_N # Load a block of A and x from DRAM, masking out any extra elements in case the input is not a multiple of the block size if EVEN_N & EVEN_M: - a = tl.load( - A_ptr + offs_m[:, None] * stride_am + offs_n[None, :] * stride_an - ) + a = tl.load(A_ptr + offs_n[None, :] * stride_an) else: a = tl.load( - A_ptr + offs_m[:, None] * stride_am + offs_n[None, :] * stride_an, + A_ptr + offs_n[None, :] * stride_an, mask=mask_m[:, None] & mask_n[None, :], other=0.0, ) @@ -71,16 +73,17 @@ def gemv_kernel( acc += tl.sum(a * x[None, :], axis=1) # Load y from DRAM, masking out any extra elements in case the input is not a multiple of the block size if EVEN_M: - y = tl.load(y_ptr + offs_m * stride_y) + y = tl.load(y_ptr) else: - y = tl.load(y_ptr + offs_m * stride_y, mask=mask_m, other=0.0) + y = tl.load(y_ptr, mask=mask_m, other=0.0) # Compute y = alpha * A * x + beta * y - y_new = (alpha * acc + beta * y).to(y.dtype) + y_new = beta * y + y_new += alpha * acc # Write y back to DRAM if EVEN_M: - tl.store(y_ptr + offs_m * stride_y, y_new) + tl.store(y_ptr, y_new) else: - tl.store(y_ptr + offs_m * stride_y, y_new, mask=mask_m) + tl.store(y_ptr, y_new, mask=mask_m) def gemv( @@ -117,12 +120,12 @@ def grid(meta): A, x, y, - alpha, - beta, A.stride(0), A.stride(1), x.stride(0), y.stride(0), + alpha, + beta, n_elements_M, n_elements_N, ) diff --git a/kernel_course/triton_ops/scal.py b/kernel_course/triton_ops/scal.py index 0ab4bda..4ef30f5 100644 --- a/kernel_course/triton_ops/scal.py +++ b/kernel_course/triton_ops/scal.py @@ -11,7 +11,7 @@ ) @triton.jit def scal_kernel( - y_ptr, + Y, alpha, n_elements, BLOCK_SIZE: tl.constexpr, @@ -24,14 +24,16 @@ def scal_kernel( # We need note that offsets is a list of pointers block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Initialize pointers to the start of the blocks + y_ptr = Y + offsets # Create a mask to guard memory operations against out-of-bounds accesses mask = offsets < n_elements # Load y from DRAM, masking out any extra elements in case the input is not a multiple of the block_size - y = tl.load(y_ptr + offsets, mask=mask) + y = tl.load(y_ptr, mask=mask) # Scale y by alpha y = y * alpha # Write y back to DRAM - tl.store(y_ptr + offsets, y, mask=mask) + tl.store(y_ptr, y, mask=mask) def scal( diff --git a/kernel_course/triton_ops/swap.py b/kernel_course/triton_ops/swap.py index 0cff6b9..e36176c 100644 --- a/kernel_course/triton_ops/swap.py +++ b/kernel_course/triton_ops/swap.py @@ -11,8 +11,8 @@ ) @triton.jit def swap_kernel( - x_ptr, - y_ptr, + X, + Y, n_elements, BLOCK_SIZE: tl.constexpr, ): @@ -24,14 +24,17 @@ def swap_kernel( # We need note that offsets is a list of pointers block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Initialize pointers to the start of the blocks + x_ptr = X + offsets + y_ptr = Y + offsets # Create a mask to guard memory operations against out-of-bounds accesses mask = offsets < n_elements # Load x and y from DRAM, masking out any extra elements in case the input is not a multiple of the block_size - x = tl.load(x_ptr + offsets, mask=mask) - y = tl.load(y_ptr + offsets, mask=mask) + x = tl.load(x_ptr, mask=mask) + y = tl.load(y_ptr, mask=mask) # Write y to x and x to y in DRAM - tl.store(x_ptr + offsets, y, mask=mask) - tl.store(y_ptr + offsets, x, mask=mask) + tl.store(x_ptr, y, mask=mask) + tl.store(y_ptr, x, mask=mask) def swap(