Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/xegpu_matmul/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Set `LLVM_INSTALL_DIR` and use the below script to checkout and compile LLVM loc

```bash
export LLVM_INSTALL_DIR=<...>
export LLVM_VERSION=83765f435d1c
export LLVM_VERSION=45bee6efe9d6

git clone https://github.com/llvm/llvm-project.git
cd llvm-project
Expand All @@ -34,7 +34,6 @@ cmake ../llvm -G Ninja \
-DLLVM_BUILD_EXAMPLES=OFF \
-DLLVM_TARGETS_TO_BUILD="host" \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_ENABLE_RTTI=ON \
-DLLVM_EXPERIMENTAL_TARGETS_TO_BUILD="SPIRV" \
-DLLVM_INSTALL_GTEST=ON \
-DMLIR_ENABLE_LEVELZERO_RUNNER=1 \
Expand Down
80 changes: 32 additions & 48 deletions examples/xegpu_matmul/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,22 +271,27 @@ def convert_layout(value, input, target):
tile_a,
nb_prefetch=nb_prefetch,
)
xegpu.set_desc_layout(
desc_prefetch_a,
sg_layout=prefetch_layout_a,
sg_data=prefetch_tile_a,
inst_data=prefetch_inst_data,
)
layout_prefetch_a = {
"sg_layout": prefetch_layout_a,
"sg_data": prefetch_tile_a,
"inst_data": prefetch_inst_data,
}
pf_ops = transform.get_consumers_of_result(anytype, desc_prefetch_a, 0)
for pf in transform.split_handle((anytype,) * (nb_prefetch + 1), pf_ops):
xegpu.set_op_layout_attr(pf, **layout_prefetch_a)

desc_prefetch_b = xegpu.insert_prefetch(
tile_b,
nb_prefetch=nb_prefetch,
)
xegpu.set_desc_layout(
desc_prefetch_b,
sg_layout=prefetch_layout_b,
sg_data=prefetch_tile_b,
inst_data=prefetch_inst_data,
)
layout_prefetch_b = {
"sg_layout": prefetch_layout_b,
"sg_data": prefetch_tile_b,
"inst_data": prefetch_inst_data,
}
pf_ops = transform.get_consumers_of_result(anytype, desc_prefetch_b, 0)
for pf in transform.split_handle((anytype,) * (nb_prefetch + 1), pf_ops):
xegpu.set_op_layout_attr(pf, **layout_prefetch_b)

# A tile load layout
layout_load_a = {
Expand All @@ -295,10 +300,9 @@ def convert_layout(value, input, target):
"inst_data": load_tile_a,
}
desc_op_a = xegpu.get_desc_op(tile_a)
desc_op_a = xegpu.set_desc_layout(
target=desc_op_a,
**layout_load_a,
)
# A tile load op anchor layout
load_op_a = transform.get_consumers_of_result(anytype, desc_op_a, 0)
xegpu.set_op_layout_attr(load_op_a, **layout_load_a)
# A tile dpas layout
layout_dpas_a = layout_load_a.copy()
layout_dpas_a["inst_data"] = dpas_shape_a
Expand All @@ -311,10 +315,9 @@ def convert_layout(value, input, target):
"inst_data": load_tile_b,
}
desc_op_b = xegpu.get_desc_op(tile_b)
desc_op_b = xegpu.set_desc_layout(
target=desc_op_b,
**layout_load_b,
)
# B tile load op anchor layout
load_op_b = transform.get_consumers_of_result(anytype, desc_op_b, 0)
xegpu.set_op_layout_attr(load_op_b, **layout_load_b)
# B tile dpas layout
layout_dpas_b = layout_load_b.copy()
layout_dpas_b["inst_data"] = dpas_shape_b
Expand All @@ -327,42 +330,23 @@ def convert_layout(value, input, target):
"inst_data": dpas_shape_c,
}
desc_op_c = xegpu.get_desc_op(tile_c)
desc_op_c = xegpu.set_desc_layout(desc_op_c, **output_layout)
# C tile dpas layout
xegpu.set_op_layout_attr(dpas_op, result=True, index=0, **output_layout)
# C tile load/store op anchor layout
desc_c_users = transform.get_consumers_of_result(anytype, desc_op_c, 0)
load_op_c, store_op_c = transform.split_handle((anytype, anytype), desc_c_users)
xegpu.set_op_layout_attr(load_op_c, **output_layout)
# C tile dpas anchor layout
xegpu.set_op_layout_attr(dpas_op, index=0, **layout_dpas_a)
xegpu.set_op_layout_attr(dpas_op, index=1, **layout_dpas_b)
xegpu.set_op_layout_attr(dpas_op, index=2, **output_layout)

if has_relu:
# for post ops we need to add C layout manually
max_op = match(gpu_func, ops={"arith.maximumf"})
xegpu.set_op_layout_attr(max_op, result=True, index=0, **output_layout)
# find zero constant buffer and annotate it
const_buffer = transform.get_producer_of_operand(anytype, max_op, 1)
xegpu.set_op_layout_attr(const_buffer, result=True, index=0, **output_layout)
if has_bias:
# for post ops we need to add C layout manually
# annotate the 1d load of the broadcast op with a slice layout
add_op = match(gpu_func, ops={"arith.addf"})
xegpu.set_op_layout_attr(add_op, result=True, index=0, **output_layout)

# annotate broadcast op operands
bcast_op = transform.get_producer_of_operand(anytype, add_op, 0)
xegpu.set_op_layout_attr(bcast_op, result=True, index=0, **output_layout)
bcast_load = transform.get_producer_of_operand(anytype, bcast_op, 0)
xegpu.set_op_layout_attr(
bcast_load, result=True, index=0, **output_layout, slice_dims=[0]
)
output_layout_dim1 = {
"sg_layout": [sg_layout[1]],
"sg_data": [sg_tile[1]],
"inst_data": [dpas_shape_c[1]],
}
offset = transform.get_producer_of_operand(anytype, bcast_load, 1)
xegpu.set_op_layout_attr(offset, result=True, index=0, **output_layout_dim1)
aux1 = transform.get_producer_of_operand(anytype, offset, 0)
xegpu.set_op_layout_attr(aux1, result=True, index=0, **output_layout_dim1)
aux2 = transform.get_producer_of_operand(anytype, offset, 1)
xegpu.set_op_layout_attr(aux2, result=True, index=0, **output_layout_dim1)
mask = transform.get_producer_of_operand(anytype, bcast_load, 2)
xegpu.set_op_layout_attr(mask, result=True, index=0, **output_layout_dim1)
raise NotImplementedError("Bias layout propagation is not supported.")
transform.apply_cse(gpu_func)
canonicalize(gpu_func)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "lighthouse"
dynamic = ["version"]
requires-python = ">=3.10,<3.13" # Bounds are due to torch-mlir's packaging
dependencies = [
"mlir-python-bindings==20260211+f932646bf"
"mlir-python-bindings==20260215+45bee6efe"
]

[dependency-groups]
Expand Down