Add simulation backends and improve performance projection accuracy#560
Open
araina-amd wants to merge 8 commits intomainfrom
Open
Add simulation backends and improve performance projection accuracy#560araina-amd wants to merge 8 commits intomainfrom
araina-amd wants to merge 8 commits intomainfrom
Conversation
- Add Origami GEMM simulation backend with MI300X/MI325X hardware profiles - Add FAv3 analytical SDPA simulator with roofline model, atomic overhead modeling for backward pass, and GQA/MQA support - Add simulation mode (--profiling-mode simulate) for GPU-free performance projection using Origami for GEMMs and analytical model for Flash Attention - Wire simulation backends into all module profilers (attention, dense MLP, MoE MLP, embedding, output layer, transformer layer) - Add --gemm-backend and --gpu-arch CLI arguments for simulation control - Fix FSDP communication model to double AllGather count when recompute_granularity='full' is enabled
…and intra-node A2A overhead - Add per-phase FSDP overlap model (90% forward AG, 24% backward AG, 34% RS) - Add PXN All-to-All algorithm for DeepEP with pipelined scale-up/scale-out - Enable PXN automatically when DeepEP is detected (moe_enable_deepep or use_turbo_deepep) - Adjust intra-node A2A overhead to 28 us per peer (from preflight measurements) - Separate intra-node and inter-node A2A overhead modeling
…p estimation, and --gpu-clock-mhz override
| Returns: | ||
| SimulationResult with forward_time_ms and backward_time_ms. | ||
| """ | ||
| fwd_time = 0.0 |
| SimulationResult with forward_time_ms and backward_time_ms. | ||
| """ | ||
| fwd_time = 0.0 | ||
| bwd_time = 0.0 |
| f" ({num_local_experts} local experts, M={M}, H={H}, F={F})") | ||
|
|
||
| expert_fwd_ms = 0.0 | ||
| expert_bwd_ms = 0.0 |
| "fsdp_ag_multiplier", 1 + recomp_ratio | ||
| ) | ||
| fwd_ag_total = total_fsdp_ag / ag_multiplier_val | ||
| bwd_ag_total = total_fsdp_ag - fwd_ag_total |
| fwd_ag_total = total_fsdp_ag / ag_multiplier_val | ||
| bwd_ag_total = total_fsdp_ag - fwd_ag_total | ||
| else: | ||
| fwd_ag_total = total_fsdp_ag |
| bwd_ag_total = total_fsdp_ag - fwd_ag_total | ||
| else: | ||
| fwd_ag_total = total_fsdp_ag | ||
| bwd_ag_total = 0.0 |
| # See LICENSE for license information. | ||
| ############################################################################### | ||
|
|
||
| import math |
… MoE overhead modeling - Add Multi-Latent Attention (MLA) GEMM simulation with LoRA-factored Q and compressed KV projections (6 fwd + 12 bwd GEMMs) in attention profiler - Add MLA-aware SDPA simulation with split D_qk/D_v head dimensions - Replace 2x-forward backward approximation with explicit dgrad + wgrad GEMM simulation in attention, MLP, MoE, and output layer profilers - Add batched GEMM support (batch param) for Turbo grouped-GEMM modeling vs legacy sequential per-expert execution in MoE MLP profiler - Add router overhead, token permutation, and activation function overhead estimation in MoE MLP simulation - Add TP AllReduce and MoE All-to-All communication overhead estimation in transformer layer simulation mode - Switch EP MLP scaling to delta-based approach preserving profiled layer components (TP AR, A2A, norms) and EP-invariant routed compute model - Add enable_primus_turbo and use_turbo_grouped_mlp config flags
…level simulation Replace the analytical roofline SDPA simulator with an Origami-based tile-level model that simulates each per-workgroup GEMM on a single CU. This captures wave quantisation, LDS traffic, and pipeline effects that the global max(compute, memory) roofline missed, eliminating the need for empirical compute_efficiency / memory_efficiency parameters.
…model - A2A operates on S/TP tokens per GPU due to sequence parallelism. Divide dispatch_size by max(tp, 1) in both _estimate_ep_communication_overhead and calculate_collective_communication_time. - Account for TP striding (hp) in single_shot_alltoall, hierarchical_alltoall, and pxn_alltoall: effective EP ranks per node = node_size / hp.
Replace the three separate per-phase overlap percentages (FWD_AG=90%, BWD_AG=24%, RS=34%) with a single uniform 93% overlap applied to all FSDP communication (AllGather fwd, AllGather recompute, ReduceScatter). 93% is observed on the actual run for llama3-70b.
Fix formatting issues flagged by CI code-lint job.
8aa9fbb to
c73dbd8
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR adds GPU-free simulation backends (Origami for GEMMs, analytical tile-level model for Flash Attention) and significantly improves the accuracy of the performance projection tool. Across 11 benchmarked workloads (LLaMA-2/3.1/3.3, Qwen-2.5, Mixtral, DeepSeek-V3), the projected TGS is within ~10% of measured results for 9/11 models.
Key Changes
Simulation Backends (new simulation_backends/ package)
Origami GEMM Backend (origami_backend.py): Wraps the Origami analytical GEMM model with MI300X/MI325X hardware profiles, clock frequency override, and support for batched GEMM.
SDPA Simulator (sdpa_simulator.py): Tile-level Flash Attention v3 model using Origami on a single CU. Simulates per-workgroup GEMMs (QKᵀ, PV for forward; 5 GEMMs for backward) and adds dQ atomic overhead additively. Captures wave quantization, LDS traffic, and pipeline effects that a global roofline model misses.
Factory (factory.py): Provides get_gemm_simulation_backend() and get_sdpa_simulation_backend() factory functions.
Adds --profiling-mode simulate for fully GPU-free performance projection.
Multi-Latent Attention (MLA) Support
Dedicated _simulate_mla_gemms() in the attention profiler models the 6 forward + 12 backward GEMMs specific to DeepSeek-V3's MLA with LoRA-factored Q and compressed KV projections.
MLA-aware SDPA simulation with separate D_qk / D_v head dimensions.
MoE MLP Improvements
Two-mode grouped GEMM model: Turbo (batched GEMM via batch=num_experts) vs Legacy (sequential per-expert GEMM × num_local_experts), selected via enable_primus_turbo + use_turbo_grouped_mlp config flags.
Adds router overhead, token permutation, and activation function overhead estimation.
Explicit dgrad + wgrad backward GEMM simulation replaces the 2×forward approximation.
Collective Communication Model Fixes
A2A message size correction for TP>1: A2A operates on S/TP tokens per GPU due to sequence parallelism; dispatch_size is now divided by max(tp, 1).
TP-striding in A2A: single_shot_alltoall, hierarchical_alltoall, and pxn_alltoall now account for hp (tensor parallelism) when calculating effective EP ranks per node.
PXN All-to-All: Added support for pipelined scale-up/scale-out A2A when DeepEP is enabled.
FSDP overlap model: Single uniform 93% overlap factor (observed on LLaMA-3 70B) applied to all FSDP communication (AllGather forward, AllGather recompute, ReduceScatter). Doubles AllGather count when recompute_granularity='full'.
Other Improvements
--gpu-clock-mhz CLI override adjusts Origami and SDPA peak TFLOPS accordingly.
TP AllReduce and MoE A2A overhead estimation in transformer layer simulation.
EP MLP scaling uses delta-based approach preserving profiled layer components.