Skip to content

Comments

Add simulation backends and improve performance projection accuracy#560

Open
araina-amd wants to merge 8 commits intomainfrom
dev/araina/simulation-backends-and-projection-improvements
Open

Add simulation backends and improve performance projection accuracy#560
araina-amd wants to merge 8 commits intomainfrom
dev/araina/simulation-backends-and-projection-improvements

Conversation

@araina-amd
Copy link
Contributor

This PR adds GPU-free simulation backends (Origami for GEMMs, analytical tile-level model for Flash Attention) and significantly improves the accuracy of the performance projection tool. Across 11 benchmarked workloads (LLaMA-2/3.1/3.3, Qwen-2.5, Mixtral, DeepSeek-V3), the projected TGS is within ~10% of measured results for 9/11 models.

Key Changes

  1. Simulation Backends (new simulation_backends/ package)
    Origami GEMM Backend (origami_backend.py): Wraps the Origami analytical GEMM model with MI300X/MI325X hardware profiles, clock frequency override, and support for batched GEMM.
    SDPA Simulator (sdpa_simulator.py): Tile-level Flash Attention v3 model using Origami on a single CU. Simulates per-workgroup GEMMs (QKᵀ, PV for forward; 5 GEMMs for backward) and adds dQ atomic overhead additively. Captures wave quantization, LDS traffic, and pipeline effects that a global roofline model misses.
    Factory (factory.py): Provides get_gemm_simulation_backend() and get_sdpa_simulation_backend() factory functions.
    Adds --profiling-mode simulate for fully GPU-free performance projection.

  2. Multi-Latent Attention (MLA) Support
    Dedicated _simulate_mla_gemms() in the attention profiler models the 6 forward + 12 backward GEMMs specific to DeepSeek-V3's MLA with LoRA-factored Q and compressed KV projections.
    MLA-aware SDPA simulation with separate D_qk / D_v head dimensions.

  3. MoE MLP Improvements
    Two-mode grouped GEMM model: Turbo (batched GEMM via batch=num_experts) vs Legacy (sequential per-expert GEMM × num_local_experts), selected via enable_primus_turbo + use_turbo_grouped_mlp config flags.
    Adds router overhead, token permutation, and activation function overhead estimation.
    Explicit dgrad + wgrad backward GEMM simulation replaces the 2×forward approximation.

  4. Collective Communication Model Fixes
    A2A message size correction for TP>1: A2A operates on S/TP tokens per GPU due to sequence parallelism; dispatch_size is now divided by max(tp, 1).
    TP-striding in A2A: single_shot_alltoall, hierarchical_alltoall, and pxn_alltoall now account for hp (tensor parallelism) when calculating effective EP ranks per node.
    PXN All-to-All: Added support for pipelined scale-up/scale-out A2A when DeepEP is enabled.
    FSDP overlap model: Single uniform 93% overlap factor (observed on LLaMA-3 70B) applied to all FSDP communication (AllGather forward, AllGather recompute, ReduceScatter). Doubles AllGather count when recompute_granularity='full'.

  5. Other Improvements
    --gpu-clock-mhz CLI override adjusts Origami and SDPA peak TFLOPS accordingly.
    TP AllReduce and MoE A2A overhead estimation in transformer layer simulation.
    EP MLP scaling uses delta-based approach preserving profiled layer components.

araina-amd and others added 3 commits February 13, 2026 15:28
- Add Origami GEMM simulation backend with MI300X/MI325X hardware profiles
- Add FAv3 analytical SDPA simulator with roofline model, atomic overhead
  modeling for backward pass, and GQA/MQA support
- Add simulation mode (--profiling-mode simulate) for GPU-free performance
  projection using Origami for GEMMs and analytical model for Flash Attention
- Wire simulation backends into all module profilers (attention, dense MLP,
  MoE MLP, embedding, output layer, transformer layer)
- Add --gemm-backend and --gpu-arch CLI arguments for simulation control
- Fix FSDP communication model to double AllGather count when
  recompute_granularity='full' is enabled
…and intra-node A2A overhead

- Add per-phase FSDP overlap model (90% forward AG, 24% backward AG, 34% RS)
- Add PXN All-to-All algorithm for DeepEP with pipelined scale-up/scale-out
- Enable PXN automatically when DeepEP is detected (moe_enable_deepep or use_turbo_deepep)
- Adjust intra-node A2A overhead to 28 us per peer (from preflight measurements)
- Separate intra-node and inter-node A2A overhead modeling
Returns:
SimulationResult with forward_time_ms and backward_time_ms.
"""
fwd_time = 0.0
SimulationResult with forward_time_ms and backward_time_ms.
"""
fwd_time = 0.0
bwd_time = 0.0
f" ({num_local_experts} local experts, M={M}, H={H}, F={F})")

expert_fwd_ms = 0.0
expert_bwd_ms = 0.0
"fsdp_ag_multiplier", 1 + recomp_ratio
)
fwd_ag_total = total_fsdp_ag / ag_multiplier_val
bwd_ag_total = total_fsdp_ag - fwd_ag_total
fwd_ag_total = total_fsdp_ag / ag_multiplier_val
bwd_ag_total = total_fsdp_ag - fwd_ag_total
else:
fwd_ag_total = total_fsdp_ag
bwd_ag_total = total_fsdp_ag - fwd_ag_total
else:
fwd_ag_total = total_fsdp_ag
bwd_ag_total = 0.0
# See LICENSE for license information.
###############################################################################

import math
scale_out_msg_size = int(original_msg_size * (num_nodes - 1) / num_nodes)

# Scale-up delay: time to accumulate 4MB before scale-out starts
scaleup_delay = 0.0
)

expert_fwd_ms = 0.0
expert_bwd_ms = 0.0
"for Flash Attention"
)
return backend
except Exception:
… MoE overhead modeling

- Add Multi-Latent Attention (MLA) GEMM simulation with LoRA-factored Q
  and compressed KV projections (6 fwd + 12 bwd GEMMs) in attention profiler
- Add MLA-aware SDPA simulation with split D_qk/D_v head dimensions
- Replace 2x-forward backward approximation with explicit dgrad + wgrad
  GEMM simulation in attention, MLP, MoE, and output layer profilers
- Add batched GEMM support (batch param) for Turbo grouped-GEMM modeling
  vs legacy sequential per-expert execution in MoE MLP profiler
- Add router overhead, token permutation, and activation function overhead
  estimation in MoE MLP simulation
- Add TP AllReduce and MoE All-to-All communication overhead estimation
  in transformer layer simulation mode
- Switch EP MLP scaling to delta-based approach preserving profiled layer
  components (TP AR, A2A, norms) and EP-invariant routed compute model
- Add enable_primus_turbo and use_turbo_grouped_mlp config flags
…level simulation

Replace the analytical roofline SDPA simulator with an Origami-based
tile-level model that simulates each per-workgroup GEMM on a single CU.
This captures wave quantisation, LDS traffic, and pipeline effects that
the global max(compute, memory) roofline missed, eliminating the need
for empirical compute_efficiency / memory_efficiency parameters.
…model

- A2A operates on S/TP tokens per GPU due to sequence parallelism.
  Divide dispatch_size by max(tp, 1) in both _estimate_ep_communication_overhead
  and calculate_collective_communication_time.
- Account for TP striding (hp) in single_shot_alltoall, hierarchical_alltoall,
  and pxn_alltoall: effective EP ranks per node = node_size / hp.
Replace the three separate per-phase overlap percentages (FWD_AG=90%,
BWD_AG=24%, RS=34%) with a single uniform 93% overlap applied to all
FSDP communication (AllGather fwd, AllGather recompute, ReduceScatter).

93% is observed on the actual run for llama3-70b.
Fix formatting issues flagged by CI code-lint job.
@araina-amd araina-amd force-pushed the dev/araina/simulation-backends-and-projection-improvements branch from 8aa9fbb to c73dbd8 Compare February 21, 2026 02:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant