Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 41 additions & 27 deletions examples/mlp-mpi/mlp-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,17 +146,26 @@ def allocate_inputs(self, execution_engine: ExecutionEngine):
):
yield self.input_memrefs

def _gather(
self,
memref: ctypes.Structure,
execution_engine: ExecutionEngine,
gather_func: str,
) -> ctypes.Structure:
gathered_memref = make_nd_memref_descriptor(2, as_ctype(self.dtype))()
execution_engine.invoke(
gather_func,
memref_to_ctype(gathered_memref),
memref_to_ctype(memref),
)
return gathered_memref

def _reference_solution(self, execution_engine: ExecutionEngine) -> np.ndarray:
rprint(" * Gathering input data...")
gathered = []
for i, v in enumerate(["act", "win", "wout"]):
memref = make_nd_memref_descriptor(2, as_ctype(self.dtype))()
execution_engine.invoke(
f"gather_{v}",
memref_to_ctype(memref),
memref_to_ctype(self.input_memrefs[i + 1]),
)
gathered.append(memref)
gathered = [
self._gather(self.input_memrefs[i + 1], execution_engine, f"gather_{v}")
for i, v in enumerate(["act", "win", "wout"])
]

rprint(" * Computing reference solution...")

Expand All @@ -171,14 +180,16 @@ def sigmoid(z):
def check_correctness(
self, execution_engine: ExecutionEngine, verbose: int = 0
) -> bool:
R = ranked_memref_to_numpy([self.input_memrefs[0]])
R_ref = self._reference_solution(execution_engine)
if verbose > 1:
rprint("Reference solution:")
rprint(R_ref)
rprint("Computed solution:")
rprint(R)
success = np.allclose(R, R_ref)
gathered = self._gather(self.input_memrefs[0], execution_engine, "gather_act")
with deallocate_memrefs_on_exit([gathered], execution_engine, "dealloc_2d"):
R = ranked_memref_to_numpy([gathered])
R_ref = self._reference_solution(execution_engine)
if verbose > 1:
rprint("Reference solution:")
rprint(R_ref)
rprint("Computed solution:")
rprint(R)
success = np.allclose(R, R_ref)
success = MPI.COMM_WORLD.allreduce(success, op=MPI.LAND)
if success:
rprint("PASSED")
Expand Down Expand Up @@ -238,13 +249,9 @@ def find_factors(n):
"split_act": "[[], [0]]",
"split_win": "[[], [0]]",
"split_wout": "[[0], []]",
"split_mm0_a": "[[]]",
"split_mm0_b": "[[], [0]]",
"split_mm0a_mm1c": "[[]]",
"split_mm0_c": "[[], [0]]",
"split_sigmoid": "[[], [0]]",
"split_mm1_a": "[[], [0]]",
"split_mm1_b": "[[0], []]",
"split_mm1_c": "[[]]",
}
)
else:
Expand All @@ -253,13 +260,9 @@ def find_factors(n):
"split_act": "[[], [0, 1]]",
"split_win": "[[0], [1]]",
"split_wout": "[[1], [0]]",
"split_mm0_a": "[[], [0]]",
"split_mm0_b": "[[0], [1]]",
"split_mm0a_mm1c": "[[], [0]]",
"split_mm0_c": "[[], [1]]",
"split_sigmoid": "[[], [1, 0]]",
"split_mm1_a": "[[], [1]]",
"split_mm1_b": "[[1], [0]]",
"split_mm1_c": "[[], [0]]",
}
)
txt = txt.format_map(format_values)
Expand Down Expand Up @@ -290,10 +293,21 @@ def schedule_module(
with ir.InsertionPoint(named_sequence.body):
anytype = transform.AnyOpType.get()
func = match(named_sequence.bodyTarget, ops={"func.func"})
func = apply_registered_pass(
func,
"sharding-propagation",
options={"traversal": "forward-backward"},
)
if self.verbose > 0:
transform.PrintOp(target=func)
func = apply_registered_pass(func, "shard-partition")
func = apply_registered_pass(func, "canonicalize")
if self.verbose > 0:
transform.PrintOp(target=func)
func = apply_registered_pass(func, "convert-shard-to-mpi")
func = apply_registered_pass(func, "canonicalize")
if self.verbose > 0:
transform.PrintOp(target=func)
func = apply_registered_pass(func, "tosa-to-linalg")
mod = transform.get_parent_op(
anytype, func, op_name="builtin.module", deduplicate=True
Expand Down
50 changes: 18 additions & 32 deletions examples/mlp-mpi/mlp_weight_stationary.mlir
Original file line number Diff line number Diff line change
@@ -1,81 +1,65 @@
module attributes {{mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH", "MPI:comm_world_size" = {P}, "MPI:comm_world_rank" = {R}> }} {{
shard.grid @grid0(shape = {grid}) {{sym_visibility = "private"}}
func.func @{func_name}(%arg0: tensor<{M}x{K}xf32>, %arg1: tensor<{K}x{N}xf32>, %arg2: tensor<{N}x{K}xf32>) -> tensor<{M}x{K}xf32> attributes {{llvm.emit_c_interface}} {{
%cst = arith.constant 0.000000e+00 : f32

%sharding_arg0 = shard.sharding @grid0 split_axes = {split_act} : !shard.sharding
%sharding_arg1 = shard.sharding @grid0 split_axes = {split_win} : !shard.sharding
%sharding_arg2 = shard.sharding @grid0 split_axes = {split_wout} : !shard.sharding

%sharding_mm0_a = shard.sharding @grid0 split_axes = {split_mm0_a} : !shard.sharding
%sharding_mm0_b = shard.sharding @grid0 split_axes = {split_mm0_b} : !shard.sharding
%sharding_mm0a_mm1c = shard.sharding @grid0 split_axes = {split_mm0a_mm1c} : !shard.sharding
%sharding_mm0_c = shard.sharding @grid0 split_axes = {split_mm0_c} : !shard.sharding

%sharding_sigmoid = shard.sharding @grid0 split_axes = {split_sigmoid} : !shard.sharding

%sharding_mm1_a = shard.sharding @grid0 split_axes = {split_mm1_a} : !shard.sharding
%sharding_mm1_b = shard.sharding @grid0 split_axes = {split_mm1_b} : !shard.sharding
%sharding_mm1_c = shard.sharding @grid0 split_axes = {split_mm1_c} : !shard.sharding

%sharding_r = shard.sharding @grid0 split_axes = {split_r} : !shard.sharding

%sharded = shard.shard %arg0 to %sharding_arg0 : tensor<{M}x{K}xf32>
%sharded_6 = shard.shard %arg1 to %sharding_arg1 : tensor<{K}x{N}xf32>
%sharded_7 = shard.shard %arg2 to %sharding_arg2 : tensor<{N}x{K}xf32>

%0 = tensor.empty() : tensor<{M}x{N}xf32>
%sharded_8 = shard.shard %0 to %sharding_mm0_c : tensor<{M}x{N}xf32>

%cst = arith.constant 0.000000e+00 : f32
%sharded_9 = shard.shard %sharded_8 to %sharding_mm0_c annotate_for_users : tensor<{M}x{N}xf32>
%1 = linalg.fill ins(%cst : f32) outs(%sharded_9 : tensor<{M}x{N}xf32>) -> tensor<{M}x{N}xf32>
%sharded_10 = shard.shard %1 to %sharding_mm0_c : tensor<{M}x{N}xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<{M}x{N}xf32>) -> tensor<{M}x{N}xf32>

%sharded_11 = shard.shard %sharded to %sharding_mm0_a annotate_for_users : tensor<{M}x{K}xf32>
%sharded_12 = shard.shard %sharded_6 to %sharding_mm0_b annotate_for_users : tensor<{K}x{N}xf32>
%sharded_13 = shard.shard %sharded_10 to %sharding_mm0_c annotate_for_users : tensor<{M}x{N}xf32>
%2 = linalg.matmul ins(%sharded_11, %sharded_12 : tensor<{M}x{K}xf32>, tensor<{K}x{N}xf32>) outs(%sharded_13 : tensor<{M}x{N}xf32>) -> tensor<{M}x{N}xf32>
%sharded_14 = shard.shard %2 to %sharding_mm0_c : tensor<{M}x{N}xf32>
%sharded_11 = shard.shard %sharded to %sharding_mm0a_mm1c annotate_for_users : tensor<{M}x{K}xf32>
%sharded_13 = shard.shard %1 to %sharding_mm0_c annotate_for_users : tensor<{M}x{N}xf32>
%2 = linalg.matmul ins(%sharded_11, %sharded_6 : tensor<{M}x{K}xf32>, tensor<{K}x{N}xf32>) outs(%sharded_13 : tensor<{M}x{N}xf32>) -> tensor<{M}x{N}xf32>

%sharded_15 = shard.shard %sharded_14 to %sharding_sigmoid annotate_for_users : tensor<{M}x{N}xf32>
%sharded_15 = shard.shard %2 to %sharding_sigmoid annotate_for_users : tensor<{M}x{N}xf32>
%3 = tosa.sigmoid %sharded_15 : (tensor<{M}x{N}xf32>) -> tensor<{M}x{N}xf32>
%sharded_16 = shard.shard %3 to %sharding_sigmoid : tensor<{M}x{N}xf32>

%4 = tensor.empty() : tensor<{M}x{K}xf32>
%sharded_17 = shard.shard %4 to %sharding_mm1_c : tensor<{M}x{K}xf32>
%5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<{M}x{K}xf32>) -> tensor<{M}x{K}xf32>

%sharded_18 = shard.shard %sharded_17 to %sharding_mm1_c annotate_for_users : tensor<{M}x{K}xf32>
%5 = linalg.fill ins(%cst : f32) outs(%sharded_18 : tensor<{M}x{K}xf32>) -> tensor<{M}x{K}xf32>
%sharded_19 = shard.shard %5 to %sharding_mm1_c : tensor<{M}x{K}xf32>
%sharded_22 = shard.shard %5 to %sharding_mm0a_mm1c annotate_for_users : tensor<{M}x{K}xf32>
%6 = linalg.matmul ins(%3, %sharded_7 : tensor<{M}x{N}xf32>, tensor<{N}x{K}xf32>) outs(%sharded_22 : tensor<{M}x{K}xf32>) -> tensor<{M}x{K}xf32>

%sharded_20 = shard.shard %sharded_16 to %sharding_mm1_a annotate_for_users : tensor<{M}x{N}xf32>
%sharded_21 = shard.shard %sharded_7 to %sharding_mm1_b annotate_for_users : tensor<{N}x{K}xf32>
%sharded_22 = shard.shard %sharded_19 to %sharding_mm1_c annotate_for_users : tensor<{M}x{K}xf32>
%6 = linalg.matmul ins(%sharded_20, %sharded_21 : tensor<{M}x{N}xf32>, tensor<{N}x{K}xf32>) outs(%sharded_22 : tensor<{M}x{K}xf32>) -> tensor<{M}x{K}xf32>
%sharded_23 = shard.shard %6 to %sharding_mm1_c : tensor<{M}x{K}xf32>

%sharded_24 = shard.shard %sharded_23 to %sharding_r annotate_for_users : tensor<{M}x{K}xf32>
%sharded_24 = shard.shard %6 to %sharding_arg0 annotate_for_users : tensor<{M}x{K}xf32>
return %sharded_24 : tensor<{M}x{K}xf32>
}}

func.func @alloc_act() -> (tensor<{M}x{K}xf32>) attributes {{llvm.emit_c_interface}} {{
%a = tensor.empty() : tensor<{M}x{K}xf32>
%sharding_act = shard.sharding @grid0 split_axes = {split_act} : !shard.sharding
%sharded_act = shard.shard %a to %sharding_act : tensor<{M}x{K}xf32>
%ret_a = shard.shard %sharded_act to %sharding_act annotate_for_users : tensor<{M}x{K}xf32>
return %ret_a : tensor<{M}x{K}xf32>
}}

func.func @alloc_win() -> (tensor<{K}x{N}xf32>) attributes {{llvm.emit_c_interface}} {{
%b = tensor.empty() : tensor<{K}x{N}xf32>
%sharding_win = shard.sharding @grid0 split_axes = {split_win} : !shard.sharding
%sharded_win = shard.shard %b to %sharding_win : tensor<{K}x{N}xf32>
%ret_win = shard.shard %sharded_win to %sharding_win annotate_for_users : tensor<{K}x{N}xf32>
return %ret_win : tensor<{K}x{N}xf32>
}}

func.func @alloc_wout() -> (tensor<{N}x{K}xf32>) attributes {{llvm.emit_c_interface}} {{
%c = tensor.empty() : tensor<{N}x{K}xf32>
%sharding_wout = shard.sharding @grid0 split_axes = {split_wout} : !shard.sharding
%sharded_wout = shard.shard %c to %sharding_wout : tensor<{N}x{K}xf32>
%ret_wout = shard.shard %sharded_wout to %sharding_wout annotate_for_users : tensor<{N}x{K}xf32>
return %ret_wout : tensor<{N}x{K}xf32>
}}

func.func @alloc_r() -> (tensor<{M}x{K}xf32>) attributes {{llvm.emit_c_interface}} {{
%a = tensor.empty() : tensor<{M}x{K}xf32>
%sharding_r = shard.sharding @grid0 split_axes = {split_r} : !shard.sharding
Expand All @@ -96,13 +80,15 @@ module attributes {{mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH", "MPI:co
%sharded_g = shard.shard %sharded to %sharding_g annotate_for_users : tensor<{M}x{K}xf32>
return %sharded_g : tensor<{M}x{K}xf32>
}}

func.func @gather_win(%arg0: tensor<{K}x{N}xf32>) -> tensor<{K}x{N}xf32> attributes {{llvm.emit_c_interface}} {{
%sharding = shard.sharding @grid0 split_axes = {split_win} : !shard.sharding
%sharding_g = shard.sharding @grid0 split_axes = [[]] : !shard.sharding
%sharded = shard.shard %arg0 to %sharding : tensor<{K}x{N}xf32>
%sharded_g = shard.shard %sharded to %sharding_g annotate_for_users : tensor<{K}x{N}xf32>
return %sharded_g : tensor<{K}x{N}xf32>
}}

func.func @gather_wout(%arg0: tensor<{N}x{K}xf32>) -> tensor<{N}x{K}xf32> attributes {{llvm.emit_c_interface}} {{
%sharding = shard.sharding @grid0 split_axes = {split_wout} : !shard.sharding
%sharding_g = shard.sharding @grid0 split_axes = [[]] : !shard.sharding
Expand Down
Loading