Skip to content

Comments

MLP with MPI through MLIR#50

Merged
fschlimb merged 15 commits intollvm:mainfrom
fschlimb:mlp-mpi
Feb 23, 2026
Merged

MLP with MPI through MLIR#50
fschlimb merged 15 commits intollvm:mainfrom
fschlimb:mlp-mpi

Conversation

@fschlimb
Copy link
Contributor

@fschlimb fschlimb commented Feb 4, 2026

This demonstrates the distributed sharding infrastucture of MLIR unsig a single MLP.
Lowering sharding annotations to MPI.
Currently using the lower end of the pipeline, sharding-propagation is not fit for it yet.
2 different sharding policies can be used by selecting a 1d or a 2d device grid.

Note: some fixes to make this work have not yet landed on MLIR main.

@rolfmorel @tkarna @rengolin Is there any CI? If so, what is the recommended way of integrating this?

@rolfmorel
Copy link
Contributor

For CI, I think you will want:

Add an optional Python dependency group for mpi with mpi4py and probably mpich in it (though maybe there should be multiple incompatible dependency groups, like we have for torch, to also target running with openmpi): e.g. like

[project.optional-dependencies]
ingress_torch_mlir = [
"torch-mlir==20260125.703",
"ml_dtypes",
]

Add a mpi/mpi-mpich "feature" to the llvm-lit config, like:

lighthouse/lit.cfg.py

Lines 23 to 24 in d32d1cf

if importlib.util.find_spec("torch"):
config.available_features.add("torch")

Add a REQUIRES: mpi (or REQUIRES: mpi-mpich) line to your python file that is to be CHECKed. E.g., like

# RUN: %PYTHON %s
# REQUIRES: torch

If the CI machines needs to have certain (system) libraries, you will want to modify the following file:

sudo apt-get install -y llvm-dev # Obtain FileCheck, used in testing.

Note that if Ubuntu packages exist, this is easy. If not, this becomes quite a bit more involved.

@fschlimb
Copy link
Contributor Author

@rolfmorel @tkarna @rengolin this is ready for review

Once llvm/llvm-project#180962 has been merged, we should update the mlir dep and also run the 2d-grid case in CI.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR integrates MPI-based distributed computing capabilities into the MLIR infrastructure, enabling multi-rank execution of neural network operations. It adds support for running a Multi-Layer Perceptron (MLP) across multiple MPI processes with different sharding strategies.

Changes:

  • Updated MLIR and torch-mlir dependencies to newer versions
  • Added MPI runtime dependencies (mpi4py, mpich) as optional extras
  • Implemented a distributed MLP example with weight-stationary sharding strategies
  • Extended lit test infrastructure to detect and run MPI-enabled tests

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
pyproject.toml Updated dependency versions and added runtime_mpi optional dependency group
lit.cfg.py Extended test detection to support MPI packages and added VIRTUAL_ENV substitution
lighthouse/workload/runner.py Modified shared library path resolution to support absolute paths
examples/mlp-mpi/mlp_weight_stationary.mlir New MLIR template for distributed MLP with sharding annotations
examples/mlp-mpi/mlp-mpi.py New Python implementation of distributed MLP workload with MPI execution
examples/mlp-mpi/README.md Documentation for running the MPI-based MLP example
.github/workflows/examples.yml Added CI workflow step for MPI-enabled examples
Comments suppressed due to low confidence (1)

examples/mlp-mpi/mlp-mpi.py:1

  • The hardcoded library name 'libmpi.so.12' is brittle and will fail on systems with different MPI versions or implementations (e.g., libmpi.so.40 for newer MPICH, or different naming for OpenMPI). Consider using a more flexible approach such as detecting the library at runtime or using a configurable variable.
# REQUIRES: mpi4py

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 190 to 191
self.M * self.N * self.K * 2 + self.M * self.K * 4
) # matmuls + sigmoid
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The FLOP count calculation is incorrect. The MLP performs two matmuls: A@B (M×K by K×N = 2MNK FLOPs) and result@C (M×N by N×K = 2MNK FLOPs), totaling 4MNK FLOPs. The sigmoid operation on M×N elements requires approximately 5MN FLOPs (not 4MK). The formula should be '4 * self.M * self.N * self.K + 5 * self.M * self.N'.

Suggested change
self.M * self.N * self.K * 2 + self.M * self.K * 4
) # matmuls + sigmoid
4 * self.M * self.N * self.K + 5 * self.M * self.N
) # 2 matmuls (4MNK) + sigmoid (≈5MN)

Copilot uses AI. Check for mistakes.
Copy link
Contributor

@rolfmorel rolfmorel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks fine to me.

The CI changes look good. Maybe @tkarna can check the workload usage and a bit more detailed look at the shard & mpi specifics.

Comment on lines 359 to 373
# rprint(" Execute 2 ".center(60, "-"))
# execute(wload, verbose=1)

# rprint(" Benchmark ".center(60, "-"))
# times = benchmark(wload)
# times *= 1e6 # convert to microseconds
# compute statistics
# mean = np.mean(times)
# min = np.min(times)
# max = np.max(times)
# std = np.std(times)
# rprint(f"Timings (us): mean={mean:.2f}+/-{std:.2f} min={min:.2f} max={max:.2f}")
# flop_count = wload.get_complexity()[0]
# gflops = flop_count / (mean * 1e-6) / 1e9
# rprint(f"Throughput: {gflops:.2f} GFLOPS")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# rprint(" Execute 2 ".center(60, "-"))
# execute(wload, verbose=1)
# rprint(" Benchmark ".center(60, "-"))
# times = benchmark(wload)
# times *= 1e6 # convert to microseconds
# compute statistics
# mean = np.mean(times)
# min = np.min(times)
# max = np.max(times)
# std = np.std(times)
# rprint(f"Timings (us): mean={mean:.2f}+/-{std:.2f} min={min:.2f} max={max:.2f}")
# flop_count = wload.get_complexity()[0]
# gflops = flop_count / (mean * 1e-6) / 1e9
# rprint(f"Throughput: {gflops:.2f} GFLOPS")

Comment on lines 310 to 311
# test_analysis_only=True,
# print_conflicts=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# test_analysis_only=True,
# print_conflicts=True,

"split_mm1_c": "[[], [0]]",
}
)
txt = txt.format_map(format_values)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to prefer this textual manipulation over generating the payload with Python and having the payload-generating function be suitably parameterized?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am fed up with using the python mlir builders. It is more work than writing plain MLIR.

Comment on lines 310 to 311
# test_analysis_only=True,
# print_conflicts=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# test_analysis_only=True,
# print_conflicts=True,

Comment on lines 359 to 373
# rprint(" Execute 2 ".center(60, "-"))
# execute(wload, verbose=1)

# rprint(" Benchmark ".center(60, "-"))
# times = benchmark(wload)
# times *= 1e6 # convert to microseconds
# compute statistics
# mean = np.mean(times)
# min = np.min(times)
# max = np.max(times)
# std = np.std(times)
# rprint(f"Timings (us): mean={mean:.2f}+/-{std:.2f} min={min:.2f} max={max:.2f}")
# flop_count = wload.get_complexity()[0]
# gflops = flop_count / (mean * 1e-6) / 1e9
# rprint(f"Throughput: {gflops:.2f} GFLOPS")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# rprint(" Execute 2 ".center(60, "-"))
# execute(wload, verbose=1)
# rprint(" Benchmark ".center(60, "-"))
# times = benchmark(wload)
# times *= 1e6 # convert to microseconds
# compute statistics
# mean = np.mean(times)
# min = np.min(times)
# max = np.max(times)
# std = np.std(times)
# rprint(f"Timings (us): mean={mean:.2f}+/-{std:.2f} min={min:.2f} max={max:.2f}")
# flop_count = wload.get_complexity()[0]
# gflops = flop_count / (mean * 1e-6) / 1e9
# rprint(f"Throughput: {gflops:.2f} GFLOPS")

Copy link
Contributor

@tkarna tkarna left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates. This is a neat example!

@fschlimb fschlimb merged commit 3336c60 into llvm:main Feb 23, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants