Skip to content

[HybridEP] Support hybridEP for GB200 with NVL72#2207

Open
elfiegg wants to merge 2 commits intopytorch:mainfrom
elfiegg:hybrid
Open

[HybridEP] Support hybridEP for GB200 with NVL72#2207
elfiegg wants to merge 2 commits intopytorch:mainfrom
elfiegg:hybrid

Conversation

@elfiegg
Copy link
Contributor

@elfiegg elfiegg commented Jan 8, 2026

Co-authored with @vivekgoe and Big shout out to Tong Liu(@Autumn1998) for HybridEP support!

Please click here for Design doc.

Summary

This PR integrates HybridEP, specifically optimized for non-NVL8 systems.

Key Changes

Kernel Fusion: Fused the communication and permutation steps into a single operation to reduce overhead and memory pressure.

Non-NVL8 functionality: Able to handle EP all-to-all with non-standard NVL8 systems (up to NVL72), by adjusting env var: export NUM_OF_HYBRID_EP_RANKS_PER_NVLINK_DOMAIN=$EP_PARALLEL_SIZE and export USE_MNNVL=1

Performance Impact

Click to view Model Configuration Details
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True; \
export NUM_OF_HYBRID_EP_RANKS_PER_NVLINK_DOMAIN=64; \
export USE_MNNVL=1; \
python -m $TRAIN_FILE \
    --job.config_file /torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml \
    --training.steps=130\
    --training.dataset_path=$DATASET_PATH \
    --profiling.no-enable_profiling \
    --comm.init_timeout_seconds=1000 \
    --comm.train_timeout_seconds=300 \
    --profiling.save_traces_folder $MOE_BACKEND \
    --parallelism.data_parallel_shard_degree=-1\
    --parallelism.expert_parallel_degree=64\
    --parallelism.pipeline_parallel_degree=1\
    --training.local_batch_size=8\
    --compile.enable \
    --compile.components=loss \
    --parallelism.expert_parallel_comm_backend=$MOE_BACKEND
Category Optimization TPS Speedup
Baseline None 102.86 1.0x
Communication Reduce Pipeline Parallelism Degree to 1 (with full activation checkpointing) + EP=64 269.11 2.61x
Communication Enable HybridEP 352.18 1.31x
Attention Update SDPA to cuDNN for SM100 (from CUTLASS fmha SM80, already upstreamed) 576.32 1.64x
GEMM Use MXFP8 precision for GeMM (torch.compile to fuse quantization kernels) 743.55 1.29x
GEMM Reduce D2H sync for grouped_gemm with forced balancing 829.72 1.12x
GEMM Use TE's multi-stream MXFP8 cuBLAS grouped-gemm kernel 1067.54 1.29x

Accuracy Test

Tested on DeepseekV3-16B:
Screenshot 2026-01-14 at 2 57 57 PM

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 8, 2026
@elfiegg elfiegg force-pushed the hybrid branch 2 times, most recently from e4472fe to 0ec5d05 Compare January 13, 2026 00:34
@elfiegg elfiegg changed the title [Draft][HybridEP] Support hybridEP for GB200 with NVL72 [HybridEP] Support hybridEP for GB200 with NVL72 Jan 13, 2026
@elfiegg elfiegg marked this pull request as ready for review January 16, 2026 00:03
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This looks interesting optimization. I have several naive questions:

  1. Could you educate us on what HybridEP is? Please share literature / resources on this if possible.
  2. Given the complexity of this PR, does it make sense to have a mini design doc / chart? I personally find it hard to follow, possibly due to my missing context. Nevertheless, I believe there is room for improvement in terms of code organization.
  3. There seem to be undocumented use of envvars like NUM_OF_HYBRID_EP_RANKS_PER_NVLINK_DOMAIN and USE_MNNVL. Could you define what they are and detail the usage?

cc @shuhuayu @yuankaichen-amd

"""
Number of SMs used by the HybridEP dispatch API.
Only used when expert_parallel_comm_backend is "hybridep".
This is configured by models behind the scene and not exposed to users.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If they're not supposed to be configured by users, can we fix them for now?


# Global cache for dispatch handles, keyed by cache_id
# SAC saves the cache_id tensor; we use it to retrieve the non-tensor handle
_backend_mode: Literal["deepep", "hybridep"] = "deepep"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this will be configured by users, we probably shouldn't set default here.


# Mask out zero-score tokens
selected_experts_indices = selected_experts_indices.masked_fill(top_scores == 0, -1)
if _backend_mode == "hybridep":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having such if-else at multiple places seem to suggest that we should introduce a separate class, e.g. HybridEPExpertParallel, from a object-oriented perspective.

@tianyu-l tianyu-l requested a review from shuhuayu January 19, 2026 10:01
@yuankaichen-amd
Copy link

yuankaichen-amd commented Jan 19, 2026

If I understand correctly, Hybrid-EP is referring to: https://github.com/deepseek-ai/DeepEP/blob/hybrid-ep/Hybrid-EP_Implementation.md

It is an improvement over the standard DeepEP by using Nvidia hardware features (e.g. TMA), particularly for NVL72 use cases. It is a great work, but it seems Hybrid-EP is not yet a part of the standard DeepEP library (I could only find it from a separate branch). And its current interfaces are quite different from the DeepEP's.

It would be ideal if Hybrid-EP's interfaces can align with DeepEP's. If that is not possible, a separate file (e.g. hybrid_ep.py) and a separate Class as @tianyu-l suggests would be preferrable. Also Nvidia-special variables, such as hybridep_num_sms_dispatch_api, should not be included in job_config.py. I think the user can configure those through environment variables.

@shuhuayu shuhuayu self-assigned this Jan 21, 2026
@elfiegg elfiegg force-pushed the hybrid branch 3 times, most recently from f4e7790 to 546f6af Compare January 22, 2026 04:14
@elfiegg
Copy link
Contributor Author

elfiegg commented Jan 22, 2026

Hello! Design doc is here, please leave thoughts/comments: https://docs.google.com/document/d/1i8zlu-3S2psDztKc9hHuyRAu2tsskDRBmMl3egDo7Vo/edit?usp=sharing

I also have made hybridep.py a separate backend file in this PR

_backend_mode: Literal["deepep", "hybridep"] = "deepep"


def configure_backend(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of having a separate configure_backend to patch from outside DeepEPExpertParallel, I think a cleaner way is to pass comm_backend to DeepEPExpertParallel constructor, and dynamically choose dispatch (or combine, similarly) backend here https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/expert_parallel.py#L347

Comment on lines 440 to 441
# Save handle now to survive activation checkpointing recompute
ctx.saved_handle = _get_cached_handle(ctx.cache_id_int)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea of stashing on global cache seems fine. If we care about compile support, having the handle, a nonproxyable object as input/output, is a pretty fundamental issue anyway, so there's unlikely anything better. Some things we might need to eventually consider is properly marking these custom ops as having side effect, so that the compiler would not be able to reorder arbitrarily. Maybe that is not in scope for this PR though.

Some comments on the impl:

Custom op outputs must be tensors to flow through autograd

We don't actually care about backwarding through the cache_id output, so could it just be a python int instead of having to create / .item() a cpu tensor?

ctx.saved_handle where it survives activation checkpointing recompute
Cache pop is lenient: AC recompute may create new cache_ids, missing pops are expected and safe

I would expect AC recompute to not create new cache_ids because you should save the cache_id for backward, and, prior to trying to create a new cache_id you should first check if the existing entry already exists for the saved cache_id?

Handle saved in setup_context: The handle is immediately copied to ctx.saved_handle where it survives activation checkpointing recompute

You might've already changed this in your latest update, but I would've assumed that there would be a global cache from id: int -> handle, so there's no need to store it on the ctx.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@soulitzer Thank you for illustrating the design principles! Responses to the suggestions:

We don't actually care about backwarding through the cache_id output, so could it just be a python int instead of having to create / .item() a cpu tensor?

My understanding is that custom op requires a tensor output; so if we switch to an int, we'll need the cache_id to be outside of (be the input of) the custom op. Totally doable.

I would expect AC recompute to not create new cache_ids because you should save the cache_id for backward, and, prior to trying to create a new cache_id you should first check if the existing entry already exists for the saved cache_id?

My understanding is the reason to create a new cache_ids in AC recompute is, unless we can retrieve the current recompute for a pass via a persistent mechanism e.g. involve querying model layer, the sequence of the current all-to-all in the layer etc, we need to regenerated a cache_id anyways or else for AC pass the cache_id would slightly change anyways due to inconsistent hashing - maybe I'm missing some info here, please advise.

You might've already changed this in your latest update, but I would've assumed that there would be a global cache from id: int -> handle, so there's no need to store it on the ctx.

Makes sense. Originally it was an intermediate implementation where I tried reusing cache_id for AC recompute (thus cache_id would overwrite handle, thus the original handle had to be saved on ctx) but at that time I used input_buffer data pointer for retrieving cache_id and it did not turn to be reliable. But I agree with different cache_id for each pass in the current implementation, it's redundant anyways

Copy link
Contributor

@soulitzer soulitzer Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea of stashing on global cache seems fine. If we care about compile support, having the handle, a nonproxyable object as input/output, is a pretty fundamental issue anyway, so there's unlikely anything better.

Thanks for your responses!
I actually want to take back my earlier comment - I'm not sure we should use the cache at all anymore. I didn't realize it existed earlier as it is new and still being actively worked on, but I think we should just try to make it possible to explicitly take and return the handle as input and output via the new opaque type APIs. (This should be possible if we can treat handle as a constant and never mutate it). See example usage in: https://gist.github.com/soulitzer/6765b99dfce77ece17192e008d290ca2
The example in that gist works in eager, but compile requires an additional patch because currently the partitioner today does not support saving these opaque objects. Still investigating on how to land that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@soulitzer incorporated the suggestion and the interface now looks much cleaner! Let me know if it looks good to you

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update! It looks good to me (for eager), but also tagging @angelayi who knows more about opaque types.

My notes for compile (Angela correct me if I'm wrong):

  • register_opaque_type(typ="value") type is probably not what we want here because handle secretly holds data like tensors, that is not reconstructable from a str (I guess you could in theory decompose the handle in to a bunch of tensors and struct, and actually have that struct be the "value" type, but maybe not the best to couple directly w/ the internal representation of it?)
  • register_opaque_type(typ="reference") type is probably what you want, even though we don't actually need to mutate things here, but it is not officially supported to create reference types in the graph - probably because inductor fallback kernels don't support non-tensors as outputs, but this doesn't seem fundamental, so we should just fix this?

Comment on lines 53 to 54
We use CPU tensors for cache_id to avoid GPU-CPU synchronization when
retrieving the ID value in setup_context.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might lack context for HybridEP in general, so would love to learn more. My intuition for dropless EP in general is the following -- please correct me if I'm wrong.

  1. The EP buffer will be used in two ways: (A) to send/receive tokens in a single MoE layer, and (B) to save all-to-all results to avoid recomputation during backward, at least partially save.
  2. Usually we can allow (A) to be the maximum possible. However, we can't allow (B) to be the maximum, because all MoE layers added up together will cause OOM for sure.
  3. (B) is the essential reason why we need D2H sync. I've seen works to determine how much to save/recompute depending on the size -- if not too much, just save; if too much, save some and recompute the rest, etc.

With HybridEP, I wonder what's the intuition and where does the reasoning break so that we no longer need D2H sync? Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good topic to discuss! My understanding - Let's denote N = total tokens per EP rank, K = topK

  1. The need for D2H sync is mostly because the dynamic output sizes and offsets(tokens per experts, num_tokens_per_rank, etc) are data dependent - the host still needs to wait all tokens to arrive to know the exact buffer size, and to read a small GPU tensor for token counts via .item() which forces a device-to-host sync. HybridEP eliminates D2H sync, due to:
    a). HybridEP pre‑allocates a global all-to-all buffer sized for the worst case: max_tokens_to_receive = N * K * capacity_factor per rank, and uses that for all layers. This is memory‑heavy for dropless setups (in which for the worst case, max_tokens_to_receive = N * EP_group_size), but the key is that it’s static and shared across layers, so shapes are known without D2H sync, on GB200 we usually can afford it, but load balanced routing for sure works better.
    b). HybridEP introduces num_permuted_tokens for reserving permutation output buffer for experts activations, this is per layer and can't be shared globally across layers - and this is actually where memory is heavily consumed, and most of time, many multiples of all-to-all buffer. The worst case output of permutation is N * EP_size * min(K, num_local_expert), where the all-to-all buffer is worst case N * EP_size.

  2. SAC computation - essentially orthogonal to the D2H sync; But it'd be great to leave user control when to save the dispatch/combine results.

Copy link
Contributor Author

@elfiegg elfiegg Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @Tong Liu for vis. Please feel free to correct

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

a). HybridEP pre‑allocates a global all-to-all buffer sized for the worst case

This part I can get. We allocate a central buffer that's (1) large enough, and (2) shared by all layers in fwd & bwd. It's large but not that large.

b). HybridEP introduces num_permuted_tokens for reserving permutation output buffer for experts activations, this is per layer and can't be shared globally across layers
The worst case output of permutation is N * EP_size * min(K, num_local_expert), where the all-to-all buffer is worst case N * EP_size.

This part I didn't get.

  • If we don't do recomputation, then the worst case memory overhead must have a factor of num_layers, which could be gigantic and there is no way we can store them for bwd computation. (Recall that we didn't do D2H sync so we cannot know the actual size.)
  • If we always recompute, that means we don't have memory pressure at all, but we need to re-do all the all-to-all's in fwd, so could be bad for throughput.

Since you put torch.ops.hybridep.dispatch.default in _op_sac_save_list, so it sounds the former, i.e. we should be effectively storing huge buffers.

Please correct me, thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your understanding looks good to me. One addition: num_permuted_tokens is optional. If it’s not provided, HybridEP has to infer the output size at runtime, which triggers a D2H sync. The original point above was that there are two distinct buffers we care about if we want to be completely D2H‑free and CUDA‑graph‑friendly:

  1. The all‑to‑all receive buffer (tokens after dispatch).
  2. The permuted expert buffer (tokens expanded by local experts for grouped GEMM).

HybridEP is designed with CUDA graph compatibility in mind: with a dense routing map and a known num_permuted_tokens, buffer sizes are fully determined on GPU with zero CPU interaction. DeepEP, in contrast, relies on pinned host metadata has to rely on periodic CPU check, which breaks fully graph‑captured execution.

If we size for the worst case under dropless routing, then we have two options:

  1. Always provision buffers for the upper bound N * EP_size * min(K, num_local_expert), and always recompute instead of saving activations, or
  2. Accept D2H sync to learn the actual token counts and only allocate what we really need.

The first option avoids D2H but is expensive and as you mentioned throughput unfriendly (and dropless itself is throughput‑unfriendly due to load imbalance and GPU jitter though?). The second option avoids recompute but pays a D2H tax for every layer.

If we instead optimize for the average case, we can:

  1. Estimate a realistic ceiling on total tokens per rank (e.g., via a capacity factor) and pass that to HybridEP as well as num_permuted_tokens.
    This ceiling is typically much smaller than the worst case N * EP_size * min(K, num_local_expert).
  2. In the ideal balanced case, the effective permuted tokens on a rank are on the order of N * min(K, num_local_expert); with D2H we’d also size close to this.
  3. We then pay a small memory tax to stay CUDA‑graphable and completely D2H‑free.

I expect this question will come up often. I'll add this explanation (and the worst‑case vs average‑case trade‑off) to the design doc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Yes - I haven't populated num_permuted_tokens field yet, but we can define an environment var so user can configure it (via estimate or prior experiment). There is still D2H sync as you can see from profiler there is cudaStreamSynchronize where CPU waits for GPU. It then use the read num_permuted_tokens to allocate output buffer for permute kernel inside of hybridEP torch op API.
Screenshot 2026-01-27 at 5 12 20 PM
  1. If the reserved buffer isn’t large enough for all tokens, we’ll end up corrupting memory, which can show up as illegal memory access errors. This is because doing a proper bounds check would require reading the buffer and a CPU-side counting pass. The same constraint actually applies to the other GPU kernels

  2. For now I’m using HYBRIDEP_CAPACITY_FACTOR for all-to-all buffer sizing; I can rename it to HYBRIDEP_BUFFER_SIZE_RATIO if that’s clearer. As mentioned earlier, we can also define another env var like NUM_TOKENS_PER_RANK_FOR_EXPERT (or other similar) to control num_permuted_tokens.

  3. I agree it’s better to split the phases. We can either keep the D2H sync in the earlier stage until the load_balance_coeff via auxiliary bias kicks in or we reserve large enough memory, and then switch to smaller buffer sizes. I can follow up on this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, makes a lot of sense!

I haven't populated num_permuted_tokens field yet, but we can define an environment var so user can configure it (via estimate or prior experiment).

I personally would prefer a config option, to make configuration less scattered (among config and envvar), if that makes sense to you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@elfiegg Thanks for the explanations! Learned a lot. Following the discussions here, it seems to me the d2h sync in hybridep is similar to the low latency mode in deepep, which also uses a pre-located buffer to avoid d2h sync and be compatible with cuda graph: https://github.com/deepseek-ai/DeepEP/blob/1a0a8bda09d627e67c787795aa2d984bd63dde27/csrc/deep_ep.cpp#L1015C9-L1015C29. In the current implementation of deepep, we used the normal mode so there are still d2h sync here: https://github.com/deepseek-ai/DeepEP/blob/1a0a8bda09d627e67c787795aa2d984bd63dde27/csrc/deep_ep.cpp#L394

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"If the reserved buffer isn’t large enough for all tokens, we’ll end up corrupting memory, which can show up as illegal memory access errors."

This only shows the IMA error in older versions. In the latest hybrid-ep, it silently drops the tokens, and the device-scalar overflow_flag in the returned handle is set to true.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shuhuayu yeah the low-latency mode seems to be a later development driven by inference cases, IIUC it won't be performant for throughput cases where training cares about
Here are some perf comparisons for DeepEP vs. HybridEP perf table click here but to notice that this might be based on an earlier commit of DeepEP that the stats might not be up to date.

@Autumn1998 thank you for the info. This is very helpful

@elfiegg
Copy link
Contributor Author

elfiegg commented Jan 29, 2026

Changed the env vars to configurable fields with doc string, and it's ready for review

Comment on lines 17 to 86
from typing import Any, Literal, Optional, Tuple

import torch
from torch.distributed import ProcessGroup


def dispatch_tokens(
hidden_states: torch.Tensor,
selected_experts_indices: torch.Tensor,
top_scores: torch.Tensor,
num_local_experts: int,
num_experts: int,
group: ProcessGroup,
score_before_experts: bool = True,
backend: Literal["deepep", "hybridep"] = "deepep",
# HybridEP-specific (ignored for DeepEP)
num_permuted_tokens: Optional[int] = None,
capacity_factor: float = 1.0,
num_sms_dispatch: int = 16,
num_sms_combine: int = 16,
pad_multiple: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Any]:
"""Dispatch tokens to experts via specified backend.

Returns: (permuted_hidden, tokens_per_expert, state)
"""
if backend == "hybridep":
from . import hybridep
return hybridep.dispatch_tokens(
hidden_states=hidden_states,
selected_experts_indices=selected_experts_indices,
top_scores=top_scores,
num_local_experts=num_local_experts,
num_experts=num_experts,
group=group,
score_before_experts=score_before_experts,
num_permuted_tokens=num_permuted_tokens,
capacity_factor=capacity_factor,
num_sms_dispatch=num_sms_dispatch,
num_sms_combine=num_sms_combine,
pad_multiple=pad_multiple,
)
else:
from .deepep import dispatch_tokens as _dispatch
return _dispatch(
hidden_states=hidden_states,
selected_experts_indices=selected_experts_indices,
top_scores=top_scores,
num_local_experts=num_local_experts,
num_experts=num_experts,
group=group,
score_before_experts=score_before_experts,
)


def combine_tokens(
hidden_states: torch.Tensor,
state: Any,
backend: Literal["deepep", "hybridep"] = "deepep",
) -> torch.Tensor:
"""Combine expert outputs via specified backend."""
if backend == "hybridep":
from . import hybridep
return hybridep.combine_tokens(hidden_states, state)
else:
from .deepep import combine_tokens as _combine
return _combine(hidden_states, state)


__all__ = ["dispatch_tokens", "combine_tokens"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can simply this by selecting module instead of selecting functions, and then use the ep_backend.dispatch and ep_backend.combine. we can use something like

Suggested change
from typing import Any, Literal, Optional, Tuple
import torch
from torch.distributed import ProcessGroup
def dispatch_tokens(
hidden_states: torch.Tensor,
selected_experts_indices: torch.Tensor,
top_scores: torch.Tensor,
num_local_experts: int,
num_experts: int,
group: ProcessGroup,
score_before_experts: bool = True,
backend: Literal["deepep", "hybridep"] = "deepep",
# HybridEP-specific (ignored for DeepEP)
num_permuted_tokens: Optional[int] = None,
capacity_factor: float = 1.0,
num_sms_dispatch: int = 16,
num_sms_combine: int = 16,
pad_multiple: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Any]:
"""Dispatch tokens to experts via specified backend.
Returns: (permuted_hidden, tokens_per_expert, state)
"""
if backend == "hybridep":
from . import hybridep
return hybridep.dispatch_tokens(
hidden_states=hidden_states,
selected_experts_indices=selected_experts_indices,
top_scores=top_scores,
num_local_experts=num_local_experts,
num_experts=num_experts,
group=group,
score_before_experts=score_before_experts,
num_permuted_tokens=num_permuted_tokens,
capacity_factor=capacity_factor,
num_sms_dispatch=num_sms_dispatch,
num_sms_combine=num_sms_combine,
pad_multiple=pad_multiple,
)
else:
from .deepep import dispatch_tokens as _dispatch
return _dispatch(
hidden_states=hidden_states,
selected_experts_indices=selected_experts_indices,
top_scores=top_scores,
num_local_experts=num_local_experts,
num_experts=num_experts,
group=group,
score_before_experts=score_before_experts,
)
def combine_tokens(
hidden_states: torch.Tensor,
state: Any,
backend: Literal["deepep", "hybridep"] = "deepep",
) -> torch.Tensor:
"""Combine expert outputs via specified backend."""
if backend == "hybridep":
from . import hybridep
return hybridep.combine_tokens(hidden_states, state)
else:
from .deepep import combine_tokens as _combine
return _combine(hidden_states, state)
__all__ = ["dispatch_tokens", "combine_tokens"]
from typing import Literal
def get_ep_backend(backend: Literal["deepep", "hybridep"] = "deepep"):
"""Return the backend module for EP communication."""
if backend == "hybridep":
from . import hybridep
return hybridep
else:
from . import deepep
return deepep
__all__ = ["get_ep_backend"]

Copy link
Contributor

@tianyu-l tianyu-l Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel having the very thin wrappers dispatch_tokens and combine_tokens is unnecessary -- they are just function selectors, which we could do in their call sites, i.e. DeepEPExpertParallel.

In particular, the leakage of "HybridEP-specific (ignored for DeepEP)" args can be avoided.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since both reviewers' philosophies are similar - to get rid of redundant code, make a cleaner interface - I did get rid of the this thin layer (since this would create the minimum redundancy) and incorporate functions into DeepEPExpertParallel.

Only effective when expert_parallel_comm_backend="hybridep".
"""

capacity_factor: float = 1.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got the idea, but from the name it's still very hard to know what they are about.

Could you educate us more about these two fields, capacity_factor and num_permuted_tokens, and have more detailed documentation?
E.g. for capacity_factor, do we not always want to set this to max possible? What happens if limit is surpassed?

Btw I somehow feel there could be clearer names, e.g. pre_dispatch_capacity_factor, post_dispatch_num_tokens, etc. We can discuss naming after their meanings become clear to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done adding more comments and renaming

If None, uses blocking mode with D2H for sizing.
"""

pad_multiple: int | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are removing padding here #2255
so this field can be removed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for MXFP8 - since the kernel does comm + permute - and need to be aligned to a multiple of 32. I have removed this from config and set it internally in hybridep if MXFP8 is detected

"num_sms_combine": num_sms_combine,
}

return _buffer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the return seems not used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

Comment on lines 288 to 289
num_sms_dispatch: int = 16,
num_sms_combine: int = 16,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's always 16 in this PR, can we hardcode them in this function for now, instead of passing around?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 588 to 593
hybridep_capacity_factor=hybridep_capacity_factor,
hybridep_num_permuted_tokens=hybridep_num_permuted_tokens,
hybridep_pad_multiple=hybridep_pad_multiple,
# Model-specific settings (set by caller)
hybridep_num_sms_dispatch=hybridep_num_sms_dispatch,
hybridep_num_sms_combine=hybridep_num_sms_combine,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

passing all of them seem unnecessary -- see comments elsewhere

I think at least we can get rid of hybridep_pad_multiple, hybridep_num_sms_dispatch, and hybridep_num_sms_combine

For the other two, hybridep_capacity_factor and hybridep_num_permuted_tokens, if they are absolutely needed, we can pass job_config.parallelism.hybridep around.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 691 to 696
# CUTLASS grouped_mm handles 0-token experts fine at runtime,
# but torch.compile's meta registration doesn't handle zero-size
# tensors (strides (0,0) error). Skip compilation in that case.
# TODO: remove once PyTorch fixes _meta_grouped_mm_common.
if x.shape[0] == 0:
return x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xmfan could you help review this? We recently removed padding on the input tokens to experts.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks fine to add. let me take a look at the grouped mm meta fix

Copy link
Member

@xmfan xmfan Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@elfiegg Actually I tried to repro this standalone in eager, and it looks like this also causes issues with eager backward? https://gist.github.com/xmfan/f99876f07bf5fa023dc912c3d7db4e0f.

It's true that the cutlass grouped_mm forward can handle 0-token experts fine, but the backward implementation cannot. So yes, we can move the 0 check to the backward meta implementation, but you'll still need this branching for 0 token for eager/compiled bwd.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel please help confirm.
It seems bf16 grouped_mm cannot handle 0-token (across all experts) for backward. Is it true? Do we plan to support?


To size correctly, consider: num_tokens * top_k * capacity_factor / ep_degree,
accounting for potential load imbalance across experts.
capacity_factor: Buffer multiplier (>= top_k, <= EP group size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why >= top_k? I thought it should be = 1 for load balanced case.

also why <= EP group size?
The total number of tokens is dp_degree (= ep_degree) * num_local_send_tokens * top_k. If num_local_experts >= top_k, it's possible that worst case all tokens goes to one EP rank, in which case the capacity factor needs to be num_local_send_tokens * top_k (the current default).

But if num_local_experts < top_k, it's impossible that all tokens go to one rank.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the confusion — I realize this may be frustrating to read. My understanding of HybridEP is still evolving, so please bear with me, in this process it might happen that I confuse AI and others...

I intended to put a boundary of [1.0, EP_group_size], but Claude kept thinking it should be [top_k, EP_group_size] and adding back topk related calculations possibly due to my previous prompts..(I should have a more thorough review though. No shame on using AI lol)

Here is a HybridEP implementation doc that @Autumn1998 put together on Saturday.

The reason the range is [1.0, EP_group_size] is because, all-to-all output is:

recv_x:        [num_recv_tokens, hidden_dim]  # Main activation
recv_indices:  [num_recv_tokens, top_k]       # Small integer tensor
recv_scores:   [num_recv_tokens, top_k]       # Small float tensor  

and num_recv_tokens = local_batch x seq_length x range -
In mostly balanced setting, range = 1.0, in the worst, range = EP_group (when all other ranks route tokens to the current).

However, I've recently learnt that HybridEP has already sized internally for the worst case, meaning it will reserve a big enough buffer to max_num_of_tokens = max_num_of_tokens_per_rank(num_recv_tokens) × EP_group_size, see: buffer allocation

And megatronLM uses moe_expert_capacity_factor: The capacity factor for each expert, None means no token
will be dropped. The default is None. Seems their typical range is (0, 1].
I have aligned the implementation with that pattern.

As for num_permuted_token - It makes sense that it can be calculated via capacity_factor so I did a boolean to enable non-blocking mode and usenum_permuted_tokens = local_batch x seq_length x EP_group_size x min(topk, num_local_experts) x moe_expert_capacity_factor. This also aligns with MegatronLM: code pointer

Hope it makes sense to you!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks! I feel this work shares some spirit with my previous attempt on integrating NVSHMEM into EP #1569, but of course is way more mature.

dispatch. If the buffer is too small, it causes illegal memory access (IMA).

In balanced routing each rank receives num_tokens × top_k tokens
(the EP_degree cancels out), so the minimum safe value equals top_k.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comments is confusing.
The default buffer, without multiplier is already num_tokens * top_k, so why "the minimum safe value equals top_k"? I think the minimum safe value is 1, if token balancing, and num_local_experts > top_k.

If num_local_experts < top_k, the minimum safe value should be smaller. But maybe we should change the default buffer size in that case.

Only effective when expert_parallel_comm_backend="hybridep".
"""

receive_tokens_ratio: float = 1.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems you only changed the naming here, but keeps capacity_factor in other places like in HybridEP class.

Now that I understand what this field means, I think capacity_factor is the accurate name. So I think we can keep that.


Recommendation: 1.5 (50% headroom for load imbalance).
With auxiliary-loss-free load balancing, routing stabilizes quickly
and top_k × 1.0-1.5 is typically sufficient.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not mix top_k and without top_k in the capacity factor

num_experts: Total experts across all ranks
group: EP ProcessGroup
score_before_experts: Apply scores before expert computation
num_permuted_tokens: Pre-allocated output buffer size for grouped_mm.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming that capacity_factor gives us the buffer size to receive all-to-all results on a EP rank, that should also be the input size to grouped mm?

Is the reason we still need this field of num_permuted_tokens because we'd like to save activations and not recompute them in backward? Because I think if we always recompute, we can just reuse the buffer created with capacity factor?

Earlier you said

a). HybridEP pre‑allocates a global all-to-all buffer sized for the worst case: max_tokens_to_receive = N * K * capacity_factor per rank, and uses that for all layers. This is memory‑heavy for dropless setups (in which for the worst case, max_tokens_to_receive = N * EP_group_size), but the key is that it’s static and shared across layers, so shapes are known without D2H sync, on GB200 we usually can afford it, but load balanced routing for sure works better.

Instead of max_tokens_to_receive = N * EP_group_size, it should be max_tokens_to_receive = N * K * EP_group_size, assuming N is the local_tokens per DP rank, and top_k <= EP_group_size?

b). HybridEP introduces num_permuted_tokens for reserving permutation output buffer for experts activations, this is per layer and can't be shared globally across layers - and this is actually where memory is heavily consumed, and most of time, many multiples of all-to-all buffer. The worst case output of permutation is N * EP_size * min(K, num_local_expert), where the all-to-all buffer is worst case N * EP_size.

When you say

many multiples of all-to-all buffer.

does it mean all layers added together being multiples of all-to-all buffer? For each single layer, why we need anything above the size of all-to-all buffer?

where the all-to-all buffer is worst case N * EP_size.

The all-to-all buffer is worst case N * K * EP_size?

The worst case output of permutation is N * EP_size * min(K, num_local_expert)

"permuted" seems suggesting pre-permutation size (received tokens) is different from post-permutation size (permuted tokens), why permutation would change buffer size?

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Some more clarifying questions.

(Also I apologize that we may have to have you rebase onto the refactor in #2386 after it's merged).

@@ -526,27 +522,25 @@
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC this code was added by @garrett361 in #1974 for deterministic computation.

Is this change trying to address a separate problem? because HybridEP shouldn't use this implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching - definitely an incorrect code check-in. over the past few weeks Torch nightly has had a strange allocator behavior that caused even very small models to OOM. At the time, I suspected I had introduced issues in the codebase, so I started prompting and exploring potential memory optimizations. After quite a bit of debugging, I realized it was an issue with Torch itself. Reinstalling a newer Torch version resolved the problem. But unfortunately, some of the experimental code changes were mistakenly checked in during that process (In the AI era, we really do need to be extra careful about what gets upstreamed!)
Really appreciate you flagging it though Tianyu!

from .kernels import generate_permute_indices

TOKEN_GROUP_ALIGN_SIZE_M = 8
MXFP8_TOKEN_GROUP_ALIGNMENT_SIZE_M = 32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

available here https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/quantization/__init__.py#L18

also the util functions should be in quantization/mx.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

topk_weights: torch.Tensor,
num_experts: int,
num_permuted_tokens: Optional[int] = None,
moe_expert_capacity_factor: Optional[float] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel num_permuted_tokens and moe_expert_capacity_factor has some redundant info. E.g. in this dispatch function, moe_expert_capacity_factor is not used. Can it be removed from the function signature?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense

raise AssertionError("HybridEP FP8 dispatch not yet supported")

HybridEPBuffer = _require_hybridep()
max_tokens = num_tokens
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is coming from hidden_states, which is the number of tokens before token dispatch. I feel comparing _buffer.config.max_num_of_tokens_per_rank < max_tokens doesn't make sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This simply asks if the allocated buffer for holding max_num_of_tokens_per_rank for all-to-all is less than the current max_num_of_tokens_per_rank(tokens before patch) will possibly use during the next all-to-all; if not, resize the buffer. renamed to be clearer

Only effective when expert_parallel_comm_backend="hybridep".
"""

moe_expert_capacity_factor: float | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question 1: Is this field only used when enable_non_blocking is True? Or would we drop tokens even if we are OK with getting actual buffer usage with D2H sync?

Related question: Is token dropping by this field run-to-run deterministic?

Question 2: You mentioned that HybridEP will anyway allocate max_num_of_tokens_per_rank amount of buffer to receive tokens -- is it right that this is the buffer shared by dispatch / combine and all layers, as we previously discussed?
If so, then is it correct that the purpose of moe_expert_capacity_factor mainly to help

  • AC memory (if dispatching results are saved)
  • load balancing in terms of computation across different ranks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. it will be used in all-to-all buffer reservation as well to potentially reduce the memory consumption, so it'll be used regardless of enable_non_blocking
  2. token dropping is run-to-run deterministic
  3. yes buffer will be reserved once for all-layers. so moe_expert_capacity_factor helps further reduce the memory usage when it's known to be more balanced routing

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's my latest understanding, please correct me if anything is wrong:

There are two modes, blocking vs. non-blocking. And there are two relevant size all-to-all buffer allocation (shared by layers) vs. num_permuted_token for dispatching results (each layer has its own result).

Under blocking mode:

  • We will know the actual size of all-to-all dispatch, and num_permuted_token. One possible thing to do is to allocate a buffer of exactly the needed size in an MoE layer, but since we are sharing the buffer across all layers, it's better if we allocate a large enough size so that we don't always reallocate. (Although, I'm seeing that HybridEPBuffer has an option use_shared_buffer).
  • num_permuted_token is passed by D2H sync, so we don't need to worry about wasting memory if we really want dropless.
  • If we set moe_expert_capacity_factor to allocate small than what actually is needed (from D2H synced num_permuted_token). Tokens will be dropped silently without raising exception.

Under non-blocking mode:

  • We don't know the actual size of all-to-all dispatch, or num_permuted_token. All we do is to allocate the same size for both, determined by pre-dispatch size, ep degree, and moe_expert_capacity_factor.
  • The only difference is that buffer is shared across layers, and num_permuted_token is allocated for storing outputs per layer (if not recopmuted by AC).

Could you add more docstring to the HybridEP fields?


To size correctly, consider: num_tokens * top_k * capacity_factor / ep_degree,
accounting for potential load imbalance across experts.
capacity_factor: Buffer multiplier (>= top_k, <= EP group size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks! I feel this work shares some spirit with my previous attempt on integrating NVSHMEM into EP #1569, but of course is way more mature.

topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
num_experts: int,
num_permuted_tokens: Optional[int] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you explain one more time what "permutation" here means?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

meaning hybridEP has all-to-all + local token permutation (transform from [num_tokens, hidden_dim] + [num_tokens, topk] to [num_tokens * topK (concatenated), hidden_dim]) all in one;

Add HybridEP as a separate backend optimized for GB200/NVLink72 systems.

Key changes:
- Add hybridep.py: Self-contained HybridEP implementation with TMA optimization
- Update __init__.py: Unified interface with configure_backend() for backend selection
- Update parallelize.py (DeepSeekV3, Llama4): Support hybridep backend with proper SAC registration
- Update moe.py: build_moe() supports 'hybridep' backend
- Update job_config.py: Add 'hybridep' option, SM config via env vars
- Update expert_parallel.py: Support hybridep dispatch/combine
- Add moe/utils.py: Padding utility for MXFP8

HybridEP uses dense routing format (vs DeepEP sparse) for:
- TMA instruction optimization
- Fused permute kernel support
- Zero CPU-GPU sync for CUDA Graph compatibility

SM configuration via environment variables:
- HYBRIDEP_NUM_SMS_DISPATCH (default: 16)
- HYBRIDEP_NUM_SMS_COMBINE (default: 16)

Co-authored-by: Cursor <cursoragent@cursor.com>
@elfiegg
Copy link
Contributor Author

elfiegg commented Feb 24, 2026

BTW here is the up-to-date summary of the optimizations we discovered on GB200 Deepseek 671B model. With HybridEP + its D2H sync free feature + MXFP8 we are at 829.72 TFLOPS which is about 8.07x of the baseline recipe (PP=8, EP=32, SAC)

Category Optimization TPS Speedup
Baseline None 102.86 1.0x
Communication Reduce Pipeline Parallelism Degree to 1 (with full activation checkpointing) + EP=64 269.11 2.61x
Communication Enable HybridEP 352.18 1.31x
Attention Update SDPA to cuDNN for SM100 (from CUTLASS fmha SM80, already upstreamed) 576.32 1.64x
GEMM Use MXFP8 precision for GeMM (torch.compile to fuse quantization kernels) 743.55 1.29x
GEMM Reduce D2H sync for grouped_gemm with forced balancing 829.72 1.12x
GEMM Use TE's multi-stream MXFP8 cuBLAS grouped-gemm kernel 1067.54 1.29x

@tianyu-l
Copy link
Contributor

BTW here is the up-to-date summary of the optimizations we discovered on GB200 Deepseek 671B model. With HybridEP + its D2H sync free feature + MXFP8 we are at 829.72 TFLOPS which is about 8.07x of the baseline recipe (PP=8, EP=32, SAC)

Nice! To clarify

  • baseline is NCCL all-to-all, not DeepEP, right? (Is DeepEP relevant at all on GB200?)
  • Are you using 256 GPUs in total? If so, is the baseline PP8 FSDP32 (dense) EP32 (sparse)?

It's also impressive that

TE's multi-stream MXFP8 cuBLAS grouped-gemm kernel

cc @danielvegamyhre

scaling_factor=None,
num_of_experts_per_rank=num_local_experts,
pad_multiple=pad_multiple,
num_permuted_tokens=num_permuted_tokens,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • what happens if num_permuted_tokens is None? Do we require it's in blocking mode?
  • what happens if num_permuted_tokens is given and we are in blocking mode -- is it ignored?

Can we add an assertion here?

Copy link
Contributor Author

@elfiegg elfiegg Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If num_permute_tokens is None and non-blocking is enabled, HybridEP will hit an assertion error:
https://github.com/deepseek-ai/DeepEP/blob/hybrid-ep/deep_ep/hybrid_ep_buffer.py#L447-L448. Otherwise num_permuted_tokens should be None so that we read from D2H.

If num_permuted_tokens is provided (currently it is computed internally and only when non_blocking=True) while running in blocking mode, then num_permuted_tokens takes priority. Setting it to -1 forces it to fall back to the host-synchronized value; otherwise, host-sync value will simply be ignored:
https://github.com/deepseek-ai/DeepEP/blob/hybrid-ep/csrc/hybrid_ep/executor/executor.cu#L158-L170

Copy link
Contributor Author

@elfiegg elfiegg Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably add one more assertion for case 1 so we catch that earilier. assertion 2 might not necessary since we decide when num_permuted_tokens is provided (only when non_blocking=True).
But in the current implementation, we actually set num_permuted_tokens to None if non_blocking isn't enabled. So it looks to me assertion for case 1 isn't necessary either

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting it to -1 forces it to fall back to the host-synchronized value

Is this blocking-mode only behavior or for both blocking and non-blocking? From the code pointer, it seems blocking-mode only.

For blocking mode, setting it to -1 and None seems to have the same effect?

Only effective when expert_parallel_comm_backend="hybridep".
"""

moe_expert_capacity_factor: float | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's my latest understanding, please correct me if anything is wrong:

There are two modes, blocking vs. non-blocking. And there are two relevant size all-to-all buffer allocation (shared by layers) vs. num_permuted_token for dispatching results (each layer has its own result).

Under blocking mode:

  • We will know the actual size of all-to-all dispatch, and num_permuted_token. One possible thing to do is to allocate a buffer of exactly the needed size in an MoE layer, but since we are sharing the buffer across all layers, it's better if we allocate a large enough size so that we don't always reallocate. (Although, I'm seeing that HybridEPBuffer has an option use_shared_buffer).
  • num_permuted_token is passed by D2H sync, so we don't need to worry about wasting memory if we really want dropless.
  • If we set moe_expert_capacity_factor to allocate small than what actually is needed (from D2H synced num_permuted_token). Tokens will be dropped silently without raising exception.

Under non-blocking mode:

  • We don't know the actual size of all-to-all dispatch, or num_permuted_token. All we do is to allocate the same size for both, determined by pre-dispatch size, ep degree, and moe_expert_capacity_factor.
  • The only difference is that buffer is shared across layers, and num_permuted_token is allocated for storing outputs per layer (if not recopmuted by AC).

Could you add more docstring to the HybridEP fields?

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 25, 2026

@elfiegg this is awesome - question for you: from the PR description it sounds like you tested MXFP8 compatibility in a torchtitan fork that uses different kernels for MXFP8 quantization, GEMM, Grouped GEMM, is that right?

If so, can you test if it works with torchtitan's existing/default MXFP8 implementation? Your changes look like they should work with it, and this would be great. HybridEP's native handling of token group size padding to nearest multiple of 32 for mxfp8 grouped gemms will allow us to remove the non-optimal triton kernel based token group padding, which currently runs after the all2all, directly before the grouped gemm, incurring extra copy overhead.

Here is a command you can use as reference to test torchtitan's native MXFP8 implementation with:

(make sure to install torchao nightly build for cuda 12.8+):

  • pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu128
CONFIG_FILE=/home/$USER/torchtitan/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml ./run_train.sh --metrics.log_freq=10 \
--training.steps=100  \
--parallelism.data_parallel_shard_degree=8 \
--parallelism.expert_parallel_degree=8 \
--parallelism.tensor_parallel_degree=1 \
--parallelism.expert_tensor_parallel_degree=1 \
--training.seq_len=8192 \
--activation_checkpoint.mode=full \
--model.print_after_conversion \
--training.local_batch_size=12 \
--quantize.linear.mx.mxfp8_dim0_cast_kernel_choice="triton" \
--quantize.linear.mx.mxfp8_dim1_cast_kernel_choice="cuda" \
--quantize.grouped_mm.mx.fqns="experts" \
--compile.enable --compile.components="model,loss" --debug.moe_force_load_balance \
--model.converters="quantize.grouped_mm.mx,quantize.linear.mx"

Also, can you share how the padding implementation works in HybridEP? I looked through the design doc and did not see these details.

I'm guessing maybe HybridEP exchanges group size metadata across EP ranks, then each EP rank calculates the rounded up token group size for each local expert and allocates a padded buffer for incoming tokens accordingly, is that right?

@elfiegg
Copy link
Contributor Author

elfiegg commented Feb 26, 2026

@tianyu-l Thanks for the detailed analysis! Your understanding is mostly right - a few corrections and clarifications below.
Two separate buffer concepts:
There are indeed two relevant buffers, and they behave the same in both blocking and non-blocking modes:

  1. All-to-all communication buffer - allocated once via get_buffer(), shared across all MoE layers (global _buffer). Within a single layer, dispatch and combine phases share the same memory (use_shared_buffer=True - this is dispatch/combine sharing, not cross-layer sharing; cross-layer sharing is via the global _buffer instance). HybridEP sizes it for worst-case receive: max_num_of_tokens_per_rank × ranks_per_node × nodes 1.
  2. num_permuted_tokens - the output capacity for the fused permute kernel, computed per dispatch call (per layer). This is where token dropping actually happens: tokens whose permuted offset exceeds this limit are silently dropped and overflow_flag is set on GPU.

Blocking mode:

  • num_permuted_tokens is None => converted to -1 in C++ 3. After cudaStreamSynchronize, tokens_per_expert is read from pinned CPU memory and the exact num_permuted_tokens is computed on the host 4. So no tokens are dropped -it's always dropless.
  • If moe_expert_capacity_factor is provided, it only affects the initial get_buffer() allocation size. However, HybridEP's update_template_config auto-grows max_num_of_tokens_per_rank to max(actual_token_count, stored_max) on every dispatch 5, so the buffer grows back to full size on the first forward pass. Net effect: moe_expert_capacity_factor has no lasting impact in blocking mode.

Non-blocking mode:

  • No D2H sync is allowed, so num_permuted_tokens must be pre-computed upfront. We estimate it as num_tokens × ep_size × min(num_local_experts, top_k) × capacity_factor, aligned for MXFP8. HybridEP asserts num_permuted_tokens >= 0 6.
    capacity_factor=1.0 = worst-case sizing, no drops, highest memory. Values < 1.0 reduce memory; safe in practice with forced load balancing (aux-loss / sinkhorn) that keeps token distribution roughly uniform.
  • Same auto-grow applies to the all-to-all buffer (it always ends up at full size regardless of capacity_factor).

What's the same in both modes:
The all-to-all communication buffer is shared across all layers in both modes.
num_permuted_tokens is per-dispatch-call (per layer) in both modes.
use_shared_buffer means dispatch/combine share intra-node memory in both modes.

I've updated the docstrings on hybridep_expert_capacity_factor and hybridep_non_blocking in configs.py to document all of this, and added inline comments in the dispatch code explaining the num_permuted_tokens / non_blocking interaction.

- HYBRIDEP_NUM_SMS_COMBINE (default: 16)
"""

hybridep_expert_capacity_factor: float | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From torchtitan perspective, the field seems only meaningful in non-blocking mode.
Does it make sense to either:

  • add an assertion that when hybridep_non_blocking is False, asserting hybridep_expert_capacity_factor is None (after reordering the two args, non_blocking should be first, cf second)
  • unify the two into one config hybridep_non_blocking_capacity_factor: When None, fall back to blocking mode.

# LICENSE file in the root directory of this source tree.

"""
HybridEP: Expert Parallel Communication for GB200 NVLink72 Systems.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it also work on GB300? cc @shuhuayu

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it works on B200, GB200, GB300. Megatron has benchmarked deepseek training with hybridep here: https://docs.nvidia.com/nemo/megatron-bridge/latest/performance-summary.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

8 participants