diff --git a/examples/xegpu_matmul/README.md b/examples/xegpu_matmul/README.md index 9de0271..3acf9ad 100644 --- a/examples/xegpu_matmul/README.md +++ b/examples/xegpu_matmul/README.md @@ -20,7 +20,7 @@ Set `LLVM_INSTALL_DIR` and use the below script to checkout and compile LLVM loc ```bash export LLVM_INSTALL_DIR=<...> -export LLVM_VERSION=83765f435d1c +export LLVM_VERSION=45bee6efe9d6 git clone https://github.com/llvm/llvm-project.git cd llvm-project @@ -34,7 +34,6 @@ cmake ../llvm -G Ninja \ -DLLVM_BUILD_EXAMPLES=OFF \ -DLLVM_TARGETS_TO_BUILD="host" \ -DLLVM_ENABLE_ASSERTIONS=ON \ - -DLLVM_ENABLE_RTTI=ON \ -DLLVM_EXPERIMENTAL_TARGETS_TO_BUILD="SPIRV" \ -DLLVM_INSTALL_GTEST=ON \ -DMLIR_ENABLE_LEVELZERO_RUNNER=1 \ diff --git a/examples/xegpu_matmul/schedule.py b/examples/xegpu_matmul/schedule.py index 1dac3d9..b5827be 100644 --- a/examples/xegpu_matmul/schedule.py +++ b/examples/xegpu_matmul/schedule.py @@ -271,22 +271,27 @@ def convert_layout(value, input, target): tile_a, nb_prefetch=nb_prefetch, ) - xegpu.set_desc_layout( - desc_prefetch_a, - sg_layout=prefetch_layout_a, - sg_data=prefetch_tile_a, - inst_data=prefetch_inst_data, - ) + layout_prefetch_a = { + "sg_layout": prefetch_layout_a, + "sg_data": prefetch_tile_a, + "inst_data": prefetch_inst_data, + } + pf_ops = transform.get_consumers_of_result(anytype, desc_prefetch_a, 0) + for pf in transform.split_handle((anytype,) * (nb_prefetch + 1), pf_ops): + xegpu.set_op_layout_attr(pf, **layout_prefetch_a) + desc_prefetch_b = xegpu.insert_prefetch( tile_b, nb_prefetch=nb_prefetch, ) - xegpu.set_desc_layout( - desc_prefetch_b, - sg_layout=prefetch_layout_b, - sg_data=prefetch_tile_b, - inst_data=prefetch_inst_data, - ) + layout_prefetch_b = { + "sg_layout": prefetch_layout_b, + "sg_data": prefetch_tile_b, + "inst_data": prefetch_inst_data, + } + pf_ops = transform.get_consumers_of_result(anytype, desc_prefetch_b, 0) + for pf in transform.split_handle((anytype,) * (nb_prefetch + 1), pf_ops): + xegpu.set_op_layout_attr(pf, **layout_prefetch_b) # A tile load layout layout_load_a = { @@ -295,10 +300,9 @@ def convert_layout(value, input, target): "inst_data": load_tile_a, } desc_op_a = xegpu.get_desc_op(tile_a) - desc_op_a = xegpu.set_desc_layout( - target=desc_op_a, - **layout_load_a, - ) + # A tile load op anchor layout + load_op_a = transform.get_consumers_of_result(anytype, desc_op_a, 0) + xegpu.set_op_layout_attr(load_op_a, **layout_load_a) # A tile dpas layout layout_dpas_a = layout_load_a.copy() layout_dpas_a["inst_data"] = dpas_shape_a @@ -311,10 +315,9 @@ def convert_layout(value, input, target): "inst_data": load_tile_b, } desc_op_b = xegpu.get_desc_op(tile_b) - desc_op_b = xegpu.set_desc_layout( - target=desc_op_b, - **layout_load_b, - ) + # B tile load op anchor layout + load_op_b = transform.get_consumers_of_result(anytype, desc_op_b, 0) + xegpu.set_op_layout_attr(load_op_b, **layout_load_b) # B tile dpas layout layout_dpas_b = layout_load_b.copy() layout_dpas_b["inst_data"] = dpas_shape_b @@ -327,42 +330,23 @@ def convert_layout(value, input, target): "inst_data": dpas_shape_c, } desc_op_c = xegpu.get_desc_op(tile_c) - desc_op_c = xegpu.set_desc_layout(desc_op_c, **output_layout) - # C tile dpas layout - xegpu.set_op_layout_attr(dpas_op, result=True, index=0, **output_layout) + # C tile load/store op anchor layout + desc_c_users = transform.get_consumers_of_result(anytype, desc_op_c, 0) + load_op_c, store_op_c = transform.split_handle((anytype, anytype), desc_c_users) + xegpu.set_op_layout_attr(load_op_c, **output_layout) + # C tile dpas anchor layout + xegpu.set_op_layout_attr(dpas_op, index=0, **layout_dpas_a) + xegpu.set_op_layout_attr(dpas_op, index=1, **layout_dpas_b) + xegpu.set_op_layout_attr(dpas_op, index=2, **output_layout) - if has_relu: - # for post ops we need to add C layout manually - max_op = match(gpu_func, ops={"arith.maximumf"}) - xegpu.set_op_layout_attr(max_op, result=True, index=0, **output_layout) - # find zero constant buffer and annotate it - const_buffer = transform.get_producer_of_operand(anytype, max_op, 1) - xegpu.set_op_layout_attr(const_buffer, result=True, index=0, **output_layout) if has_bias: - # for post ops we need to add C layout manually + # annotate the 1d load of the broadcast op with a slice layout add_op = match(gpu_func, ops={"arith.addf"}) - xegpu.set_op_layout_attr(add_op, result=True, index=0, **output_layout) - - # annotate broadcast op operands bcast_op = transform.get_producer_of_operand(anytype, add_op, 0) - xegpu.set_op_layout_attr(bcast_op, result=True, index=0, **output_layout) bcast_load = transform.get_producer_of_operand(anytype, bcast_op, 0) xegpu.set_op_layout_attr( bcast_load, result=True, index=0, **output_layout, slice_dims=[0] ) - output_layout_dim1 = { - "sg_layout": [sg_layout[1]], - "sg_data": [sg_tile[1]], - "inst_data": [dpas_shape_c[1]], - } - offset = transform.get_producer_of_operand(anytype, bcast_load, 1) - xegpu.set_op_layout_attr(offset, result=True, index=0, **output_layout_dim1) - aux1 = transform.get_producer_of_operand(anytype, offset, 0) - xegpu.set_op_layout_attr(aux1, result=True, index=0, **output_layout_dim1) - aux2 = transform.get_producer_of_operand(anytype, offset, 1) - xegpu.set_op_layout_attr(aux2, result=True, index=0, **output_layout_dim1) - mask = transform.get_producer_of_operand(anytype, bcast_load, 2) - xegpu.set_op_layout_attr(mask, result=True, index=0, **output_layout_dim1) raise NotImplementedError("Bias layout propagation is not supported.") transform.apply_cse(gpu_func) canonicalize(gpu_func) diff --git a/pyproject.toml b/pyproject.toml index d81e33a..9814bee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "lighthouse" dynamic = ["version"] requires-python = ">=3.10,<3.13" # Bounds are due to torch-mlir's packaging dependencies = [ - "mlir-python-bindings==20260211+f932646bf" + "mlir-python-bindings==20260215+45bee6efe" ] [dependency-groups]