From 34849c6aa4b89847768e964f92e6dbf510154347 Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Tue, 24 Feb 2026 04:15:03 -0800 Subject: [PATCH 1/3] simplifying ingress sharding and using sharding propagation --- examples/mlp-mpi/mlp-mpi.py | 54 +++++++++++++-------- examples/mlp-mpi/mlp_weight_stationary.mlir | 48 +++++++----------- 2 files changed, 52 insertions(+), 50 deletions(-) diff --git a/examples/mlp-mpi/mlp-mpi.py b/examples/mlp-mpi/mlp-mpi.py index 73aed65..773519c 100644 --- a/examples/mlp-mpi/mlp-mpi.py +++ b/examples/mlp-mpi/mlp-mpi.py @@ -146,17 +146,26 @@ def allocate_inputs(self, execution_engine: ExecutionEngine): ): yield self.input_memrefs + def _gather( + self, + memref: ctypes.Structure, + execution_engine: ExecutionEngine, + gather_func: str, + ) -> np.ndarray: + gathered_memref = make_nd_memref_descriptor(2, as_ctype(self.dtype))() + execution_engine.invoke( + gather_func, + memref_to_ctype(gathered_memref), + memref_to_ctype(memref), + ) + return gathered_memref + def _reference_solution(self, execution_engine: ExecutionEngine) -> np.ndarray: rprint(" * Gathering input data...") - gathered = [] - for i, v in enumerate(["act", "win", "wout"]): - memref = make_nd_memref_descriptor(2, as_ctype(self.dtype))() - execution_engine.invoke( - f"gather_{v}", - memref_to_ctype(memref), - memref_to_ctype(self.input_memrefs[i + 1]), - ) - gathered.append(memref) + gathered = [ + self._gather(self.input_memrefs[i + 1], execution_engine, f"gather_{v}") + for i, v in enumerate(["act", "win", "wout"]) + ] rprint(" * Computing reference solution...") @@ -171,7 +180,9 @@ def sigmoid(z): def check_correctness( self, execution_engine: ExecutionEngine, verbose: int = 0 ) -> bool: - R = ranked_memref_to_numpy([self.input_memrefs[0]]) + R = ranked_memref_to_numpy( + [self._gather(self.input_memrefs[0], execution_engine, "gather_act")] + ) R_ref = self._reference_solution(execution_engine) if verbose > 1: rprint("Reference solution:") @@ -238,13 +249,9 @@ def find_factors(n): "split_act": "[[], [0]]", "split_win": "[[], [0]]", "split_wout": "[[0], []]", - "split_mm0_a": "[[]]", - "split_mm0_b": "[[], [0]]", + "split_mm0a_mm1c": "[[]]", "split_mm0_c": "[[], [0]]", "split_sigmoid": "[[], [0]]", - "split_mm1_a": "[[], [0]]", - "split_mm1_b": "[[0], []]", - "split_mm1_c": "[[]]", } ) else: @@ -253,13 +260,9 @@ def find_factors(n): "split_act": "[[], [0, 1]]", "split_win": "[[0], [1]]", "split_wout": "[[1], [0]]", - "split_mm0_a": "[[], [0]]", - "split_mm0_b": "[[0], [1]]", + "split_mm0a_mm1c": "[[], [0]]", "split_mm0_c": "[[], [1]]", "split_sigmoid": "[[], [1, 0]]", - "split_mm1_a": "[[], [1]]", - "split_mm1_b": "[[1], [0]]", - "split_mm1_c": "[[], [0]]", } ) txt = txt.format_map(format_values) @@ -290,10 +293,21 @@ def schedule_module( with ir.InsertionPoint(named_sequence.body): anytype = transform.AnyOpType.get() func = match(named_sequence.bodyTarget, ops={"func.func"}) + func = apply_registered_pass( + func, + "sharding-propagation", + options={"traversal": "forward-backward"}, + ) + if self.verbose > 0: + transform.PrintOp(target=func) func = apply_registered_pass(func, "shard-partition") func = apply_registered_pass(func, "canonicalize") + if self.verbose > 0: + transform.PrintOp(target=func) func = apply_registered_pass(func, "convert-shard-to-mpi") func = apply_registered_pass(func, "canonicalize") + if self.verbose > 0: + transform.PrintOp(target=func) func = apply_registered_pass(func, "tosa-to-linalg") mod = transform.get_parent_op( anytype, func, op_name="builtin.module", deduplicate=True diff --git a/examples/mlp-mpi/mlp_weight_stationary.mlir b/examples/mlp-mpi/mlp_weight_stationary.mlir index 887bf4e..a228cb1 100644 --- a/examples/mlp-mpi/mlp_weight_stationary.mlir +++ b/examples/mlp-mpi/mlp_weight_stationary.mlir @@ -1,20 +1,17 @@ module attributes {{mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH", "MPI:comm_world_size" = {P}, "MPI:comm_world_rank" = {R}> }} {{ shard.grid @grid0(shape = {grid}) {{sym_visibility = "private"}} func.func @{func_name}(%arg0: tensor<{M}x{K}xf32>, %arg1: tensor<{K}x{N}xf32>, %arg2: tensor<{N}x{K}xf32>) -> tensor<{M}x{K}xf32> attributes {{llvm.emit_c_interface}} {{ + %cst = arith.constant 0.000000e+00 : f32 + %sharding_arg0 = shard.sharding @grid0 split_axes = {split_act} : !shard.sharding %sharding_arg1 = shard.sharding @grid0 split_axes = {split_win} : !shard.sharding %sharding_arg2 = shard.sharding @grid0 split_axes = {split_wout} : !shard.sharding - %sharding_mm0_a = shard.sharding @grid0 split_axes = {split_mm0_a} : !shard.sharding - %sharding_mm0_b = shard.sharding @grid0 split_axes = {split_mm0_b} : !shard.sharding + %sharding_mm0a_mm1c = shard.sharding @grid0 split_axes = {split_mm0a_mm1c} : !shard.sharding %sharding_mm0_c = shard.sharding @grid0 split_axes = {split_mm0_c} : !shard.sharding %sharding_sigmoid = shard.sharding @grid0 split_axes = {split_sigmoid} : !shard.sharding - %sharding_mm1_a = shard.sharding @grid0 split_axes = {split_mm1_a} : !shard.sharding - %sharding_mm1_b = shard.sharding @grid0 split_axes = {split_mm1_b} : !shard.sharding - %sharding_mm1_c = shard.sharding @grid0 split_axes = {split_mm1_c} : !shard.sharding - %sharding_r = shard.sharding @grid0 split_axes = {split_r} : !shard.sharding %sharded = shard.shard %arg0 to %sharding_arg0 : tensor<{M}x{K}xf32> @@ -22,39 +19,25 @@ module attributes {{mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH", "MPI:co %sharded_7 = shard.shard %arg2 to %sharding_arg2 : tensor<{N}x{K}xf32> %0 = tensor.empty() : tensor<{M}x{N}xf32> - %sharded_8 = shard.shard %0 to %sharding_mm0_c : tensor<{M}x{N}xf32> - - %cst = arith.constant 0.000000e+00 : f32 - %sharded_9 = shard.shard %sharded_8 to %sharding_mm0_c annotate_for_users : tensor<{M}x{N}xf32> - %1 = linalg.fill ins(%cst : f32) outs(%sharded_9 : tensor<{M}x{N}xf32>) -> tensor<{M}x{N}xf32> - %sharded_10 = shard.shard %1 to %sharding_mm0_c : tensor<{M}x{N}xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<{M}x{N}xf32>) -> tensor<{M}x{N}xf32> - %sharded_11 = shard.shard %sharded to %sharding_mm0_a annotate_for_users : tensor<{M}x{K}xf32> - %sharded_12 = shard.shard %sharded_6 to %sharding_mm0_b annotate_for_users : tensor<{K}x{N}xf32> - %sharded_13 = shard.shard %sharded_10 to %sharding_mm0_c annotate_for_users : tensor<{M}x{N}xf32> - %2 = linalg.matmul ins(%sharded_11, %sharded_12 : tensor<{M}x{K}xf32>, tensor<{K}x{N}xf32>) outs(%sharded_13 : tensor<{M}x{N}xf32>) -> tensor<{M}x{N}xf32> - %sharded_14 = shard.shard %2 to %sharding_mm0_c : tensor<{M}x{N}xf32> + %sharded_11 = shard.shard %sharded to %sharding_mm0a_mm1c annotate_for_users : tensor<{M}x{K}xf32> + %sharded_13 = shard.shard %1 to %sharding_mm0_c annotate_for_users : tensor<{M}x{N}xf32> + %2 = linalg.matmul ins(%sharded_11, %sharded_6 : tensor<{M}x{K}xf32>, tensor<{K}x{N}xf32>) outs(%sharded_13 : tensor<{M}x{N}xf32>) -> tensor<{M}x{N}xf32> - %sharded_15 = shard.shard %sharded_14 to %sharding_sigmoid annotate_for_users : tensor<{M}x{N}xf32> + %sharded_15 = shard.shard %2 to %sharding_sigmoid annotate_for_users : tensor<{M}x{N}xf32> %3 = tosa.sigmoid %sharded_15 : (tensor<{M}x{N}xf32>) -> tensor<{M}x{N}xf32> - %sharded_16 = shard.shard %3 to %sharding_sigmoid : tensor<{M}x{N}xf32> %4 = tensor.empty() : tensor<{M}x{K}xf32> - %sharded_17 = shard.shard %4 to %sharding_mm1_c : tensor<{M}x{K}xf32> - - %sharded_18 = shard.shard %sharded_17 to %sharding_mm1_c annotate_for_users : tensor<{M}x{K}xf32> - %5 = linalg.fill ins(%cst : f32) outs(%sharded_18 : tensor<{M}x{K}xf32>) -> tensor<{M}x{K}xf32> - %sharded_19 = shard.shard %5 to %sharding_mm1_c : tensor<{M}x{K}xf32> + %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<{M}x{K}xf32>) -> tensor<{M}x{K}xf32> - %sharded_20 = shard.shard %sharded_16 to %sharding_mm1_a annotate_for_users : tensor<{M}x{N}xf32> - %sharded_21 = shard.shard %sharded_7 to %sharding_mm1_b annotate_for_users : tensor<{N}x{K}xf32> - %sharded_22 = shard.shard %sharded_19 to %sharding_mm1_c annotate_for_users : tensor<{M}x{K}xf32> - %6 = linalg.matmul ins(%sharded_20, %sharded_21 : tensor<{M}x{N}xf32>, tensor<{N}x{K}xf32>) outs(%sharded_22 : tensor<{M}x{K}xf32>) -> tensor<{M}x{K}xf32> - %sharded_23 = shard.shard %6 to %sharding_mm1_c : tensor<{M}x{K}xf32> + %sharded_22 = shard.shard %5 to %sharding_mm0a_mm1c annotate_for_users : tensor<{M}x{K}xf32> + %6 = linalg.matmul ins(%3, %sharded_7 : tensor<{M}x{N}xf32>, tensor<{N}x{K}xf32>) outs(%sharded_22 : tensor<{M}x{K}xf32>) -> tensor<{M}x{K}xf32> - %sharded_24 = shard.shard %sharded_23 to %sharding_r annotate_for_users : tensor<{M}x{K}xf32> + %sharded_24 = shard.shard %6 to %sharding_arg0 annotate_for_users : tensor<{M}x{K}xf32> return %sharded_24 : tensor<{M}x{K}xf32> }} + func.func @alloc_act() -> (tensor<{M}x{K}xf32>) attributes {{llvm.emit_c_interface}} {{ %a = tensor.empty() : tensor<{M}x{K}xf32> %sharding_act = shard.sharding @grid0 split_axes = {split_act} : !shard.sharding @@ -62,6 +45,7 @@ module attributes {{mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH", "MPI:co %ret_a = shard.shard %sharded_act to %sharding_act annotate_for_users : tensor<{M}x{K}xf32> return %ret_a : tensor<{M}x{K}xf32> }} + func.func @alloc_win() -> (tensor<{K}x{N}xf32>) attributes {{llvm.emit_c_interface}} {{ %b = tensor.empty() : tensor<{K}x{N}xf32> %sharding_win = shard.sharding @grid0 split_axes = {split_win} : !shard.sharding @@ -69,6 +53,7 @@ module attributes {{mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH", "MPI:co %ret_win = shard.shard %sharded_win to %sharding_win annotate_for_users : tensor<{K}x{N}xf32> return %ret_win : tensor<{K}x{N}xf32> }} + func.func @alloc_wout() -> (tensor<{N}x{K}xf32>) attributes {{llvm.emit_c_interface}} {{ %c = tensor.empty() : tensor<{N}x{K}xf32> %sharding_wout = shard.sharding @grid0 split_axes = {split_wout} : !shard.sharding @@ -76,6 +61,7 @@ module attributes {{mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH", "MPI:co %ret_wout = shard.shard %sharded_wout to %sharding_wout annotate_for_users : tensor<{N}x{K}xf32> return %ret_wout : tensor<{N}x{K}xf32> }} + func.func @alloc_r() -> (tensor<{M}x{K}xf32>) attributes {{llvm.emit_c_interface}} {{ %a = tensor.empty() : tensor<{M}x{K}xf32> %sharding_r = shard.sharding @grid0 split_axes = {split_r} : !shard.sharding @@ -96,6 +82,7 @@ module attributes {{mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH", "MPI:co %sharded_g = shard.shard %sharded to %sharding_g annotate_for_users : tensor<{M}x{K}xf32> return %sharded_g : tensor<{M}x{K}xf32> }} + func.func @gather_win(%arg0: tensor<{K}x{N}xf32>) -> tensor<{K}x{N}xf32> attributes {{llvm.emit_c_interface}} {{ %sharding = shard.sharding @grid0 split_axes = {split_win} : !shard.sharding %sharding_g = shard.sharding @grid0 split_axes = [[]] : !shard.sharding @@ -103,6 +90,7 @@ module attributes {{mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH", "MPI:co %sharded_g = shard.shard %sharded to %sharding_g annotate_for_users : tensor<{K}x{N}xf32> return %sharded_g : tensor<{K}x{N}xf32> }} + func.func @gather_wout(%arg0: tensor<{N}x{K}xf32>) -> tensor<{N}x{K}xf32> attributes {{llvm.emit_c_interface}} {{ %sharding = shard.sharding @grid0 split_axes = {split_wout} : !shard.sharding %sharding_g = shard.sharding @grid0 split_axes = [[]] : !shard.sharding From 904db5a6a0071d260adc7a36597c998ef44c244a Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Tue, 24 Feb 2026 04:27:05 -0800 Subject: [PATCH 2/3] removing unused sharding --- examples/mlp-mpi/mlp_weight_stationary.mlir | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/mlp-mpi/mlp_weight_stationary.mlir b/examples/mlp-mpi/mlp_weight_stationary.mlir index a228cb1..54aacd8 100644 --- a/examples/mlp-mpi/mlp_weight_stationary.mlir +++ b/examples/mlp-mpi/mlp_weight_stationary.mlir @@ -12,8 +12,6 @@ module attributes {{mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH", "MPI:co %sharding_sigmoid = shard.sharding @grid0 split_axes = {split_sigmoid} : !shard.sharding - %sharding_r = shard.sharding @grid0 split_axes = {split_r} : !shard.sharding - %sharded = shard.shard %arg0 to %sharding_arg0 : tensor<{M}x{K}xf32> %sharded_6 = shard.shard %arg1 to %sharding_arg1 : tensor<{K}x{N}xf32> %sharded_7 = shard.shard %arg2 to %sharding_arg2 : tensor<{N}x{K}xf32> From eefe269d80c5d763ce2ca60e3e9dbd96bc6a3608 Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Tue, 24 Feb 2026 04:34:47 -0800 Subject: [PATCH 3/3] GC --- examples/mlp-mpi/mlp-mpi.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/mlp-mpi/mlp-mpi.py b/examples/mlp-mpi/mlp-mpi.py index 773519c..6ce776f 100644 --- a/examples/mlp-mpi/mlp-mpi.py +++ b/examples/mlp-mpi/mlp-mpi.py @@ -151,7 +151,7 @@ def _gather( memref: ctypes.Structure, execution_engine: ExecutionEngine, gather_func: str, - ) -> np.ndarray: + ) -> ctypes.Structure: gathered_memref = make_nd_memref_descriptor(2, as_ctype(self.dtype))() execution_engine.invoke( gather_func, @@ -180,16 +180,16 @@ def sigmoid(z): def check_correctness( self, execution_engine: ExecutionEngine, verbose: int = 0 ) -> bool: - R = ranked_memref_to_numpy( - [self._gather(self.input_memrefs[0], execution_engine, "gather_act")] - ) - R_ref = self._reference_solution(execution_engine) - if verbose > 1: - rprint("Reference solution:") - rprint(R_ref) - rprint("Computed solution:") - rprint(R) - success = np.allclose(R, R_ref) + gathered = self._gather(self.input_memrefs[0], execution_engine, "gather_act") + with deallocate_memrefs_on_exit([gathered], execution_engine, "dealloc_2d"): + R = ranked_memref_to_numpy([gathered]) + R_ref = self._reference_solution(execution_engine) + if verbose > 1: + rprint("Reference solution:") + rprint(R_ref) + rprint("Computed solution:") + rprint(R) + success = np.allclose(R, R_ref) success = MPI.COMM_WORLD.allreduce(success, op=MPI.LAND) if success: rprint("PASSED")