From 458eaef479a7ba1f1035836307cdbe47d3e91850 Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Fri, 20 Feb 2026 10:27:23 +0200 Subject: [PATCH 1/7] matmul: more concise knob names --- examples/xegpu_matmul/matmul.py | 28 ++++++++++++++-------------- examples/xegpu_matmul/schedule.py | 19 +++++++++---------- 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/examples/xegpu_matmul/matmul.py b/examples/xegpu_matmul/matmul.py index 277f9c5..2e7da58 100644 --- a/examples/xegpu_matmul/matmul.py +++ b/examples/xegpu_matmul/matmul.py @@ -342,20 +342,20 @@ def parse_cli(): args = parse_cli() params = { - "auto_wg_d0": args.wg_tile[0], - "auto_wg_d1": args.wg_tile[1], - "auto_sg_d0": args.sg_tile[0], - "auto_sg_d1": args.sg_tile[1], - "auto_k": args.k_tile, - "auto_load_a_d0": args.load_tile_a[0], - "auto_load_a_d1": args.load_tile_a[1], - "auto_load_b_d0": args.load_tile_b[0], - "auto_load_b_d1": args.load_tile_b[1], - "auto_prefetch_a_d0": args.prefetch_tile_a[0], - "auto_prefetch_a_d1": args.prefetch_tile_a[1], - "auto_prefetch_b_d0": args.prefetch_tile_b[0], - "auto_prefetch_b_d1": args.prefetch_tile_b[1], - "auto_nb_prefetch": args.nb_prefetch, + "wg_m": args.wg_tile[0], + "wg_n": args.wg_tile[1], + "sg_m": args.sg_tile[0], + "sg_n": args.sg_tile[1], + "k": args.k_tile, + "load_a_m": args.load_tile_a[0], + "load_a_k": args.load_tile_a[1], + "load_b_k": args.load_tile_b[0], + "load_b_n": args.load_tile_b[1], + "pf_a_m": args.prefetch_tile_a[0], + "pf_a_k": args.prefetch_tile_a[1], + "pf_b_k": args.prefetch_tile_b[0], + "pf_b_n": args.prefetch_tile_b[1], + "pf_nb": args.nb_prefetch, } M, N, K = args.sizes diff --git a/examples/xegpu_matmul/schedule.py b/examples/xegpu_matmul/schedule.py index b5827be..1a8a7a6 100644 --- a/examples/xegpu_matmul/schedule.py +++ b/examples/xegpu_matmul/schedule.py @@ -101,16 +101,15 @@ def bundle_xepu_matmul_schedule( raise ValueError("Schedule parameters must be provided.") # tunable parameters - wg_tile = [params["auto_wg_d0"], params["auto_wg_d1"]] - sg_tile = [params["auto_sg_d0"], params["auto_sg_d1"]] - k_tile = params["auto_k"] - - load_tile_a = [params["auto_load_a_d0"], params["auto_load_a_d1"]] - load_tile_b = [params["auto_load_b_d0"], params["auto_load_b_d1"]] - - prefetch_tile_a = [params["auto_prefetch_a_d0"], params["auto_prefetch_a_d1"]] - prefetch_tile_b = [params["auto_prefetch_b_d0"], params["auto_prefetch_b_d1"]] - nb_prefetch = params["auto_nb_prefetch"] + wg_tile = [params["wg_m"], params["wg_n"]] + sg_tile = [params["sg_m"], params["sg_n"]] + k_tile = params["k"] + + load_tile_a = [params["load_a_m"], params["load_a_k"]] + load_tile_b = [params["load_b_k"], params["load_b_n"]] + prefetch_tile_a = [params["pf_a_m"], params["pf_a_k"]] + prefetch_tile_b = [params["pf_b_k"], params["pf_b_n"]] + nb_prefetch = params["pf_nb"] # derived parameters sg_layout = [wg_tile[0] // sg_tile[0], wg_tile[1] // sg_tile[1]] From d12489153bea935148e509bbdfb1e4f87f8fd928 Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Thu, 15 Jan 2026 21:12:21 +0200 Subject: [PATCH 2/7] add xegpu mlp example --- examples/xegpu_matmul/mlp.py | 585 ++++++++++++++++++++++++++++++ examples/xegpu_matmul/payload.py | 207 ++++++++++- examples/xegpu_matmul/schedule.py | 413 +++++++++++++++++++++ 3 files changed, 1190 insertions(+), 15 deletions(-) create mode 100644 examples/xegpu_matmul/mlp.py diff --git a/examples/xegpu_matmul/mlp.py b/examples/xegpu_matmul/mlp.py new file mode 100644 index 0000000..71907b6 --- /dev/null +++ b/examples/xegpu_matmul/mlp.py @@ -0,0 +1,585 @@ +# RUN: %PYTHON %s --dump-kernel=xegpu-wg | FileCheck %s +# CHECK: module attributes {gpu.container_module} { + +""" +XeGPU matrix multiplication benchmark. +""" + +import argparse +import ctypes +from typing import Optional +from contextlib import contextmanager +from functools import cached_property + +import numpy as np +from mlir import ir +from mlir.runtime.np_to_memref import ( + get_ranked_memref_descriptor, + make_nd_memref_descriptor, + as_ctype, +) +from mlir.execution_engine import ExecutionEngine + +from lighthouse.workload import Workload, benchmark +from lighthouse.utils.memref import get_packed_arg, to_ctype as memref_to_ctype + +# Import from sibling files: +from schedule import get_schedule_module_mlp +from payload import generate_mlp_payload + + +def numpy_to_ctype(arr: np.ndarray) -> ctypes._Pointer: + """Convert numpy array to memref and ctypes **void pointer.""" + return memref_to_ctype(get_ranked_memref_descriptor(arr)) + + +class XeGPUMLP(Workload): + """ + Multi-layer perceptron (MLP) workload on XeGPU. + + Optionally adds a ReLU operation after each layer. + Optionally adds a bias term in each layer (not implemented yet). + """ + + payload_function_name = "payload" + + def __init__( + self, + batch_size: int, + input_size: int, + output_size: int, + hidden_layer_sizes: Optional[list[int]] = None, + ab_type: str = "f16", + c_type: str = "f32", + has_bias: bool = False, + has_relu: bool = False, + ): + self.batch_size = batch_size + self.input_size = input_size + self.output_size = output_size + self.hidden_layer_sizes = hidden_layer_sizes or [] + self.input_shape = (self.batch_size, self.input_size) + self.output_shape = (self.batch_size, self.output_size) + layer_sizes = [self.input_size] + self.hidden_layer_sizes + [self.output_size] + self.weight_shapes = list(zip(layer_sizes[:-1], layer_sizes[1:])) + self.matmul_layers = [(self.batch_size, o, i) for i, o in self.weight_shapes] + + assert ab_type == "f16", "Only f16 type is supported for A and B" + assert c_type == "f32", "Only f32 type is supported for C" + self.ab_type = ab_type + self.c_type = c_type + type_str_to_numpy = { + "f16": np.float16, + "f32": np.float32, + } + self.ab_dtype = type_str_to_numpy[ab_type] + self.c_dtype = type_str_to_numpy[c_type] + self.has_bias = has_bias + self.has_relu = has_relu + if has_bias: + raise NotImplementedError("Bias is not implemented yet") + # cache allocated memrefs + self.gpu_memrefs = {} + + def _allocate_array( + self, + name: str, + shape: tuple[int, ...], + dtype_str: str, + execution_engine: ExecutionEngine, + ) -> ctypes.Structure: + key = (name, dtype_str) + if key in self.gpu_memrefs: + return self.gpu_memrefs[key] + dtype = { + "f16": np.float16, + "f32": np.float32, + }[dtype_str] + alloc_func = execution_engine.lookup("gpu_alloc_" + dtype_str) + mref = make_nd_memref_descriptor(len(shape), as_ctype(dtype))() + ptr_mref = ctypes.pointer(ctypes.pointer(mref)) + ptr_dims = [ctypes.pointer(ctypes.c_int32(d)) for d in shape] + alloc_func(get_packed_arg([ptr_mref] + ptr_dims)) + self.gpu_memrefs[key] = mref + return mref + + def _deallocate_all(self, execution_engine: ExecutionEngine): + for (_, dtype_str), mref in self.gpu_memrefs.items(): + dealloc_func = execution_engine.lookup("gpu_dealloc_" + dtype_str) + ptr_mref = ctypes.pointer(ctypes.pointer(mref)) + dealloc_func(get_packed_arg([ptr_mref])) + self.gpu_memrefs = {} + + @contextmanager + def allocate_inputs(self, execution_engine: ExecutionEngine): + try: + yield self._get_input_arrays(execution_engine) + finally: + self._deallocate_all(execution_engine) + + @cached_property + def _initial_host_arrays(self) -> list[np.ndarray]: + """Generate initial values on host with numpy.""" + + # use integer values to avoid f16/f32 floating point discrepancies + def gen_random(shape, dtype): + # generate random {-1, 1} values + a = np.round(np.random.random_sample(shape)) + a[a == 0] = -1 + return a.astype(dtype) + + np.random.seed(2) + input_array = gen_random(self.input_shape, self.ab_dtype) + output_array = np.zeros(self.output_shape, self.ab_dtype) + weights = [] + for i, o in self.weight_shapes: + W = gen_random((i, o), self.ab_dtype) + weights.append(W) + + if self.has_bias: + raise NotImplementedError("Bias initialization not implemented") + + return input_array, output_array, *weights + + @cached_property + def _reference_solution(self) -> np.ndarray: + """Compute reference solution on host with numpy.""" + # NOTE for large problems the solution can overflow float16 range + host_arrays = self._initial_host_arrays + # use float32 data type for efficiency + host_arrays = [arr.astype(np.float32) for arr in host_arrays] + input_array = host_arrays[0] + output_array = host_arrays[1] + weights = host_arrays[2:] + + a_array = input_array + for W in weights: + C_ref = a_array @ W + if self.has_relu: + C_ref = np.maximum(C_ref, 0) + if self.has_bias: + raise NotImplementedError("Bias verification not implemented") + a_array = C_ref.astype(self.ab_dtype).astype(np.float32) + + C_ref += output_array + return C_ref.astype(self.ab_dtype) + + def _get_input_arrays( + self, execution_engine: ExecutionEngine + ) -> list[ctypes.Structure]: + if self.has_bias: + raise NotImplementedError("Bias allocation not implemented yet") + + # allocate arrays on device + input_gpu = self._allocate_array( + "input", self.input_shape, self.ab_type, execution_engine + ) + output_gpu = self._allocate_array( + "output", self.output_shape, self.ab_type, execution_engine + ) + gpu_arrays = [input_gpu, output_gpu] + for i, (in_size, out_size) in enumerate(self.weight_shapes): + W_gpu = self._allocate_array( + f"weight_{i}", (in_size, out_size), self.ab_type, execution_engine + ) + gpu_arrays.append(W_gpu) + + # get initial host arrays + host_arrays = self._initial_host_arrays + # copy initial values to device + copy_func_ab = execution_engine.lookup("gpu_copy_" + self.ab_type) + for host_arr, gpu_arr in zip(host_arrays, gpu_arrays): + copy_func_ab( + get_packed_arg([numpy_to_ctype(host_arr), memref_to_ctype(gpu_arr)]) + ) + + # return memrefs for the payload function + return gpu_arrays + + def check_correctness( + self, execution_engine: ExecutionEngine, verbose: int = 0 + ) -> bool: + # copy result from device to host + res_gpu = self.gpu_memrefs[("output", self.ab_type)] + res_host_copy = np.zeros(self.output_shape, dtype=self.ab_dtype) + copy_func = execution_engine.lookup("gpu_copy_" + self.ab_type) + copy_func( + get_packed_arg([memref_to_ctype(res_gpu), numpy_to_ctype(res_host_copy)]) + ) + + res_host_ref = self._reference_solution + res_host = res_host_copy + if verbose > 1: + print("Reference solution:") + print(res_host_ref) + print("Computed solution:") + print(res_host) + success = np.allclose(res_host, res_host_ref) + + if verbose: + if success: + print("PASSED") + else: + print("FAILED Result mismatch!") + print(f"Max absolute error: {np.max(np.abs(res_host - res_host_ref))}") + num_diff = np.sum(np.abs(res_host - res_host_ref) > 1e-3) + print(f"Number of differing elements: {num_diff}") + return success + + def get_complexity(self) -> tuple[int, int, int]: + nbytes_ab = np.dtype(self.ab_dtype).itemsize + nbytes_c = np.dtype(self.c_dtype).itemsize + + def matmul_complexity(M, N, K, has_bias, has_relu): + flop_count = 2 * M * N * K + memory_reads = (M * K + K * N) * nbytes_ab # read A and B + memory_writes = M * N * nbytes_c # write C + if has_bias: + flop_count += M * N + memory_reads += N * nbytes_c # read bias + if has_relu: + flop_count += M * N + return flop_count, memory_reads, memory_writes + + flop_count = 0 + memory_reads = 0 + memory_writes = 0 + for M, N, K in self.matmul_layers: + f, r, w = matmul_complexity(M, N, K, self.has_bias, self.has_relu) + flop_count += f + memory_reads += r + memory_writes += w + return (flop_count, memory_reads, memory_writes) + + def payload_module(self) -> ir.Module: + mod = generate_mlp_payload( + func_name=self.payload_function_name, + batch_size=self.batch_size, + input_size=self.input_size, + output_size=self.output_size, + hidden_layer_sizes=self.hidden_layer_sizes, + ab_type_str=self.ab_type, + c_type_str=self.c_type, + has_bias=self.has_bias, + has_relu=self.has_relu, + ) + return mod + + def schedule_module( + self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None + ) -> ir.Module: + return get_schedule_module_mlp( + has_bias=self.has_bias, + has_relu=self.has_relu, + stop_at_stage=stop_at_stage, + nlayers=len(self.matmul_layers), + params=parameters, + ) + + def shared_libs(self) -> list[str]: + return ["libmlir_levelzero_runtime.so"] + + +matmul_param_db = { + (4096, 4096, 4096): { + "wg_m": 256, + "wg_n": 256, + "sg_m": 32, + "sg_n": 32, + "k": 64, + "load_a_m": 8, + "load_a_k": 16, + "load_b_k": 16, + "load_b_n": 16, + "pf_a_m": 8, + "pf_a_k": 32, + "pf_b_k": 8, + "pf_b_n": 32, + "pf_nb": 1, + }, + (128, 16384, 16384): { + "wg_m": 128, + "wg_n": 256, + "sg_m": 32, + "sg_n": 32, + "k": 256, + "load_a_m": 8, + "load_a_k": 16, + "load_b_k": 32, + "load_b_n": 16, + "pf_a_m": 8, + "pf_a_k": 16, + "pf_b_k": 8, + "pf_b_n": 16, + "pf_nb": 1, + }, + (128, 8192, 16384): { + "wg_m": 64, + "wg_n": 128, + "sg_m": 32, + "sg_n": 32, + "k": 128, + "load_a_m": 16, + "load_a_k": 16, + "load_b_k": 16, + "load_b_n": 16, + "pf_a_m": 32, + "pf_a_k": 16, + "pf_b_k": 16, + "pf_b_n": 32, + "pf_nb": 1, + }, + (128, 32768, 16384): { + "wg_m": 128, + "wg_n": 128, + "sg_m": 32, + "sg_n": 32, + "k": 256, + "load_a_m": 8, + "load_a_k": 16, + "load_b_k": 16, + "load_b_n": 16, + "pf_a_m": 16, + "pf_a_k": 32, + "pf_b_k": 8, + "pf_b_n": 32, + "pf_nb": 1, + }, + (128, 16384, 32768): { + "wg_m": 128, + "wg_n": 128, + "sg_m": 32, + "sg_n": 32, + "k": 256, + "load_a_m": 8, + "load_a_k": 16, + "load_b_k": 16, + "load_b_n": 16, + "pf_a_m": 32, + "pf_a_k": 32, + "pf_b_k": 8, + "pf_b_n": 16, + "pf_nb": 1, + }, + (128, 32768, 32768): { + "wg_m": 128, + "wg_n": 256, + "sg_m": 32, + "sg_n": 32, + "k": 256, + "load_a_m": 8, + "load_a_k": 16, + "load_b_k": 16, + "load_b_n": 16, + "pf_a_m": 16, + "pf_a_k": 32, + "pf_b_k": 32, + "pf_b_n": 32, + "pf_nb": 1, + }, + (1024, 1024, 8192): { + "wg_m": 256, + "wg_n": 128, + "sg_m": 32, + "sg_n": 32, + "k": 32, + "load_a_m": 8, + "load_a_k": 16, + "load_b_k": 32, + "load_b_n": 16, + "pf_a_m": 8, + "pf_a_k": 16, + "pf_b_k": 8, + "pf_b_n": 16, + "pf_nb": 1, + }, + (1024, 8192, 1024): { + "wg_m": 256, + "wg_n": 128, + "sg_m": 32, + "sg_n": 32, + "k": 32, + "load_a_m": 16, + "load_a_k": 16, + "load_b_k": 32, + "load_b_n": 16, + "pf_a_m": 8, + "pf_a_k": 16, + "pf_b_k": 16, + "pf_b_n": 16, + "pf_nb": 1, + }, + (1024, 1024, 1024): { + "wg_m": 128, + "wg_n": 64, + "sg_m": 32, + "sg_n": 32, + "k": 32, + "load_a_m": 16, + "load_a_k": 16, + "load_b_k": 32, + "load_b_n": 16, + "pf_a_m": 8, + "pf_a_k": 32, + "pf_b_k": 8, + "pf_b_n": 16, + "pf_nb": 1, + }, +} + + +class ParameterOracleMLP: + def __init__(self, workload: XeGPUMLP): + self.param_db = matmul_param_db + self.workload = workload + + def get_parameters(self) -> dict[str, dict]: + parameters = {} + for i, shape in enumerate(self.workload.matmul_layers): + if shape in self.param_db: + params = self.param_db[shape] + else: + raise ValueError(f"No parameters found for matmul shape {shape}") + parameters[f"layer_{i}"] = params + return parameters + + +def parse_cli(): + parser = argparse.ArgumentParser( + description="Matrix Multiplication using MLIR", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "-b", + "--batch-size", + type=int, + default=1024, + help="Batch size M. Input matrix has shape (M x K).", + ) + parser.add_argument( + "-i", + "--input-size", + type=int, + default=1024, + help="Number of input features K. Input matrix has shape (M x K).", + ) + parser.add_argument( + "-o", + "--output-size", + type=int, + default=1024, + help="Number of output features N. Output matrix has shape (M x N).", + ) + parser.add_argument( + "--hidden-sizes", + type=int, + nargs="+", + help="Number of features in each hidden layers.", + ) + parser.add_argument( + "--nruns", + type=int, + default=1000, + help="Number of runs to average the execution time.", + ) + parser.add_argument( + "--nwarmup", + type=int, + default=20, + help="Number of warm-up iterations before benchmarking.", + ) + parser.add_argument( + "--relu", + action="store_true", + help="Add relu op after the matrix multiplication (and bias if any).", + ) + parser.add_argument( + "--check-result", + action="store_true", + help="Check the result of the matrix multiplication.", + ) + parser.add_argument( + "--dump-kernel", + type=str, + choices=[ + "initial", + "tiled", + "vectorized", + "bufferized", + "xegpu-initial", + "xegpu-wg", + "xegpu-sg", + "xegpu-inst", + "final", + ], + help="Dump kernel IR at different stages of lowering.", + ) + parser.add_argument( + "--dump-schedule", + action="store_true", + help="Dump transform schedule.", + ) + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_cli() + + ab_type = "f16" + c_type = "f32" + + with ir.Context(), ir.Location.unknown(): + wload = XeGPUMLP( + batch_size=args.batch_size, + input_size=args.input_size, + output_size=args.output_size, + hidden_layer_sizes=args.hidden_sizes, + ab_type=ab_type, + c_type=c_type, + has_bias=False, + has_relu=args.relu, + ) + matmuls = wload.matmul_layers + print(f"MLP with {len(matmuls)} layers") + for i, (M, N, K) in enumerate(matmuls): + print(f" Layer {i}: M={M}, N={N}, K={K}") + + param_oracle = ParameterOracleMLP(wload) + params = param_oracle.get_parameters() + + if args.dump_kernel or args.dump_schedule: + wload.lower_payload( + dump_payload=args.dump_kernel, + dump_schedule=args.dump_schedule, + schedule_parameters=params, + ) + else: + times = benchmark( + wload, + nruns=args.nruns, + nwarmup=args.nwarmup, + schedule_parameters=params, + check_correctness=args.check_result, + verbose=2, + ) + times *= 1e6 # convert to microseconds + elapsed = np.mean(times) + flop_count = wload.get_complexity()[0] + gflops = flop_count / (elapsed * 1e-6) / 1e9 + + def list2str(a): + return ",".join(map(str, a)) + + hidden_sizes = args.hidden_sizes if args.hidden_sizes else [] + parts = [ + f"b={args.batch_size}", + f"i={args.input_size}", + f"o={args.output_size}", + f"hs={list2str(hidden_sizes)}", + f"dt={ab_type},{c_type}", + f"time(us): {elapsed:.2f}", + f"GFLOPS: {gflops:.2f}", + ] + print(" ".join(parts)) diff --git a/examples/xegpu_matmul/payload.py b/examples/xegpu_matmul/payload.py index 0cf3a45..2e6615c 100644 --- a/examples/xegpu_matmul/payload.py +++ b/examples/xegpu_matmul/payload.py @@ -52,6 +52,64 @@ def emit_gpu_util_funcs(element_type: ir.Type): emit_gpu_copy(suffix, element_type) +def emit_mlp_layer( + a_tensor, + b_tensor, + c_tensor, + ab_type, + c_type, + bias_tensor=None, + has_relu=False, + convert_c_type=False, +) -> ir.Value: + M, N = c_tensor.type.shape + id_map = ir.AffineMap.get_identity(2) + par_iter = linalg.IteratorType.parallel + if convert_c_type: + empty = tensor.empty((M, N), c_type) + + @linalg.generic( + [c_tensor], + [empty], + [id_map, id_map], + [par_iter, par_iter], + ) + def f(a, b): + return arith.extf(c_type, a) + + input_c_tensor = f + else: + input_c_tensor = c_tensor + mmul = linalg.matmul(a_tensor, b_tensor, outs=[input_c_tensor]) + terminal = mmul + res_type = c_type + if convert_c_type: + res_type = ab_type + empty = tensor.empty((M, N), ab_type) + + @linalg.generic( + [terminal], + [empty], + [id_map, id_map], + [par_iter, par_iter], + ) + def f(a, b): + return arith.truncf(ab_type, a) + + terminal = f + if bias_tensor is not None: + empty = tensor.empty((M, N), res_type) + bcast = linalg.broadcast(bias_tensor, outs=[empty], dimensions=[0]) + terminal = linalg.add(bcast, terminal, outs=[empty]) + if has_relu: + zero = arith.constant(ab_type if convert_c_type else c_type, 0.0) + empty = tensor.empty((M, N), res_type) + zero_tensor = linalg.fill(zero, outs=[empty]) + terminal = linalg.max(terminal, zero_tensor, outs=[empty]) + + return terminal + + def generate_matmul_payload( func_name: str, M: int, @@ -72,7 +130,6 @@ def generate_matmul_payload( tensor_a_t = ir.RankedTensorType.get((M, K), ab_type) tensor_b_t = ir.RankedTensorType.get((K, N), ab_type) tensor_c_t = ir.RankedTensorType.get((M, N), c_type) - tensor_bias_t = ir.RankedTensorType.get((N,), c_type) memref_a_t = ir.MemRefType.get((M, K), ab_type) memref_b_t = ir.MemRefType.get((K, N), ab_type) memref_c_t = ir.MemRefType.get((M, N), c_type) @@ -89,30 +146,150 @@ def payload(*args): A = args[0] B = args[1] C = args[-1] + bias = args[2] if has_bias else None a_tensor = bufferization.to_tensor(tensor_a_t, A, restrict=True) b_tensor = bufferization.to_tensor(tensor_b_t, B, restrict=True) c_tensor = bufferization.to_tensor( tensor_c_t, C, restrict=True, writable=True ) - - mmul = linalg.matmul(a_tensor, b_tensor, outs=[c_tensor]) - terminal = mmul if has_bias: - bias = args[2] bias_tensor = bufferization.to_tensor( - tensor_bias_t, bias, restrict=True, writable=True + ir.RankedTensorType.get((N,), c_type), bias, restrict=True + ) + else: + bias_tensor = None + + output = emit_mlp_layer( + a_tensor, + b_tensor, + c_tensor, + ab_type, + c_type, + bias_tensor, + has_relu, + convert_c_type=False, + ) + bufferization.materialize_in_destination( + None, output, C, restrict=True, writable=True + ) + + payload.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + + emit_gpu_util_funcs(ab_type) + if c_type != ab_type: + emit_gpu_util_funcs(c_type) + + return mod + + +def emit_buf_to_tensor(memref_value: ir.Value, **kwargs) -> ir.Value: + memref_type = memref_value.type + shape = memref_type.shape + element_type = memref_type.element_type + tensor_type = ir.RankedTensorType.get(shape, element_type) + return bufferization.to_tensor(tensor_type, memref_value, **kwargs) + + +def generate_mlp_payload( + func_name: str, + batch_size: int, + input_size: int, + output_size: int, + hidden_layer_sizes: list[int], + ab_type_str: str, + c_type_str: str, + has_bias: bool, + has_relu: bool, +) -> ir.Module: + """Generate payload function module.""" + get_ir_dtype = { + "f16": ir.F16Type.get(), + "f32": ir.F32Type.get(), + } + ab_type = get_ir_dtype[ab_type_str] + c_type = get_ir_dtype[c_type_str] + mod = ir.Module.create() + memref_in_t = ir.MemRefType.get((batch_size, input_size), ab_type) + memref_out_t = ir.MemRefType.get((batch_size, output_size), ab_type) + layer_sizes = [input_size] + hidden_layer_sizes + [output_size] + feature_sizes = list(zip(layer_sizes[:-1], layer_sizes[1:])) + weight_memref_types = [] + bias_memref_types = [] + for in_size, out_size in feature_sizes: + memref_t = ir.MemRefType.get((in_size, out_size), ab_type) + weight_memref_types.append(memref_t) + if has_bias: + memref_t = ir.MemRefType.get((out_size,), c_type) + bias_memref_types.append(memref_t) + with ir.InsertionPoint(mod.body): + # function argument order: + # input, output, weights_0, weights_1, ..., [bias_0, bias_1, ...] + fargs = [memref_in_t, memref_out_t] + weight_memref_types + if has_bias: + fargs += bias_memref_types + + @func.func(*fargs, name=func_name) + def payload(*args): + input = args[0] + output = args[1] + nlayers = len(hidden_layer_sizes) + 1 + weights = args[2 : 2 + nlayers] + biases = args[2 + nlayers :] if has_bias else [None] * nlayers + input_tensor = emit_buf_to_tensor(input, restrict=True) + output_tensor = emit_buf_to_tensor(output, restrict=True) + weight_tensors = [] + for weight_memref in weights: + weight_tensor = emit_buf_to_tensor(weight_memref, restrict=True) + weight_tensors.append(weight_tensor) + bias_tensors = [] + for bias_memref in biases: + if has_bias: + bias_tensor = emit_buf_to_tensor(bias_memref, restrict=True) + else: + bias_tensor = None + bias_tensors.append(bias_tensor) + + layer_output = input_tensor + to_dealloc = None + for i, (weight, bias) in enumerate(zip(weight_tensors, bias_tensors)): + a_tensor = layer_output + b_tensor = weight + M, K = a_tensor.type.shape + _, N = b_tensor.type.shape + if i == nlayers - 1: + c_tensor = output_tensor + else: + # allocate intermediate buffer + memref_type = ir.MemRefType.get((M, N), ab_type) + c_memref = gpu.alloc(memref_type, None, [], [], []) + gpu.memset(None, [], c_memref, arith.constant(ab_type, 0.0)) + c_tensor = emit_buf_to_tensor( + c_memref, restrict=True, writable=True + ) + bias_tensor = bias + layer_output = emit_mlp_layer( + a_tensor, + b_tensor, + c_tensor, + ab_type, + c_type, + bias_tensor, + has_relu, + convert_c_type=True, ) - empty = tensor.empty((M, N), c_type) - bcast = linalg.broadcast(bias_tensor, outs=[empty], dimensions=[0]) - terminal = linalg.add(bcast, terminal, outs=[empty]) - if has_relu: - zero = arith.constant(c_type, 0.0) - empty = tensor.empty((M, N), c_type) - zero_tensor = linalg.fill(zero, outs=[empty]) - terminal = linalg.max(terminal, zero_tensor, outs=[empty]) + if to_dealloc is not None: + gpu.dealloc(None, [], to_dealloc) + to_dealloc = None + if i != nlayers - 1: + bufferization.materialize_in_destination( + None, layer_output, c_memref, restrict=True, writable=True + ) + # deallocate after next layer + to_dealloc = c_memref + # finalize bufferization.materialize_in_destination( - None, terminal, C, restrict=True, writable=True + None, layer_output, output, restrict=True, writable=True ) payload.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() diff --git a/examples/xegpu_matmul/schedule.py b/examples/xegpu_matmul/schedule.py index 1a8a7a6..2ec1145 100644 --- a/examples/xegpu_matmul/schedule.py +++ b/examples/xegpu_matmul/schedule.py @@ -19,12 +19,425 @@ class PipelineInterrupt(Exception): pass +def match_and_split(*args, nhandles=1, **kwargs): + """Henper function that splits matched handles.""" + matched = match(*args, **kwargs) + anytype = transform.AnyOpType.get() + matched_ops = transform.split_handle((anytype,) * nhandles, matched) + if nhandles == 1: + matched_ops = [matched_ops] + return matched_ops + + # hardware constraints dpas_tile = [8, 16, 16] prefetch_inst_data = [8, 16] nb_workitems = 16 # workitems in subgroup +def get_schedule_module_mlp( + has_bias: bool = False, + has_relu: bool = False, + stop_at_stage: str = "", + nlayers: int = 1, + params: Optional[dict] = None, +) -> ir.Module: + """Generate transform schedule module.""" + mod = ir.Module.create() + mod.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get() + with ir.InsertionPoint(mod.body): + named_sequence = transform.named_sequence( + "__transform_main", + [transform.AnyOpType.get()], # input types + [], # output types + arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}], + ) + with ir.InsertionPoint(named_sequence.body): + # match the payload module + anytype = transform.AnyOpType.get() + func = match(named_sequence.bodyTarget, ops={"func.func"}) + payload_mod = transform.get_parent_op( + anytype, + func, + op_name="builtin.module", + deduplicate=True, + ) + xegpu_mlp_transform_schedule( + payload_mod, + has_bias=has_bias, + has_relu=has_relu, + stop_at_stage=stop_at_stage, + nlayers=nlayers, + params=params, + ) + + return mod + + +def xegpu_mlp_transform_schedule( + mod: ir.Value, + has_bias: bool = False, + has_relu: bool = False, + has_convert_c: bool = True, + stop_at_stage: str = "", + nlayers: int = 1, + params: Optional[list[dict]] = None, +): + """Transform schedule for matmul-like payload.""" + try: + mod = bundle_xepu_mlp_schedule( + mod, + has_bias=has_bias, + has_relu=has_relu, + has_convert_c=has_convert_c, + stop_at_stage=stop_at_stage, + nlayers=nlayers, + params=params, + ) + + mod = bundle_xegpu_to_binary( + mod, + stop_at_stage=stop_at_stage, + ) + except PipelineInterrupt: + pass + finally: + transform.yield_() + + +def bundle_xepu_mlp_schedule( + mod: ir.Value, + has_bias: bool = False, + has_relu: bool = False, + has_convert_c: bool = True, + stop_at_stage: str = "", + nlayers: int = 1, + params: Optional[list[dict]] = None, +) -> ir.Module: + """Schedule for lowering matmul-like payload to xegpu wg level.""" + if params is None: + raise ValueError("Schedule parameters must be provided.") + + if stop_at_stage == "initial": + raise PipelineInterrupt() + + anytype = transform.AnyOpType.get() + anyvalue = transform.AnyValueType.get() + + for i in range(nlayers): + assert f"layer_{i}" in params, f"Missing parameters for 'layer_{i}'" + + dpas_shape_a = [dpas_tile[0], dpas_tile[2]] + dpas_shape_b = [dpas_tile[2], dpas_tile[1]] + dpas_shape_c = [dpas_tile[0], dpas_tile[1]] + + # wg tiling + + if has_relu: + terminal_ops = match_and_split(mod, ops={"linalg.max"}, nhandles=nlayers) + elif has_convert_c: + trunc_op = match(mod, ops={"arith.truncf"}) + terminal = transform.get_parent_op(anytype, trunc_op) + # split handle for each layer + terminal_ops = transform.split_handle((anytype,) * nlayers, terminal) + if nlayers == 1: + terminal_ops = [terminal_ops] + elif has_bias: + terminal_ops = match_and_split(mod, ops={"linalg.add"}, nhandles=nlayers) + else: + terminal_ops = match_and_split(mod, ops={"linalg.matmul"}, nhandles=nlayers) + + # tile each layer separately + for i_layer in range(nlayers): + layer_params = params[f"layer_{i_layer}"] + # tunable parameters: wg level tiling + wg_tile = [layer_params["wg_m"], layer_params["wg_n"]] + sg_tile = [layer_params["sg_m"], layer_params["sg_n"]] + k_tile = layer_params["k"] + + terminal = terminal_ops[i_layer] + # FIXME use structured.structured_fuse + _, wg_loop = structured.FuseOp( + terminal, tile_sizes=wg_tile, use_forall=True + ).results + transform.apply_cse(mod) + canonicalize(mod) + + # k loop tiling + wg_matmul = match(wg_loop, ops={"linalg.matmul"}) + # FIXME use structured.structured_tile_using_for + wgk_matmul, k_loop = structured.TileUsingForOp( + wg_matmul, sizes=[0, 0, k_tile] + ).results + + func = transform.get_parent_op( + anytype, + k_loop, + op_name="func.func", + deduplicate=True, + ) + transform.apply_cse(func) + canonicalize(func) + + if stop_at_stage == "tiled": + raise PipelineInterrupt() + + # vectorize + # FIXME use structured.structured_vectorize_children_and_apply_patterns + func = structured.VectorizeChildrenAndApplyPatternsOp( + func, + fold_type_extensions_into_contract=True, + ).result + + # hoist loop invariant vector read/store ops + k_loop = match(func, ops={"scf.for"}) + loop.HoistLoopInvariantSubsetsOp(k_loop) + + transform.apply_cse(func) + canonicalize(func) + + if stop_at_stage == "vectorized": + raise PipelineInterrupt() + + # bufferize + + # eliminate empty tensors to avoid emitting extra copy ops + mod = apply_registered_pass(mod, "eliminate-empty-tensors") + identity_layout = LayoutMapOption.IdentityLayoutMap + mod = bufferization.OneShotBufferizeOp( + mod, + allow_return_allocs_from_loops=True, + bufferize_function_boundaries=True, + function_boundary_type_conversion=identity_layout, + ).result + # fold memref.subviews into vector.transfer_read/write ops + mod = apply_registered_pass(mod, "fold-memref-alias-ops") + transform.apply_cse(mod) + canonicalize(mod) + + if stop_at_stage == "bufferized": + raise PipelineInterrupt() + + # convert forall to parallel + wg_loops = match_and_split(mod, ops={"scf.forall"}, nhandles=nlayers) + for wg_loop in wg_loops: + wg_loop = loop.loop_forall_to_parallel([anytype], wg_loop) + func = transform.get_parent_op(anytype, wg_loop) + + # convert to scf.parallel to gpu.launch + func = apply_registered_pass(func, "gpu-map-parallel-loops") + func = apply_registered_pass(func, "convert-parallel-loops-to-gpu") + func = apply_registered_pass(func, "lower-affine") + transform.apply_cse(func) + canonicalize(func) + + # set correct number of gpu threads + launch_ops = match_and_split(mod, ops={"gpu.launch"}, nhandles=nlayers) + for i_layer, launch_op in enumerate(launch_ops): + layer_params = params[f"layer_{i_layer}"] + # tunable parameters + wg_tile = [layer_params["wg_m"], layer_params["wg_n"]] + sg_tile = [layer_params["sg_m"], layer_params["sg_n"]] + + # derived parameters + sg_layout = [wg_tile[0] // sg_tile[0], wg_tile[1] // sg_tile[1]] + # number of threads collapsed to 1d layout + nb_threads = sg_layout[0] * sg_layout[1] * nb_workitems + + xegpu.set_gpu_launch_threads(launch_op, threads=[nb_threads, 1, 1]) + + # outline gpu func + func = apply_registered_pass(func, "lower-affine") + canonicalize(func) + func = apply_registered_pass(func, "gpu-launch-sink-index-computations") + mod = apply_registered_pass(mod, "gpu-kernel-outlining") + transform.apply_cse(mod) + + # set xevm target + mod = apply_registered_pass( + mod, + "xevm-attach-target", + options={"O": "3", "chip": "bmg"}, + ) + + # convert vector to xegpu + gpu_mod_ops = match_and_split(mod, ops={"gpu.module"}, nhandles=nlayers) + for gpu_mod in gpu_mod_ops: + gpu_func = match(gpu_mod, ops={"gpu.func"}) + gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") + transform.apply_cse(gpu_func) + + if stop_at_stage == "xegpu-initial": + raise PipelineInterrupt() + + for i_layer, gpu_mod in enumerate(gpu_mod_ops): + gpu_func = match(gpu_mod, ops={"gpu.func"}) + + # tunable parameters: xegpu layout + layer_params = params[f"layer_{i_layer}"] + + wg_tile = [layer_params["wg_m"], layer_params["wg_n"]] + sg_tile = [layer_params["sg_m"], layer_params["sg_n"]] + k_tile = layer_params["k"] + + load_tile_a = [layer_params["load_a_m"], layer_params["load_a_k"]] + load_tile_b = [layer_params["load_b_k"], layer_params["load_b_n"]] + prefetch_tile_a = [layer_params["pf_a_m"], layer_params["pf_a_k"]] + prefetch_tile_b = [layer_params["pf_b_k"], layer_params["pf_b_n"]] + nb_prefetch = layer_params["pf_nb"] + + prefetch_layout_a = [ + wg_tile[0] // prefetch_tile_a[0], + k_tile // prefetch_tile_a[1], + ] + prefetch_layout_b = [ + k_tile // prefetch_tile_b[0], + wg_tile[1] // prefetch_tile_b[1], + ] + + # matmul matrix shapes + sg_tile_a = [sg_tile[0], k_tile] + sg_tile_b = [k_tile, sg_tile[1]] + + # add layouts to DPAS op operands + k_loop = match(gpu_func, ops={"scf.for"}) + dpas_op = match(k_loop, ops={"xegpu.dpas"}) + tile_a = transform.get_operand(anyvalue, dpas_op, [0]) + tile_b = transform.get_operand(anyvalue, dpas_op, [1]) + tile_c = transform.get_operand(anyvalue, dpas_op, [2]) + + def convert_layout(value, input, target): + xegpu.convert_layout( + value, + input_sg_layout=input["sg_layout"], + input_sg_data=input["sg_data"], + input_inst_data=input["inst_data"], + target_sg_layout=target["sg_layout"], + target_sg_data=target["sg_data"], + target_inst_data=target["inst_data"], + ) + + # insert prefetch ops for DPAS A and B tiles + desc_prefetch_a = xegpu.insert_prefetch( + tile_a, + nb_prefetch=nb_prefetch, + ) + xegpu.set_desc_layout( + desc_prefetch_a, + sg_layout=prefetch_layout_a, + sg_data=prefetch_tile_a, + inst_data=prefetch_inst_data, + ) + desc_prefetch_b = xegpu.insert_prefetch( + tile_b, + nb_prefetch=nb_prefetch, + ) + xegpu.set_desc_layout( + desc_prefetch_b, + sg_layout=prefetch_layout_b, + sg_data=prefetch_tile_b, + inst_data=prefetch_inst_data, + ) + + # A tile load layout + layout_load_a = { + "sg_layout": sg_layout, + "sg_data": sg_tile_a, + "inst_data": load_tile_a, + } + desc_op_a = xegpu.get_desc_op(tile_a) + desc_op_a = xegpu.set_desc_layout( + target=desc_op_a, + **layout_load_a, + ) + # A tile dpas layout + layout_dpas_a = layout_load_a.copy() + layout_dpas_a["inst_data"] = dpas_shape_a + convert_layout(tile_a, layout_load_a, layout_dpas_a) + + # B tile load layout + layout_load_b = { + "sg_layout": sg_layout, + "sg_data": sg_tile_b, + "inst_data": load_tile_b, + } + desc_op_b = xegpu.get_desc_op(tile_b) + desc_op_b = xegpu.set_desc_layout( + target=desc_op_b, + **layout_load_b, + ) + # B tile dpas layout + layout_dpas_b = layout_load_b.copy() + layout_dpas_b["inst_data"] = dpas_shape_b + convert_layout(tile_b, layout_load_b, layout_dpas_b) + + # C tile layout + output_layout = { + "sg_layout": sg_layout, + "sg_data": sg_tile, + "inst_data": dpas_shape_c, + } + desc_op_c = xegpu.get_desc_op(tile_c) + desc_op_c = xegpu.set_desc_layout(desc_op_c, **output_layout) + # C tile dpas layout + xegpu.set_op_layout_attr(dpas_op, result=True, index=0, **output_layout) + + if has_relu: + # for post ops we need to add C layout manually + max_op = match(gpu_func, ops={"arith.maximumf"}) + xegpu.set_op_layout_attr(max_op, result=True, index=0, **output_layout) + # find zero constant buffer and annotate it + const_buffer = transform.get_producer_of_operand(anytype, max_op, 1) + xegpu.set_op_layout_attr( + const_buffer, result=True, index=0, **output_layout + ) + if has_bias: + # for post ops we need to add C layout manually + add_op = match(gpu_func, ops={"arith.addf"}) + xegpu.set_op_layout_attr(add_op, result=True, index=0, **output_layout) + + # annotate broadcast op operands + bcast_op = transform.get_producer_of_operand(anytype, add_op, 0) + xegpu.set_op_layout_attr(bcast_op, result=True, index=0, **output_layout) + bcast_load = transform.get_producer_of_operand(anytype, bcast_op, 0) + xegpu.set_op_layout_attr( + bcast_load, result=True, index=0, **output_layout, slice_dims=[0] + ) + output_layout_dim1 = { + "sg_layout": [sg_layout[1]], + "sg_data": [sg_tile[1]], + "inst_data": [dpas_shape_c[1]], + } + offset = transform.get_producer_of_operand(anytype, bcast_load, 1) + xegpu.set_op_layout_attr(offset, result=True, index=0, **output_layout_dim1) + aux1 = transform.get_producer_of_operand(anytype, offset, 0) + xegpu.set_op_layout_attr(aux1, result=True, index=0, **output_layout_dim1) + aux2 = transform.get_producer_of_operand(anytype, offset, 1) + xegpu.set_op_layout_attr(aux2, result=True, index=0, **output_layout_dim1) + mask = transform.get_producer_of_operand(anytype, bcast_load, 2) + xegpu.set_op_layout_attr(mask, result=True, index=0, **output_layout_dim1) + raise NotImplementedError("Bias layout propagation is not supported.") + if has_convert_c: + ext_op = match(gpu_func, ops={"arith.extf"}) + xegpu.set_op_layout_attr(ext_op, result=True, index=0, **output_layout) + trunc_op = match(gpu_func, ops={"arith.truncf"}) + xegpu.set_op_layout_attr(trunc_op, result=True, index=0, **output_layout) + + transform.apply_cse(gpu_func) + canonicalize(gpu_func) + + # hoist desc ops out of reduction loop + transform.apply_licm(k_loop) + + canonicalize(gpu_func) + transform.apply_cse(gpu_func) + + if stop_at_stage == "xegpu-wg": + raise PipelineInterrupt() + + return mod + + def get_schedule_module( has_bias: bool = False, has_relu: bool = False, From b55e2b8854e9d17c19ac50d8e7432b79d4ad8c5e Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Fri, 23 Jan 2026 23:13:28 +0200 Subject: [PATCH 3/7] support non-accumulating matrix multiplications --- examples/xegpu_matmul/matmul.py | 18 ++++++++++++- examples/xegpu_matmul/mlp.py | 10 +++++++ examples/xegpu_matmul/payload.py | 22 ++++++++++++---- examples/xegpu_matmul/schedule.py | 44 ++++++++++++++++++++++++------- 4 files changed, 79 insertions(+), 15 deletions(-) diff --git a/examples/xegpu_matmul/matmul.py b/examples/xegpu_matmul/matmul.py index 2e7da58..3380985 100644 --- a/examples/xegpu_matmul/matmul.py +++ b/examples/xegpu_matmul/matmul.py @@ -54,6 +54,7 @@ def __init__( c_type: str = "f32", has_bias: bool = False, has_relu: bool = False, + accumulate_c: bool = True, ): self.M = M self.N = N @@ -73,6 +74,7 @@ def __init__( self.c_dtype = type_str_to_numpy[c_type] self.has_bias = has_bias self.has_relu = has_relu + self.accumulate_c = accumulate_c if has_bias: raise NotImplementedError("Bias is not implemented yet") # cache allocated memrefs @@ -136,7 +138,9 @@ def _reference_solution(self) -> np.ndarray: A, B, C = self._initial_host_arrays # use float32 data type for efficiency f32 = np.float32 - C_ref = A.astype(f32) @ B.astype(f32) + C.astype(f32) + C_ref = A.astype(f32) @ B.astype(f32) + if self.accumulate_c: + C_ref += C.astype(f32) if self.has_relu: C_ref = np.maximum(C_ref, 0) if self.has_bias: @@ -196,6 +200,10 @@ def get_complexity(self) -> tuple[int, int, int]: nbytes_ab = np.dtype(self.ab_dtype).itemsize nbytes_c = np.dtype(self.c_dtype).itemsize memory_reads = (M * K + K * N) * nbytes_ab # read A and B + if self.accumulate_c: + memory_reads += M * N * nbytes_c # read C for accumulation + if self.has_bias: + memory_reads += N * nbytes_c # read bias memory_writes = M * N * nbytes_c # write C return (flop_count, memory_reads, memory_writes) @@ -209,6 +217,7 @@ def payload_module(self) -> ir.Module: c_type_str=self.c_type, has_bias=self.has_bias, has_relu=self.has_relu, + accumulate_c=self.accumulate_c, ) return mod @@ -218,6 +227,7 @@ def schedule_module( return get_schedule_module( has_bias=self.has_bias, has_relu=self.has_relu, + accumulate_c=self.accumulate_c, stop_at_stage=stop_at_stage, params=parameters, ) @@ -309,6 +319,11 @@ def parse_cli(): action="store_true", help="Add relu op after the matrix multiplication (and bias if any).", ) + parser.add_argument( + "--no-accumulate-c", + action="store_true", + help="Compute plain matrix-multiply C=A*B instead of matrix-multiply-accumulate C+=A*B.", + ) parser.add_argument( "--check-result", action="store_true", @@ -371,6 +386,7 @@ def parse_cli(): c_type=c_type, has_bias=False, has_relu=args.relu, + accumulate_c=not args.no_accumulate_c, ) if args.dump_kernel or args.dump_schedule: diff --git a/examples/xegpu_matmul/mlp.py b/examples/xegpu_matmul/mlp.py index 71907b6..8e1dba1 100644 --- a/examples/xegpu_matmul/mlp.py +++ b/examples/xegpu_matmul/mlp.py @@ -53,6 +53,7 @@ def __init__( c_type: str = "f32", has_bias: bool = False, has_relu: bool = False, + accumulate_c: bool = False, ): self.batch_size = batch_size self.input_size = input_size @@ -76,6 +77,7 @@ def __init__( self.c_dtype = type_str_to_numpy[c_type] self.has_bias = has_bias self.has_relu = has_relu + self.accumulate_c = accumulate_c if has_bias: raise NotImplementedError("Bias is not implemented yet") # cache allocated memrefs @@ -262,6 +264,7 @@ def payload_module(self) -> ir.Module: c_type_str=self.c_type, has_bias=self.has_bias, has_relu=self.has_relu, + accumulate_c=self.accumulate_c, ) return mod @@ -271,6 +274,7 @@ def schedule_module( return get_schedule_module_mlp( has_bias=self.has_bias, has_relu=self.has_relu, + accumulate_c=self.accumulate_c, stop_at_stage=stop_at_stage, nlayers=len(self.matmul_layers), params=parameters, @@ -498,6 +502,11 @@ def parse_cli(): action="store_true", help="Check the result of the matrix multiplication.", ) + parser.add_argument( + "--accumulate-c", + action="store_true", + help="Use matrix-multiply-accumulate layers instead of initializing the accumulator tile with zeros.", + ) parser.add_argument( "--dump-kernel", type=str, @@ -540,6 +549,7 @@ def parse_cli(): c_type=c_type, has_bias=False, has_relu=args.relu, + accumulate_c=args.accumulate_c, ) matmuls = wload.matmul_layers print(f"MLP with {len(matmuls)} layers") diff --git a/examples/xegpu_matmul/payload.py b/examples/xegpu_matmul/payload.py index 2e6615c..4411fc8 100644 --- a/examples/xegpu_matmul/payload.py +++ b/examples/xegpu_matmul/payload.py @@ -60,12 +60,13 @@ def emit_mlp_layer( c_type, bias_tensor=None, has_relu=False, + accumulate_c=True, convert_c_type=False, ) -> ir.Value: M, N = c_tensor.type.shape id_map = ir.AffineMap.get_identity(2) par_iter = linalg.IteratorType.parallel - if convert_c_type: + if convert_c_type and accumulate_c: empty = tensor.empty((M, N), c_type) @linalg.generic( @@ -79,7 +80,13 @@ def f(a, b): input_c_tensor = f else: - input_c_tensor = c_tensor + if accumulate_c: + input_c_tensor = c_tensor + else: + zero = arith.constant(c_type, 0.0) + empty = tensor.empty((M, N), c_type) + zero_tensor = linalg.fill(zero, outs=[empty]) + input_c_tensor = zero_tensor mmul = linalg.matmul(a_tensor, b_tensor, outs=[input_c_tensor]) terminal = mmul res_type = c_type @@ -119,6 +126,7 @@ def generate_matmul_payload( c_type_str: str, has_bias: bool, has_relu: bool, + accumulate_c: bool, ) -> ir.Module: """Generate payload function module.""" get_ir_dtype = { @@ -167,6 +175,7 @@ def payload(*args): c_type, bias_tensor, has_relu, + accumulate_c=accumulate_c, convert_c_type=False, ) bufferization.materialize_in_destination( @@ -200,6 +209,7 @@ def generate_mlp_payload( c_type_str: str, has_bias: bool, has_relu: bool, + accumulate_c: bool, ) -> ir.Module: """Generate payload function module.""" get_ir_dtype = { @@ -275,15 +285,17 @@ def payload(*args): c_type, bias_tensor, has_relu, + accumulate_c=accumulate_c, convert_c_type=True, ) - if to_dealloc is not None: - gpu.dealloc(None, [], to_dealloc) - to_dealloc = None if i != nlayers - 1: bufferization.materialize_in_destination( None, layer_output, c_memref, restrict=True, writable=True ) + if to_dealloc is not None: + gpu.dealloc(None, [], to_dealloc) + to_dealloc = None + if i != nlayers - 1: # deallocate after next layer to_dealloc = c_memref diff --git a/examples/xegpu_matmul/schedule.py b/examples/xegpu_matmul/schedule.py index 2ec1145..f2f17a8 100644 --- a/examples/xegpu_matmul/schedule.py +++ b/examples/xegpu_matmul/schedule.py @@ -38,6 +38,7 @@ def match_and_split(*args, nhandles=1, **kwargs): def get_schedule_module_mlp( has_bias: bool = False, has_relu: bool = False, + accumulate_c: bool = False, stop_at_stage: str = "", nlayers: int = 1, params: Optional[dict] = None, @@ -66,6 +67,7 @@ def get_schedule_module_mlp( payload_mod, has_bias=has_bias, has_relu=has_relu, + accumulate_c=accumulate_c, stop_at_stage=stop_at_stage, nlayers=nlayers, params=params, @@ -78,6 +80,7 @@ def xegpu_mlp_transform_schedule( mod: ir.Value, has_bias: bool = False, has_relu: bool = False, + accumulate_c: bool = False, has_convert_c: bool = True, stop_at_stage: str = "", nlayers: int = 1, @@ -89,6 +92,7 @@ def xegpu_mlp_transform_schedule( mod, has_bias=has_bias, has_relu=has_relu, + accumulate_c=accumulate_c, has_convert_c=has_convert_c, stop_at_stage=stop_at_stage, nlayers=nlayers, @@ -109,6 +113,7 @@ def bundle_xepu_mlp_schedule( mod: ir.Value, has_bias: bool = False, has_relu: bool = False, + accumulate_c: bool = False, has_convert_c: bool = True, stop_at_stage: str = "", nlayers: int = 1, @@ -377,10 +382,24 @@ def convert_layout(value, input, target): "sg_data": sg_tile, "inst_data": dpas_shape_c, } - desc_op_c = xegpu.get_desc_op(tile_c) - desc_op_c = xegpu.set_desc_layout(desc_op_c, **output_layout) # C tile dpas layout xegpu.set_op_layout_attr(dpas_op, result=True, index=0, **output_layout) + if accumulate_c: + desc_op_c = xegpu.get_desc_op(tile_c) + desc_op_c = xegpu.set_desc_layout(desc_op_c, **output_layout) + else: + # match the const zero tile + acc_type = ir.F32Type.get() + vtype = ir.VectorType.get([wg_tile[0], wg_tile[1]], acc_type) + zero_tile = match( + gpu_func, ops={"arith.constant"}, filter_result_type=vtype + ) + xegpu.set_op_layout_attr(zero_tile, result=True, index=0, **output_layout) + # annotate store op + store_op = match(gpu_func, ops={"xegpu.store_nd"}) + tile_c = transform.get_operand(anyvalue, store_op, [1]) + desc_op_c = xegpu.get_desc_op(tile_c) + desc_op_c = xegpu.set_desc_layout(desc_op_c, **output_layout) if has_relu: # for post ops we need to add C layout manually @@ -418,8 +437,9 @@ def convert_layout(value, input, target): xegpu.set_op_layout_attr(mask, result=True, index=0, **output_layout_dim1) raise NotImplementedError("Bias layout propagation is not supported.") if has_convert_c: - ext_op = match(gpu_func, ops={"arith.extf"}) - xegpu.set_op_layout_attr(ext_op, result=True, index=0, **output_layout) + if accumulate_c: + ext_op = match(gpu_func, ops={"arith.extf"}) + xegpu.set_op_layout_attr(ext_op, result=True, index=0, **output_layout) trunc_op = match(gpu_func, ops={"arith.truncf"}) xegpu.set_op_layout_attr(trunc_op, result=True, index=0, **output_layout) @@ -441,6 +461,7 @@ def convert_layout(value, input, target): def get_schedule_module( has_bias: bool = False, has_relu: bool = False, + accumulate_c: bool = False, stop_at_stage: str = "", params: Optional[dict] = None, ) -> ir.Module: @@ -468,6 +489,7 @@ def get_schedule_module( payload_mod, has_bias=has_bias, has_relu=has_relu, + accumulate_c=accumulate_c, stop_at_stage=stop_at_stage, params=params, ) @@ -479,6 +501,7 @@ def xegpu_matmul_transform_schedule( mod: ir.Value, has_bias: bool = False, has_relu: bool = False, + accumulate_c: bool = False, stop_at_stage: str = "", params: Optional[dict] = None, ): @@ -488,6 +511,7 @@ def xegpu_matmul_transform_schedule( mod, has_bias=has_bias, has_relu=has_relu, + accumulate_c=accumulate_c, stop_at_stage=stop_at_stage, params=params, ) @@ -506,6 +530,7 @@ def bundle_xepu_matmul_schedule( mod, has_bias: bool = False, has_relu: bool = False, + accumulate_c: bool = False, stop_at_stage: str = "", params: Optional[dict] = None, ) -> ir.Module: @@ -741,15 +766,16 @@ def convert_layout(value, input, target): "sg_data": sg_tile, "inst_data": dpas_shape_c, } - desc_op_c = xegpu.get_desc_op(tile_c) - # C tile load/store op anchor layout - desc_c_users = transform.get_consumers_of_result(anytype, desc_op_c, 0) - load_op_c, store_op_c = transform.split_handle((anytype, anytype), desc_c_users) - xegpu.set_op_layout_attr(load_op_c, **output_layout) # C tile dpas anchor layout xegpu.set_op_layout_attr(dpas_op, index=0, **layout_dpas_a) xegpu.set_op_layout_attr(dpas_op, index=1, **layout_dpas_b) xegpu.set_op_layout_attr(dpas_op, index=2, **output_layout) + if accumulate_c: + desc_op_c = xegpu.get_desc_op(tile_c) + # C tile load/store op anchor layout + desc_c_users = transform.get_consumers_of_result(anytype, desc_op_c, 0) + load_op_c, store_op_c = transform.split_handle((anytype, anytype), desc_c_users) + xegpu.set_op_layout_attr(load_op_c, **output_layout) if has_bias: # annotate the 1d load of the broadcast op with a slice layout From 48044e92d7654041c5dc7db58f190255dfdb9f79 Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Wed, 18 Feb 2026 10:30:08 +0200 Subject: [PATCH 4/7] schedule: use anchor layouts, matmul uses mlp schedule --- examples/xegpu_matmul/matmul.py | 8 +- examples/xegpu_matmul/schedule.py | 452 +++--------------------------- 2 files changed, 41 insertions(+), 419 deletions(-) diff --git a/examples/xegpu_matmul/matmul.py b/examples/xegpu_matmul/matmul.py index 3380985..d65e85d 100644 --- a/examples/xegpu_matmul/matmul.py +++ b/examples/xegpu_matmul/matmul.py @@ -24,7 +24,7 @@ from lighthouse.utils.memref import get_packed_arg, to_ctype as memref_to_ctype # Import from sibling files: -from schedule import get_schedule_module +from schedule import get_schedule_module_mlp from payload import generate_matmul_payload @@ -224,12 +224,14 @@ def payload_module(self) -> ir.Module: def schedule_module( self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None ) -> ir.Module: - return get_schedule_module( + return get_schedule_module_mlp( has_bias=self.has_bias, has_relu=self.has_relu, + has_convert_c=False, accumulate_c=self.accumulate_c, stop_at_stage=stop_at_stage, - params=parameters, + nlayers=1, + params={"layer_0": parameters}, ) def shared_libs(self) -> list[str]: diff --git a/examples/xegpu_matmul/schedule.py b/examples/xegpu_matmul/schedule.py index f2f17a8..dcf5695 100644 --- a/examples/xegpu_matmul/schedule.py +++ b/examples/xegpu_matmul/schedule.py @@ -38,6 +38,7 @@ def match_and_split(*args, nhandles=1, **kwargs): def get_schedule_module_mlp( has_bias: bool = False, has_relu: bool = False, + has_convert_c: bool = True, accumulate_c: bool = False, stop_at_stage: str = "", nlayers: int = 1, @@ -67,6 +68,7 @@ def get_schedule_module_mlp( payload_mod, has_bias=has_bias, has_relu=has_relu, + has_convert_c=has_convert_c, accumulate_c=accumulate_c, stop_at_stage=stop_at_stage, nlayers=nlayers, @@ -285,6 +287,8 @@ def bundle_xepu_mlp_schedule( sg_tile = [layer_params["sg_m"], layer_params["sg_n"]] k_tile = layer_params["k"] + sg_layout = [wg_tile[0] // sg_tile[0], wg_tile[1] // sg_tile[1]] + load_tile_a = [layer_params["load_a_m"], layer_params["load_a_k"]] load_tile_b = [layer_params["load_b_k"], layer_params["load_b_n"]] prefetch_tile_a = [layer_params["pf_a_m"], layer_params["pf_a_k"]] @@ -309,7 +313,6 @@ def bundle_xepu_mlp_schedule( dpas_op = match(k_loop, ops={"xegpu.dpas"}) tile_a = transform.get_operand(anyvalue, dpas_op, [0]) tile_b = transform.get_operand(anyvalue, dpas_op, [1]) - tile_c = transform.get_operand(anyvalue, dpas_op, [2]) def convert_layout(value, input, target): xegpu.convert_layout( @@ -327,22 +330,27 @@ def convert_layout(value, input, target): tile_a, nb_prefetch=nb_prefetch, ) - xegpu.set_desc_layout( - desc_prefetch_a, - sg_layout=prefetch_layout_a, - sg_data=prefetch_tile_a, - inst_data=prefetch_inst_data, - ) + layout_prefetch_a = { + "sg_layout": prefetch_layout_a, + "sg_data": prefetch_tile_a, + "inst_data": prefetch_inst_data, + } + pf_ops = transform.get_consumers_of_result(anytype, desc_prefetch_a, 0) + for pf in transform.split_handle((anytype,) * (nb_prefetch + 1), pf_ops): + xegpu.set_op_layout_attr(pf, **layout_prefetch_a) + desc_prefetch_b = xegpu.insert_prefetch( tile_b, nb_prefetch=nb_prefetch, ) - xegpu.set_desc_layout( - desc_prefetch_b, - sg_layout=prefetch_layout_b, - sg_data=prefetch_tile_b, - inst_data=prefetch_inst_data, - ) + layout_prefetch_b = { + "sg_layout": prefetch_layout_b, + "sg_data": prefetch_tile_b, + "inst_data": prefetch_inst_data, + } + pf_ops = transform.get_consumers_of_result(anytype, desc_prefetch_b, 0) + for pf in transform.split_handle((anytype,) * (nb_prefetch + 1), pf_ops): + xegpu.set_op_layout_attr(pf, **layout_prefetch_b) # A tile load layout layout_load_a = { @@ -351,10 +359,9 @@ def convert_layout(value, input, target): "inst_data": load_tile_a, } desc_op_a = xegpu.get_desc_op(tile_a) - desc_op_a = xegpu.set_desc_layout( - target=desc_op_a, - **layout_load_a, - ) + # A tile load op anchor layout + load_op_a = transform.get_consumers_of_result(anytype, desc_op_a, 0) + xegpu.set_op_layout_attr(load_op_a, **layout_load_a) # A tile dpas layout layout_dpas_a = layout_load_a.copy() layout_dpas_a["inst_data"] = dpas_shape_a @@ -367,10 +374,9 @@ def convert_layout(value, input, target): "inst_data": load_tile_b, } desc_op_b = xegpu.get_desc_op(tile_b) - desc_op_b = xegpu.set_desc_layout( - target=desc_op_b, - **layout_load_b, - ) + # B tile load op anchor layout + load_op_b = transform.get_consumers_of_result(anytype, desc_op_b, 0) + xegpu.set_op_layout_attr(load_op_b, **layout_load_b) # B tile dpas layout layout_dpas_b = layout_load_b.copy() layout_dpas_b["inst_data"] = dpas_shape_b @@ -382,66 +388,23 @@ def convert_layout(value, input, target): "sg_data": sg_tile, "inst_data": dpas_shape_c, } - # C tile dpas layout - xegpu.set_op_layout_attr(dpas_op, result=True, index=0, **output_layout) - if accumulate_c: - desc_op_c = xegpu.get_desc_op(tile_c) - desc_op_c = xegpu.set_desc_layout(desc_op_c, **output_layout) - else: - # match the const zero tile - acc_type = ir.F32Type.get() - vtype = ir.VectorType.get([wg_tile[0], wg_tile[1]], acc_type) - zero_tile = match( - gpu_func, ops={"arith.constant"}, filter_result_type=vtype - ) - xegpu.set_op_layout_attr(zero_tile, result=True, index=0, **output_layout) - # annotate store op - store_op = match(gpu_func, ops={"xegpu.store_nd"}) - tile_c = transform.get_operand(anyvalue, store_op, [1]) - desc_op_c = xegpu.get_desc_op(tile_c) - desc_op_c = xegpu.set_desc_layout(desc_op_c, **output_layout) - - if has_relu: - # for post ops we need to add C layout manually - max_op = match(gpu_func, ops={"arith.maximumf"}) - xegpu.set_op_layout_attr(max_op, result=True, index=0, **output_layout) - # find zero constant buffer and annotate it - const_buffer = transform.get_producer_of_operand(anytype, max_op, 1) - xegpu.set_op_layout_attr( - const_buffer, result=True, index=0, **output_layout - ) + # C tile dpas anchor layout + xegpu.set_op_layout_attr(dpas_op, index=0, **layout_dpas_a) + xegpu.set_op_layout_attr(dpas_op, index=1, **layout_dpas_b) + xegpu.set_op_layout_attr(dpas_op, index=2, **output_layout) + # annotate store op + store_op_c = match(gpu_func, ops={"xegpu.store_nd"}) + xegpu.set_op_layout_attr(store_op_c, **output_layout) + if has_bias: - # for post ops we need to add C layout manually + # annotate the 1d load of the broadcast op with a slice layout add_op = match(gpu_func, ops={"arith.addf"}) - xegpu.set_op_layout_attr(add_op, result=True, index=0, **output_layout) - - # annotate broadcast op operands bcast_op = transform.get_producer_of_operand(anytype, add_op, 0) - xegpu.set_op_layout_attr(bcast_op, result=True, index=0, **output_layout) bcast_load = transform.get_producer_of_operand(anytype, bcast_op, 0) xegpu.set_op_layout_attr( bcast_load, result=True, index=0, **output_layout, slice_dims=[0] ) - output_layout_dim1 = { - "sg_layout": [sg_layout[1]], - "sg_data": [sg_tile[1]], - "inst_data": [dpas_shape_c[1]], - } - offset = transform.get_producer_of_operand(anytype, bcast_load, 1) - xegpu.set_op_layout_attr(offset, result=True, index=0, **output_layout_dim1) - aux1 = transform.get_producer_of_operand(anytype, offset, 0) - xegpu.set_op_layout_attr(aux1, result=True, index=0, **output_layout_dim1) - aux2 = transform.get_producer_of_operand(anytype, offset, 1) - xegpu.set_op_layout_attr(aux2, result=True, index=0, **output_layout_dim1) - mask = transform.get_producer_of_operand(anytype, bcast_load, 2) - xegpu.set_op_layout_attr(mask, result=True, index=0, **output_layout_dim1) raise NotImplementedError("Bias layout propagation is not supported.") - if has_convert_c: - if accumulate_c: - ext_op = match(gpu_func, ops={"arith.extf"}) - xegpu.set_op_layout_attr(ext_op, result=True, index=0, **output_layout) - trunc_op = match(gpu_func, ops={"arith.truncf"}) - xegpu.set_op_layout_attr(trunc_op, result=True, index=0, **output_layout) transform.apply_cse(gpu_func) canonicalize(gpu_func) @@ -458,349 +421,6 @@ def convert_layout(value, input, target): return mod -def get_schedule_module( - has_bias: bool = False, - has_relu: bool = False, - accumulate_c: bool = False, - stop_at_stage: str = "", - params: Optional[dict] = None, -) -> ir.Module: - """Generate transform schedule module.""" - mod = ir.Module.create() - mod.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get() - with ir.InsertionPoint(mod.body): - named_sequence = transform.named_sequence( - "__transform_main", - [transform.AnyOpType.get()], # input types - [], # output types - arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}], - ) - with ir.InsertionPoint(named_sequence.body): - # match the payload module - anytype = transform.AnyOpType.get() - func = match(named_sequence.bodyTarget, ops={"func.func"}) - payload_mod = transform.get_parent_op( - anytype, - func, - op_name="builtin.module", - deduplicate=True, - ) - xegpu_matmul_transform_schedule( - payload_mod, - has_bias=has_bias, - has_relu=has_relu, - accumulate_c=accumulate_c, - stop_at_stage=stop_at_stage, - params=params, - ) - - return mod - - -def xegpu_matmul_transform_schedule( - mod: ir.Value, - has_bias: bool = False, - has_relu: bool = False, - accumulate_c: bool = False, - stop_at_stage: str = "", - params: Optional[dict] = None, -): - """Transform schedule for matmul-like payload.""" - try: - mod = bundle_xepu_matmul_schedule( - mod, - has_bias=has_bias, - has_relu=has_relu, - accumulate_c=accumulate_c, - stop_at_stage=stop_at_stage, - params=params, - ) - - mod = bundle_xegpu_to_binary( - mod, - stop_at_stage=stop_at_stage, - ) - except PipelineInterrupt: - pass - finally: - transform.yield_() - - -def bundle_xepu_matmul_schedule( - mod, - has_bias: bool = False, - has_relu: bool = False, - accumulate_c: bool = False, - stop_at_stage: str = "", - params: Optional[dict] = None, -) -> ir.Module: - """Schedule for lowering matmul-like payload to xegpu wg level.""" - if params is None: - raise ValueError("Schedule parameters must be provided.") - - # tunable parameters - wg_tile = [params["wg_m"], params["wg_n"]] - sg_tile = [params["sg_m"], params["sg_n"]] - k_tile = params["k"] - - load_tile_a = [params["load_a_m"], params["load_a_k"]] - load_tile_b = [params["load_b_k"], params["load_b_n"]] - prefetch_tile_a = [params["pf_a_m"], params["pf_a_k"]] - prefetch_tile_b = [params["pf_b_k"], params["pf_b_n"]] - nb_prefetch = params["pf_nb"] - - # derived parameters - sg_layout = [wg_tile[0] // sg_tile[0], wg_tile[1] // sg_tile[1]] - # number of threads collapsed to 1d layout - nb_threads = sg_layout[0] * sg_layout[1] * nb_workitems - prefetch_layout_a = [ - wg_tile[0] // prefetch_tile_a[0], - k_tile // prefetch_tile_a[1], - ] - prefetch_layout_b = [ - k_tile // prefetch_tile_b[0], - wg_tile[1] // prefetch_tile_b[1], - ] - - # matmul matrix shapes - sg_tile_a = [sg_tile[0], k_tile] - sg_tile_b = [k_tile, sg_tile[1]] - - if stop_at_stage == "initial": - raise PipelineInterrupt() - - anytype = transform.AnyOpType.get() - anyvalue = transform.AnyValueType.get() - - # match the payload function - anchor = match(mod, ops={"linalg.matmul"}) - func = transform.get_parent_op( - anytype, - anchor, - op_name="func.func", - deduplicate=True, - ) - - dpas_shape_a = [dpas_tile[0], dpas_tile[2]] - dpas_shape_b = [dpas_tile[2], dpas_tile[1]] - dpas_shape_c = [dpas_tile[0], dpas_tile[1]] - - # wg tiling - if has_relu: - terminal = match(mod, ops={"linalg.max"}) - elif has_bias: - terminal = match(mod, ops={"linalg.add"}) - else: - terminal = match(mod, ops={"linalg.matmul"}) - # FIXME use structured.structured_fuse - structured.FuseOp(terminal, tile_sizes=wg_tile, use_forall=True) - transform.apply_cse(mod) - canonicalize(mod) - - # k loop tiling - wg_matmul = match(mod, ops={"linalg.matmul"}) - # FIXME use structured.structured_tile_using_for - wgk_matmul, k_loop = structured.TileUsingForOp( - wg_matmul, sizes=[0, 0, k_tile] - ).results - - transform.apply_cse(func) - canonicalize(func) - - if stop_at_stage == "tiled": - raise PipelineInterrupt() - - # vectorize - # FIXME use structured.structured_vectorize_children_and_apply_patterns - func = structured.VectorizeChildrenAndApplyPatternsOp( - func, - fold_type_extensions_into_contract=True, - ).result - - # hoist loop invariant vector read/store ops - k_loop = match(func, ops={"scf.for"}) - loop.HoistLoopInvariantSubsetsOp(k_loop) - - transform.apply_cse(func) - canonicalize(func) - - if stop_at_stage == "vectorized": - raise PipelineInterrupt() - - # bufferize - - # eliminate empty tensors to avoid emitting extra copy ops - mod = apply_registered_pass(mod, "eliminate-empty-tensors") - identity_layout = LayoutMapOption.IdentityLayoutMap - mod = bufferization.OneShotBufferizeOp( - mod, - allow_return_allocs_from_loops=True, - bufferize_function_boundaries=True, - function_boundary_type_conversion=identity_layout, - ).result - # fold memref.subviews into vector.transfer_read/write ops - mod = apply_registered_pass(mod, "fold-memref-alias-ops") - transform.apply_cse(mod) - canonicalize(mod) - - if stop_at_stage == "bufferized": - raise PipelineInterrupt() - - # convert forall to parallel - wg_loop = match(mod, ops={"scf.forall"}) - wg_loop = loop.loop_forall_to_parallel([anytype], wg_loop) - func = transform.get_parent_op(anytype, wg_loop) - - # convert to scf.parallel to gpu.launch - func = apply_registered_pass(func, "gpu-map-parallel-loops") - func = apply_registered_pass(func, "convert-parallel-loops-to-gpu") - func = apply_registered_pass(func, "lower-affine") - transform.apply_cse(func) - canonicalize(func) - - # set correct number of gpu threads - launch_op = match(func, ops={"gpu.launch"}) - xegpu.set_gpu_launch_threads(launch_op, threads=[nb_threads, 1, 1]) - - # outline gpu func - func = apply_registered_pass(func, "lower-affine") - canonicalize(func) - func = apply_registered_pass(func, "gpu-launch-sink-index-computations") - mod = apply_registered_pass(mod, "gpu-kernel-outlining") - transform.apply_cse(mod) - - # set xevm target - mod = apply_registered_pass( - mod, - "xevm-attach-target", - options={"O": "3", "chip": "bmg"}, - ) - - # convert vector to xegpu - gpu_mod = match(mod, ops={"gpu.module"}) - gpu_func = match(gpu_mod, ops={"gpu.func"}) - gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") - transform.apply_cse(gpu_func) - - if stop_at_stage == "xegpu-initial": - raise PipelineInterrupt() - - # add layouts to DPAS op operands - k_loop = match(gpu_func, ops={"scf.for"}) - dpas_op = match(k_loop, ops={"xegpu.dpas"}) - tile_a = transform.get_operand(anyvalue, dpas_op, [0]) - tile_b = transform.get_operand(anyvalue, dpas_op, [1]) - tile_c = transform.get_operand(anyvalue, dpas_op, [2]) - - def convert_layout(value, input, target): - xegpu.convert_layout( - value, - input_sg_layout=input["sg_layout"], - input_sg_data=input["sg_data"], - input_inst_data=input["inst_data"], - target_sg_layout=target["sg_layout"], - target_sg_data=target["sg_data"], - target_inst_data=target["inst_data"], - ) - - # insert prefetch ops for DPAS A and B tiles - desc_prefetch_a = xegpu.insert_prefetch( - tile_a, - nb_prefetch=nb_prefetch, - ) - layout_prefetch_a = { - "sg_layout": prefetch_layout_a, - "sg_data": prefetch_tile_a, - "inst_data": prefetch_inst_data, - } - pf_ops = transform.get_consumers_of_result(anytype, desc_prefetch_a, 0) - for pf in transform.split_handle((anytype,) * (nb_prefetch + 1), pf_ops): - xegpu.set_op_layout_attr(pf, **layout_prefetch_a) - - desc_prefetch_b = xegpu.insert_prefetch( - tile_b, - nb_prefetch=nb_prefetch, - ) - layout_prefetch_b = { - "sg_layout": prefetch_layout_b, - "sg_data": prefetch_tile_b, - "inst_data": prefetch_inst_data, - } - pf_ops = transform.get_consumers_of_result(anytype, desc_prefetch_b, 0) - for pf in transform.split_handle((anytype,) * (nb_prefetch + 1), pf_ops): - xegpu.set_op_layout_attr(pf, **layout_prefetch_b) - - # A tile load layout - layout_load_a = { - "sg_layout": sg_layout, - "sg_data": sg_tile_a, - "inst_data": load_tile_a, - } - desc_op_a = xegpu.get_desc_op(tile_a) - # A tile load op anchor layout - load_op_a = transform.get_consumers_of_result(anytype, desc_op_a, 0) - xegpu.set_op_layout_attr(load_op_a, **layout_load_a) - # A tile dpas layout - layout_dpas_a = layout_load_a.copy() - layout_dpas_a["inst_data"] = dpas_shape_a - convert_layout(tile_a, layout_load_a, layout_dpas_a) - - # B tile load layout - layout_load_b = { - "sg_layout": sg_layout, - "sg_data": sg_tile_b, - "inst_data": load_tile_b, - } - desc_op_b = xegpu.get_desc_op(tile_b) - # B tile load op anchor layout - load_op_b = transform.get_consumers_of_result(anytype, desc_op_b, 0) - xegpu.set_op_layout_attr(load_op_b, **layout_load_b) - # B tile dpas layout - layout_dpas_b = layout_load_b.copy() - layout_dpas_b["inst_data"] = dpas_shape_b - convert_layout(tile_b, layout_load_b, layout_dpas_b) - - # C tile layout - output_layout = { - "sg_layout": sg_layout, - "sg_data": sg_tile, - "inst_data": dpas_shape_c, - } - # C tile dpas anchor layout - xegpu.set_op_layout_attr(dpas_op, index=0, **layout_dpas_a) - xegpu.set_op_layout_attr(dpas_op, index=1, **layout_dpas_b) - xegpu.set_op_layout_attr(dpas_op, index=2, **output_layout) - if accumulate_c: - desc_op_c = xegpu.get_desc_op(tile_c) - # C tile load/store op anchor layout - desc_c_users = transform.get_consumers_of_result(anytype, desc_op_c, 0) - load_op_c, store_op_c = transform.split_handle((anytype, anytype), desc_c_users) - xegpu.set_op_layout_attr(load_op_c, **output_layout) - - if has_bias: - # annotate the 1d load of the broadcast op with a slice layout - add_op = match(gpu_func, ops={"arith.addf"}) - bcast_op = transform.get_producer_of_operand(anytype, add_op, 0) - bcast_load = transform.get_producer_of_operand(anytype, bcast_op, 0) - xegpu.set_op_layout_attr( - bcast_load, result=True, index=0, **output_layout, slice_dims=[0] - ) - raise NotImplementedError("Bias layout propagation is not supported.") - transform.apply_cse(gpu_func) - canonicalize(gpu_func) - - # hoist desc ops out of reduction loop - transform.apply_licm(k_loop) - - canonicalize(gpu_func) - transform.apply_cse(gpu_func) - - if stop_at_stage == "xegpu-wg": - raise PipelineInterrupt() - - return mod - - def bundle_xegpu_to_binary(mod, stop_at_stage: str = "") -> ir.Module: """Schedule for lowering xegpu wg level to binary.""" # upstream xegpu/xevm pipeline is payload independent. From 1057a28a183e7337ba74659f8ca7b2a5f25d2261 Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Mon, 23 Feb 2026 13:49:07 +0200 Subject: [PATCH 5/7] mlp: add identity-weights option --- examples/xegpu_matmul/mlp.py | 36 +++++++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/examples/xegpu_matmul/mlp.py b/examples/xegpu_matmul/mlp.py index 8e1dba1..fff13c9 100644 --- a/examples/xegpu_matmul/mlp.py +++ b/examples/xegpu_matmul/mlp.py @@ -54,6 +54,7 @@ def __init__( has_bias: bool = False, has_relu: bool = False, accumulate_c: bool = False, + identity_weights: bool = False, ): self.batch_size = batch_size self.input_size = input_size @@ -64,6 +65,7 @@ def __init__( layer_sizes = [self.input_size] + self.hidden_layer_sizes + [self.output_size] self.weight_shapes = list(zip(layer_sizes[:-1], layer_sizes[1:])) self.matmul_layers = [(self.batch_size, o, i) for i, o in self.weight_shapes] + self.identity_weights = identity_weights assert ab_type == "f16", "Only f16 type is supported for A and B" assert c_type == "f32", "Only f32 type is supported for C" @@ -125,17 +127,28 @@ def _initial_host_arrays(self) -> list[np.ndarray]: # use integer values to avoid f16/f32 floating point discrepancies def gen_random(shape, dtype): - # generate random {-1, 1} values - a = np.round(np.random.random_sample(shape)) - a[a == 0] = -1 + # generate values in range [-3, 3] + a = np.round(6 * np.random.random_sample(shape)) - 3 return a.astype(dtype) + def gen_identity(shape, dtype): + # identity matrix, if cols > rows wrap to fill all columns + a = np.zeros(shape, dtype=dtype) + np.fill_diagonal(a, 1) + if shape[1] > shape[0]: + second_block = a[:, shape[0] :] + np.fill_diagonal(second_block, 1) + return a + np.random.seed(2) input_array = gen_random(self.input_shape, self.ab_dtype) output_array = np.zeros(self.output_shape, self.ab_dtype) weights = [] for i, o in self.weight_shapes: - W = gen_random((i, o), self.ab_dtype) + if self.identity_weights: + W = gen_identity((i, o), self.ab_dtype) + else: + W = gen_random((i, o), self.ab_dtype) weights.append(W) if self.has_bias: @@ -497,15 +510,23 @@ def parse_cli(): action="store_true", help="Add relu op after the matrix multiplication (and bias if any).", ) + parser.add_argument( + "--accumulate-c", + action="store_true", + help="Use matrix-multiply-accumulate layers instead of initializing the " + "accumulator tile with zeros.", + ) parser.add_argument( "--check-result", action="store_true", - help="Check the result of the matrix multiplication.", + help="Check the result of the MLP model. If the result overflows to " + "inf/nan values, use --identity-weights option.", ) parser.add_argument( - "--accumulate-c", + "--identity-weights", action="store_true", - help="Use matrix-multiply-accumulate layers instead of initializing the accumulator tile with zeros.", + help="Initialize weights as (extended) identity matrix, useful for " + "correctness test. Can skew performance measurement.", ) parser.add_argument( "--dump-kernel", @@ -550,6 +571,7 @@ def parse_cli(): has_bias=False, has_relu=args.relu, accumulate_c=args.accumulate_c, + identity_weights=args.identity_weights, ) matmuls = wload.matmul_layers print(f"MLP with {len(matmuls)} layers") From f22946985c3655ffac6740b3ab5d842d090e2b36 Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Mon, 23 Feb 2026 17:00:47 +0200 Subject: [PATCH 6/7] move mlp to separate dir, payload and schedule to lighthouse --- examples/xegpu_matmul/lit.local.cfg | 1 - examples/xegpu_matmul/matmul.py | 15 ++---- examples/xegpu_mlp/README.md | 46 +++++++++++++++++++ examples/{xegpu_matmul => xegpu_mlp}/mlp.py | 17 ++----- lighthouse/ingress/gpu/__init__.py | 3 ++ .../ingress/gpu/matmul.py | 0 lighthouse/schedule/__init__.py | 0 lighthouse/schedule/xegpu/__init__.py | 0 .../schedule/xegpu/matmul_schedule.py | 2 +- lighthouse/utils/numpy.py | 10 ++++ 10 files changed, 69 insertions(+), 25 deletions(-) delete mode 100644 examples/xegpu_matmul/lit.local.cfg create mode 100644 examples/xegpu_mlp/README.md rename examples/{xegpu_matmul => xegpu_mlp}/mlp.py (97%) create mode 100644 lighthouse/ingress/gpu/__init__.py rename examples/xegpu_matmul/payload.py => lighthouse/ingress/gpu/matmul.py (100%) create mode 100644 lighthouse/schedule/__init__.py create mode 100644 lighthouse/schedule/xegpu/__init__.py rename examples/xegpu_matmul/schedule.py => lighthouse/schedule/xegpu/matmul_schedule.py (99%) create mode 100644 lighthouse/utils/numpy.py diff --git a/examples/xegpu_matmul/lit.local.cfg b/examples/xegpu_matmul/lit.local.cfg deleted file mode 100644 index b310830..0000000 --- a/examples/xegpu_matmul/lit.local.cfg +++ /dev/null @@ -1 +0,0 @@ -config.excludes = ["mlir_utils.py", "payload.py", "runner.py", "schedule.py"] diff --git a/examples/xegpu_matmul/matmul.py b/examples/xegpu_matmul/matmul.py index d65e85d..8a2dd9c 100644 --- a/examples/xegpu_matmul/matmul.py +++ b/examples/xegpu_matmul/matmul.py @@ -14,7 +14,6 @@ import numpy as np from mlir import ir from mlir.runtime.np_to_memref import ( - get_ranked_memref_descriptor, make_nd_memref_descriptor, as_ctype, ) @@ -22,15 +21,9 @@ from lighthouse.workload import Workload, benchmark from lighthouse.utils.memref import get_packed_arg, to_ctype as memref_to_ctype - -# Import from sibling files: -from schedule import get_schedule_module_mlp -from payload import generate_matmul_payload - - -def numpy_to_ctype(arr: np.ndarray) -> ctypes._Pointer: - """Convert numpy array to memref and ctypes **void pointer.""" - return memref_to_ctype(get_ranked_memref_descriptor(arr)) +from lighthouse.utils.numpy import numpy_to_ctype +from lighthouse.schedule.xegpu.matmul_schedule import get_schedule_module +from lighthouse.ingress.gpu import generate_matmul_payload class XeGPUMatMul(Workload): @@ -224,7 +217,7 @@ def payload_module(self) -> ir.Module: def schedule_module( self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None ) -> ir.Module: - return get_schedule_module_mlp( + return get_schedule_module( has_bias=self.has_bias, has_relu=self.has_relu, has_convert_c=False, diff --git a/examples/xegpu_mlp/README.md b/examples/xegpu_mlp/README.md new file mode 100644 index 0000000..6257240 --- /dev/null +++ b/examples/xegpu_mlp/README.md @@ -0,0 +1,46 @@ +# XeGPU Multilayer Perceptron (MLP) benchmark + +## Installation + +To install Lighthouse with XeGPU support, see installation instructions in [xegpu_matmul/README.md](../xegpu_matmul/README.md). + +## Usage + +Run the default single layer MLP (batch=1024, input_features=1024, output_features=1024) benchmark with correctness test: + +```bash +python mlp.py --check-result +``` + +which is equivalent to + +```bash +python mlp.py -b 1024 -i 1024 -o 1024 --check-result +``` + +Run a 3-layer MLP with batch size 128: + +```bash +python mlp.py -b 128 -i 16384 -o 8192 --hidden-sizes 16384 16384 ... +``` + +which corresponds to + +```txt +MLP with 3 layers + Layer 0: M=128, N=16384, K=16384 + Layer 1: M=128, N=16384, K=16384 + Layer 2: M=128, N=8192, K=16384 +``` + +Add ReLU to all layers: + +```bash +python mlp.py --relu ... +``` + +See all command line arguments: + +```bash +python mlp.py --help +``` diff --git a/examples/xegpu_matmul/mlp.py b/examples/xegpu_mlp/mlp.py similarity index 97% rename from examples/xegpu_matmul/mlp.py rename to examples/xegpu_mlp/mlp.py index fff13c9..321ff3b 100644 --- a/examples/xegpu_matmul/mlp.py +++ b/examples/xegpu_mlp/mlp.py @@ -2,7 +2,7 @@ # CHECK: module attributes {gpu.container_module} { """ -XeGPU matrix multiplication benchmark. +XeGPU MLP benchmark. """ import argparse @@ -14,7 +14,6 @@ import numpy as np from mlir import ir from mlir.runtime.np_to_memref import ( - get_ranked_memref_descriptor, make_nd_memref_descriptor, as_ctype, ) @@ -22,15 +21,9 @@ from lighthouse.workload import Workload, benchmark from lighthouse.utils.memref import get_packed_arg, to_ctype as memref_to_ctype - -# Import from sibling files: -from schedule import get_schedule_module_mlp -from payload import generate_mlp_payload - - -def numpy_to_ctype(arr: np.ndarray) -> ctypes._Pointer: - """Convert numpy array to memref and ctypes **void pointer.""" - return memref_to_ctype(get_ranked_memref_descriptor(arr)) +from lighthouse.utils.numpy import numpy_to_ctype +from lighthouse.schedule.xegpu.matmul_schedule import get_schedule_module +from lighthouse.ingress.gpu import generate_mlp_payload class XeGPUMLP(Workload): @@ -284,7 +277,7 @@ def payload_module(self) -> ir.Module: def schedule_module( self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None ) -> ir.Module: - return get_schedule_module_mlp( + return get_schedule_module( has_bias=self.has_bias, has_relu=self.has_relu, accumulate_c=self.accumulate_c, diff --git a/lighthouse/ingress/gpu/__init__.py b/lighthouse/ingress/gpu/__init__.py new file mode 100644 index 0000000..af910e2 --- /dev/null +++ b/lighthouse/ingress/gpu/__init__.py @@ -0,0 +1,3 @@ +from .matmul import generate_matmul_payload, generate_mlp_payload + +__all__ = ["generate_matmul_payload", "generate_mlp_payload"] diff --git a/examples/xegpu_matmul/payload.py b/lighthouse/ingress/gpu/matmul.py similarity index 100% rename from examples/xegpu_matmul/payload.py rename to lighthouse/ingress/gpu/matmul.py diff --git a/lighthouse/schedule/__init__.py b/lighthouse/schedule/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lighthouse/schedule/xegpu/__init__.py b/lighthouse/schedule/xegpu/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/xegpu_matmul/schedule.py b/lighthouse/schedule/xegpu/matmul_schedule.py similarity index 99% rename from examples/xegpu_matmul/schedule.py rename to lighthouse/schedule/xegpu/matmul_schedule.py index dcf5695..47be998 100644 --- a/examples/xegpu_matmul/schedule.py +++ b/lighthouse/schedule/xegpu/matmul_schedule.py @@ -35,7 +35,7 @@ def match_and_split(*args, nhandles=1, **kwargs): nb_workitems = 16 # workitems in subgroup -def get_schedule_module_mlp( +def get_schedule_module( has_bias: bool = False, has_relu: bool = False, has_convert_c: bool = True, diff --git a/lighthouse/utils/numpy.py b/lighthouse/utils/numpy.py new file mode 100644 index 0000000..6ce67ec --- /dev/null +++ b/lighthouse/utils/numpy.py @@ -0,0 +1,10 @@ +import ctypes + +import numpy as np +from mlir.runtime.np_to_memref import get_ranked_memref_descriptor +from lighthouse.utils.memref import to_ctype + + +def numpy_to_ctype(arr: np.ndarray) -> ctypes._Pointer: + """Convert numpy array to memref and ctypes **void pointer.""" + return to_ctype(get_ranked_memref_descriptor(arr)) From 40a7544e79c49acb2c266cae41cec0eed3ca346e Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Tue, 24 Feb 2026 18:42:58 +0200 Subject: [PATCH 7/7] mlp: do not emit ReLU for the output layer --- examples/xegpu_mlp/mlp.py | 16 +++++++++++----- lighthouse/ingress/gpu/matmul.py | 4 +++- lighthouse/schedule/xegpu/matmul_schedule.py | 10 ++++++---- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/examples/xegpu_mlp/mlp.py b/examples/xegpu_mlp/mlp.py index 321ff3b..d8671ce 100644 --- a/examples/xegpu_mlp/mlp.py +++ b/examples/xegpu_mlp/mlp.py @@ -10,6 +10,7 @@ from typing import Optional from contextlib import contextmanager from functools import cached_property +import warnings import numpy as np from mlir import ir @@ -75,6 +76,10 @@ def __init__( self.accumulate_c = accumulate_c if has_bias: raise NotImplementedError("Bias is not implemented yet") + + if len(self.matmul_layers) == 1 and self.has_relu: + warnings.warn("Using ReLU on a single layer model has no effect.") + # cache allocated memrefs self.gpu_memrefs = {} @@ -161,9 +166,9 @@ def _reference_solution(self) -> np.ndarray: weights = host_arrays[2:] a_array = input_array - for W in weights: + for i, W in enumerate(weights): C_ref = a_array @ W - if self.has_relu: + if self.has_relu and i < len(weights) - 1: C_ref = np.maximum(C_ref, 0) if self.has_bias: raise NotImplementedError("Bias verification not implemented") @@ -252,8 +257,9 @@ def matmul_complexity(M, N, K, has_bias, has_relu): flop_count = 0 memory_reads = 0 memory_writes = 0 - for M, N, K in self.matmul_layers: - f, r, w = matmul_complexity(M, N, K, self.has_bias, self.has_relu) + for i, (M, N, K) in enumerate(self.matmul_layers): + relu = self.has_relu if i < len(self.matmul_layers) - 1 else False + f, r, w = matmul_complexity(M, N, K, self.has_bias, relu) flop_count += f memory_reads += r memory_writes += w @@ -501,7 +507,7 @@ def parse_cli(): parser.add_argument( "--relu", action="store_true", - help="Add relu op after the matrix multiplication (and bias if any).", + help="Add ReLU activation function to each layer except the output layer.", ) parser.add_argument( "--accumulate-c", diff --git a/lighthouse/ingress/gpu/matmul.py b/lighthouse/ingress/gpu/matmul.py index 4411fc8..f100dfe 100644 --- a/lighthouse/ingress/gpu/matmul.py +++ b/lighthouse/ingress/gpu/matmul.py @@ -277,6 +277,8 @@ def payload(*args): c_memref, restrict=True, writable=True ) bias_tensor = bias + # skip relu for final layer + emit_relu = has_relu if i < nlayers - 1 else False layer_output = emit_mlp_layer( a_tensor, b_tensor, @@ -284,7 +286,7 @@ def payload(*args): ab_type, c_type, bias_tensor, - has_relu, + emit_relu, accumulate_c=accumulate_c, convert_c_type=True, ) diff --git a/lighthouse/schedule/xegpu/matmul_schedule.py b/lighthouse/schedule/xegpu/matmul_schedule.py index 47be998..77ec013 100644 --- a/lighthouse/schedule/xegpu/matmul_schedule.py +++ b/lighthouse/schedule/xegpu/matmul_schedule.py @@ -139,10 +139,7 @@ def bundle_xepu_mlp_schedule( dpas_shape_c = [dpas_tile[0], dpas_tile[1]] # wg tiling - - if has_relu: - terminal_ops = match_and_split(mod, ops={"linalg.max"}, nhandles=nlayers) - elif has_convert_c: + if has_convert_c: trunc_op = match(mod, ops={"arith.truncf"}) terminal = transform.get_parent_op(anytype, trunc_op) # split handle for each layer @@ -153,6 +150,11 @@ def bundle_xepu_mlp_schedule( terminal_ops = match_and_split(mod, ops={"linalg.add"}, nhandles=nlayers) else: terminal_ops = match_and_split(mod, ops={"linalg.matmul"}, nhandles=nlayers) + if has_relu and nlayers > 1: + # intermediate layers have relu activation function + relu_ops = match_and_split(mod, ops={"linalg.max"}, nhandles=nlayers - 1) + # the final layer does not have relu + terminal_ops = list(relu_ops) + [terminal_ops[-1]] # tile each layer separately for i_layer in range(nlayers):