Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/xegpu_matmul/lit.local.cfg

This file was deleted.

63 changes: 37 additions & 26 deletions examples/xegpu_matmul/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,16 @@
import numpy as np
from mlir import ir
from mlir.runtime.np_to_memref import (
get_ranked_memref_descriptor,
make_nd_memref_descriptor,
as_ctype,
)
from mlir.execution_engine import ExecutionEngine

from lighthouse.workload import Workload, benchmark
from lighthouse.utils.memref import get_packed_arg, to_ctype as memref_to_ctype

# Import from sibling files:
from schedule import get_schedule_module
from payload import generate_matmul_payload


def numpy_to_ctype(arr: np.ndarray) -> ctypes._Pointer:
"""Convert numpy array to memref and ctypes **void pointer."""
return memref_to_ctype(get_ranked_memref_descriptor(arr))
from lighthouse.utils.numpy import numpy_to_ctype
from lighthouse.schedule.xegpu.matmul_schedule import get_schedule_module
from lighthouse.ingress.gpu import generate_matmul_payload


class XeGPUMatMul(Workload):
Expand All @@ -54,6 +47,7 @@ def __init__(
c_type: str = "f32",
has_bias: bool = False,
has_relu: bool = False,
accumulate_c: bool = True,
):
self.M = M
self.N = N
Expand All @@ -73,6 +67,7 @@ def __init__(
self.c_dtype = type_str_to_numpy[c_type]
self.has_bias = has_bias
self.has_relu = has_relu
self.accumulate_c = accumulate_c
if has_bias:
raise NotImplementedError("Bias is not implemented yet")
# cache allocated memrefs
Expand Down Expand Up @@ -136,7 +131,9 @@ def _reference_solution(self) -> np.ndarray:
A, B, C = self._initial_host_arrays
# use float32 data type for efficiency
f32 = np.float32
C_ref = A.astype(f32) @ B.astype(f32) + C.astype(f32)
C_ref = A.astype(f32) @ B.astype(f32)
if self.accumulate_c:
C_ref += C.astype(f32)
if self.has_relu:
C_ref = np.maximum(C_ref, 0)
if self.has_bias:
Expand Down Expand Up @@ -196,6 +193,10 @@ def get_complexity(self) -> tuple[int, int, int]:
nbytes_ab = np.dtype(self.ab_dtype).itemsize
nbytes_c = np.dtype(self.c_dtype).itemsize
memory_reads = (M * K + K * N) * nbytes_ab # read A and B
if self.accumulate_c:
memory_reads += M * N * nbytes_c # read C for accumulation
if self.has_bias:
memory_reads += N * nbytes_c # read bias
memory_writes = M * N * nbytes_c # write C
return (flop_count, memory_reads, memory_writes)

Expand All @@ -209,6 +210,7 @@ def payload_module(self) -> ir.Module:
c_type_str=self.c_type,
has_bias=self.has_bias,
has_relu=self.has_relu,
accumulate_c=self.accumulate_c,
)
return mod

Expand All @@ -218,8 +220,11 @@ def schedule_module(
return get_schedule_module(
has_bias=self.has_bias,
has_relu=self.has_relu,
has_convert_c=False,
accumulate_c=self.accumulate_c,
stop_at_stage=stop_at_stage,
params=parameters,
nlayers=1,
params={"layer_0": parameters},
)

def shared_libs(self) -> list[str]:
Expand Down Expand Up @@ -309,6 +314,11 @@ def parse_cli():
action="store_true",
help="Add relu op after the matrix multiplication (and bias if any).",
)
parser.add_argument(
"--no-accumulate-c",
action="store_true",
help="Compute plain matrix-multiply C=A*B instead of matrix-multiply-accumulate C+=A*B.",
)
parser.add_argument(
"--check-result",
action="store_true",
Expand Down Expand Up @@ -342,20 +352,20 @@ def parse_cli():
args = parse_cli()

params = {
"auto_wg_d0": args.wg_tile[0],
"auto_wg_d1": args.wg_tile[1],
"auto_sg_d0": args.sg_tile[0],
"auto_sg_d1": args.sg_tile[1],
"auto_k": args.k_tile,
"auto_load_a_d0": args.load_tile_a[0],
"auto_load_a_d1": args.load_tile_a[1],
"auto_load_b_d0": args.load_tile_b[0],
"auto_load_b_d1": args.load_tile_b[1],
"auto_prefetch_a_d0": args.prefetch_tile_a[0],
"auto_prefetch_a_d1": args.prefetch_tile_a[1],
"auto_prefetch_b_d0": args.prefetch_tile_b[0],
"auto_prefetch_b_d1": args.prefetch_tile_b[1],
"auto_nb_prefetch": args.nb_prefetch,
"wg_m": args.wg_tile[0],
"wg_n": args.wg_tile[1],
"sg_m": args.sg_tile[0],
"sg_n": args.sg_tile[1],
"k": args.k_tile,
"load_a_m": args.load_tile_a[0],
"load_a_k": args.load_tile_a[1],
"load_b_k": args.load_tile_b[0],
"load_b_n": args.load_tile_b[1],
"pf_a_m": args.prefetch_tile_a[0],
"pf_a_k": args.prefetch_tile_a[1],
"pf_b_k": args.prefetch_tile_b[0],
"pf_b_n": args.prefetch_tile_b[1],
"pf_nb": args.nb_prefetch,
}

M, N, K = args.sizes
Expand All @@ -371,6 +381,7 @@ def parse_cli():
c_type=c_type,
has_bias=False,
has_relu=args.relu,
accumulate_c=not args.no_accumulate_c,
)

if args.dump_kernel or args.dump_schedule:
Expand Down
124 changes: 0 additions & 124 deletions examples/xegpu_matmul/payload.py

This file was deleted.

Loading