Conversation
|
For CI, I think you will want: Add an optional Python dependency group for Lines 18 to 22 in d32d1cf Add a Lines 23 to 24 in d32d1cf Add a lighthouse/examples/llama/test_llama3.py Lines 1 to 2 in d32d1cf If the CI machines needs to have certain (system) libraries, you will want to modify the following file: lighthouse/.github/workflows/examples.yml Line 29 in d32d1cf Note that if Ubuntu packages exist, this is easy. If not, this becomes quite a bit more involved. |
|
@rolfmorel @tkarna @rengolin this is ready for review Once llvm/llvm-project#180962 has been merged, we should update the mlir dep and also run the 2d-grid case in CI. |
There was a problem hiding this comment.
Pull request overview
This PR integrates MPI-based distributed computing capabilities into the MLIR infrastructure, enabling multi-rank execution of neural network operations. It adds support for running a Multi-Layer Perceptron (MLP) across multiple MPI processes with different sharding strategies.
Changes:
- Updated MLIR and torch-mlir dependencies to newer versions
- Added MPI runtime dependencies (mpi4py, mpich) as optional extras
- Implemented a distributed MLP example with weight-stationary sharding strategies
- Extended lit test infrastructure to detect and run MPI-enabled tests
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| pyproject.toml | Updated dependency versions and added runtime_mpi optional dependency group |
| lit.cfg.py | Extended test detection to support MPI packages and added VIRTUAL_ENV substitution |
| lighthouse/workload/runner.py | Modified shared library path resolution to support absolute paths |
| examples/mlp-mpi/mlp_weight_stationary.mlir | New MLIR template for distributed MLP with sharding annotations |
| examples/mlp-mpi/mlp-mpi.py | New Python implementation of distributed MLP workload with MPI execution |
| examples/mlp-mpi/README.md | Documentation for running the MPI-based MLP example |
| .github/workflows/examples.yml | Added CI workflow step for MPI-enabled examples |
Comments suppressed due to low confidence (1)
examples/mlp-mpi/mlp-mpi.py:1
- The hardcoded library name 'libmpi.so.12' is brittle and will fail on systems with different MPI versions or implementations (e.g., libmpi.so.40 for newer MPICH, or different naming for OpenMPI). Consider using a more flexible approach such as detecting the library at runtime or using a configurable variable.
# REQUIRES: mpi4py
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
examples/mlp-mpi/mlp-mpi.py
Outdated
| self.M * self.N * self.K * 2 + self.M * self.K * 4 | ||
| ) # matmuls + sigmoid |
There was a problem hiding this comment.
The FLOP count calculation is incorrect. The MLP performs two matmuls: A@B (M×K by K×N = 2MNK FLOPs) and result@C (M×N by N×K = 2MNK FLOPs), totaling 4MNK FLOPs. The sigmoid operation on M×N elements requires approximately 5MN FLOPs (not 4MK). The formula should be '4 * self.M * self.N * self.K + 5 * self.M * self.N'.
| self.M * self.N * self.K * 2 + self.M * self.K * 4 | |
| ) # matmuls + sigmoid | |
| 4 * self.M * self.N * self.K + 5 * self.M * self.N | |
| ) # 2 matmuls (4MNK) + sigmoid (≈5MN) |
examples/mlp-mpi/mlp-mpi.py
Outdated
| # rprint(" Execute 2 ".center(60, "-")) | ||
| # execute(wload, verbose=1) | ||
|
|
||
| # rprint(" Benchmark ".center(60, "-")) | ||
| # times = benchmark(wload) | ||
| # times *= 1e6 # convert to microseconds | ||
| # compute statistics | ||
| # mean = np.mean(times) | ||
| # min = np.min(times) | ||
| # max = np.max(times) | ||
| # std = np.std(times) | ||
| # rprint(f"Timings (us): mean={mean:.2f}+/-{std:.2f} min={min:.2f} max={max:.2f}") | ||
| # flop_count = wload.get_complexity()[0] | ||
| # gflops = flop_count / (mean * 1e-6) / 1e9 | ||
| # rprint(f"Throughput: {gflops:.2f} GFLOPS") |
There was a problem hiding this comment.
| # rprint(" Execute 2 ".center(60, "-")) | |
| # execute(wload, verbose=1) | |
| # rprint(" Benchmark ".center(60, "-")) | |
| # times = benchmark(wload) | |
| # times *= 1e6 # convert to microseconds | |
| # compute statistics | |
| # mean = np.mean(times) | |
| # min = np.min(times) | |
| # max = np.max(times) | |
| # std = np.std(times) | |
| # rprint(f"Timings (us): mean={mean:.2f}+/-{std:.2f} min={min:.2f} max={max:.2f}") | |
| # flop_count = wload.get_complexity()[0] | |
| # gflops = flop_count / (mean * 1e-6) / 1e9 | |
| # rprint(f"Throughput: {gflops:.2f} GFLOPS") |
examples/mlp-mpi/mlp-mpi.py
Outdated
| # test_analysis_only=True, | ||
| # print_conflicts=True, |
There was a problem hiding this comment.
| # test_analysis_only=True, | |
| # print_conflicts=True, |
| "split_mm1_c": "[[], [0]]", | ||
| } | ||
| ) | ||
| txt = txt.format_map(format_values) |
There was a problem hiding this comment.
Any reason to prefer this textual manipulation over generating the payload with Python and having the payload-generating function be suitably parameterized?
There was a problem hiding this comment.
I am fed up with using the python mlir builders. It is more work than writing plain MLIR.
examples/mlp-mpi/mlp-mpi.py
Outdated
| # test_analysis_only=True, | ||
| # print_conflicts=True, |
There was a problem hiding this comment.
| # test_analysis_only=True, | |
| # print_conflicts=True, |
examples/mlp-mpi/mlp-mpi.py
Outdated
| # rprint(" Execute 2 ".center(60, "-")) | ||
| # execute(wload, verbose=1) | ||
|
|
||
| # rprint(" Benchmark ".center(60, "-")) | ||
| # times = benchmark(wload) | ||
| # times *= 1e6 # convert to microseconds | ||
| # compute statistics | ||
| # mean = np.mean(times) | ||
| # min = np.min(times) | ||
| # max = np.max(times) | ||
| # std = np.std(times) | ||
| # rprint(f"Timings (us): mean={mean:.2f}+/-{std:.2f} min={min:.2f} max={max:.2f}") | ||
| # flop_count = wload.get_complexity()[0] | ||
| # gflops = flop_count / (mean * 1e-6) / 1e9 | ||
| # rprint(f"Throughput: {gflops:.2f} GFLOPS") |
There was a problem hiding this comment.
| # rprint(" Execute 2 ".center(60, "-")) | |
| # execute(wload, verbose=1) | |
| # rprint(" Benchmark ".center(60, "-")) | |
| # times = benchmark(wload) | |
| # times *= 1e6 # convert to microseconds | |
| # compute statistics | |
| # mean = np.mean(times) | |
| # min = np.min(times) | |
| # max = np.max(times) | |
| # std = np.std(times) | |
| # rprint(f"Timings (us): mean={mean:.2f}+/-{std:.2f} min={min:.2f} max={max:.2f}") | |
| # flop_count = wload.get_complexity()[0] | |
| # gflops = flop_count / (mean * 1e-6) / 1e9 | |
| # rprint(f"Throughput: {gflops:.2f} GFLOPS") |
tkarna
left a comment
There was a problem hiding this comment.
Thanks for the updates. This is a neat example!
This demonstrates the distributed sharding infrastucture of MLIR unsig a single MLP.
Lowering sharding annotations to MPI.
Currently using the lower end of the pipeline, sharding-propagation is not fit for it yet.
2 different sharding policies can be used by selecting a 1d or a 2d device grid.
Note: some fixes to make this work have not yet landed on MLIR main.
@rolfmorel @tkarna @rengolin Is there any CI? If so, what is the recommended way of integrating this?