Skip to content

Remove unnecessary token padding for MoE in BF16 mode#2255

Open
rakkit wants to merge 5 commits intopytorch:mainfrom
rakkit:moe_fast_path
Open

Remove unnecessary token padding for MoE in BF16 mode#2255
rakkit wants to merge 5 commits intopytorch:mainfrom
rakkit:moe_fast_path

Conversation

@rakkit
Copy link
Contributor

@rakkit rakkit commented Jan 20, 2026

fix of #2225

Context: Solar Open-102B points out in BF16 mode, Expert parallel did unnecessary token padding (ps. also non-EP case).
This PR set TOKEN_GROUP_ALIGN_SIZE_M=1 by Default.

Test:

Original implement (with TOKEN_GROUP_ALIGN_SIZE_M=8)

CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10
image
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10 --parallelism.expert_parallel_degree=2
image

And with this PR

CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10
image
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10 --parallelism.expert_parallel_degree=2
image

Another update to completely remove TOKEN_GROUP_ALIGN_SIZE_M

CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10

image

CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10 --parallelism.expert_parallel_degree=2
image

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 20, 2026

TOKEN_GROUP_ALIGN_SIZE_M = 8
ValidTokenGroupAlignmentSize = Literal[8, 16, 32]
TOKEN_GROUP_ALIGN_SIZE_M = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix is "soft", in the sense that the padding code path still exists for bf16.

I wonder whether it's viable to go one step further -- remove all padding logic for bf16 and move padding logic to quantized paths only. cc @danielvegamyhre

Copy link
Contributor

@danielvegamyhre danielvegamyhre Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either way is fine, the 8 token alignment is needed if we want to use TMA in any kernels operating on each token group (8*2 bytes per elem = 16 byte alignment). However, if we are only doing that in the low precision code path, then there's no reason to pad.

Feel free to remove bf16 padding entirely.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For bf16 8-token alignment is not needed anywhere, see

import torch
import torch.nn.functional as F
x = torch.randn(2048, 4096, device="cuda", dtype=torch.bfloat16).requires_grad_(True)
w = torch.randn(2, 4096, 7168, device="cuda", dtype=torch.bfloat16).requires_grad_(True)
# odd offsets
offs = torch.tensor([1023, 2048], device="cuda", dtype=torch.int32)
out = F.grouped_mm(x, w, offs=offs)
gO = torch.rand_like(out)
out.backward(gO)
# check that gradients are computed
print(x.grad.sum(), w.grad.sum())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danielvegamyhre
Right now we are mixing padding and permutation into one kernel. Since bf16 doesn't require padding, I wonder if it makes sense to move padding to quantization kernel? The argument is that the kernel itself should be general and not require user to do padding from outside.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure if we agree move the padding logic to quant paths then i will refactor to remove TOKEN_GROUP_ALIGN_SIZE_M in torchtitan.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianyu-l we have a version of the permutation and pad/fill kernel in torchao now, used in the MXFP8 EP primitives. It is not fused with quantization though. To clarify, are you asking if we can delete the permute+pad kernel from torchtitan and replace it with fused permute+pad+quantize kernel in torchao?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danielvegamyhre My request is that we remove padding from torchtitan entirely, while keeping correctness.

In the past we have the permute+pad kernel to avoid d2h sync on the padding part. Now that if we no longer need padding for bf16, I'd hope we remove the kernel altogether, but that requires torchao to handle padding.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is doable

@danielvegamyhre
Copy link
Contributor

Solar Open-102B technical report is very interesting @rakkit, thanks for sharing it!

@rakkit
Copy link
Contributor Author

rakkit commented Jan 31, 2026

thanks @danielvegamyhre @tianyu-l

made a refactor and this aligment things are completely removed now. (PR message also updated for test run).

i think we still need generate_permute_indices kernels for EP permute, but it runs on no-padding mode by default now

x = torch.vstack((x, x.new_zeros((x.shape[-1]))))
input_shape = x.shape
x = x[permuted_indices, :]
x = torch.index_select(x, 0, permuted_indices)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious, what's the reason for this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i revert the changes, i though they end up at same kernel

max_len = x.shape[0]

with torch.no_grad():
(permuted_indices, num_tokens_per_expert, _offsets,) = generate_permute_indices(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we still need generate_permute_indices kernels for EP permute, but it runs on no-padding mode by default now

Do we? IIUC we needed generate_permute_indices to avoid D2H sync (in getting pad sizes). Now that we don't do padding, can we just use pytorch code to do permutation? Please correct me if I'm wrong.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i addede a pytorch version generate_permute_indices above.

@tianyu-l
Copy link
Contributor

tianyu-l commented Feb 8, 2026

@danielvegamyhre let us know when the mxfp8 kernels supports padding.

@danielvegamyhre
Copy link
Contributor

@rakkit i have not closely reviewed the PR yet but as an intermediate solution can we just set TOKEN_GROUP_ALIGNMENT_SIZE=1 for bf16 until a performant solution for automatic padding in MXFP8? I'm prototyping some solutions to this, but it's harder than it sounds to not kill performance. The TLDR is it needs to be fused into the alltoall dispatch, which will take time, maybe ~weeks

@rakkit
Copy link
Contributor Author

rakkit commented Feb 24, 2026

hi @danielvegamyhre i am not sure if i understand it correctly.

do you mean, set TOKEN_GROUP_ALIGNMENT_SIZE=1, and we let

  • BF16, without EP, no padding,
  • BF16, with EP, using current padding kernels (with size=1)
  • (MX)FP8, using current padding kernels

then its the init commits of this PR you can check it here so i can checkout to that commits?

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 25, 2026

  • BF16, with EP, using current padding kernels (with size=1)

on the bf16 path if alignment is 1, we can early exit. i am just asking to keep the padding path as it exists, and not delete it for now, so mxfp8 can continue to use it, until we land a proper solution, if is ok with everyone.

additional extra context - at the moment, it looks like it will be either using HybridEP's native handling for this, or alternatively, if that doesn't work out, extending an experimental MXFP8 all-to-all implementation we have in torchao to also handle the padded.

once that solution is tested and landed, we can delete the current non-optimal padding logic entirely.

@rakkit
Copy link
Contributor Author

rakkit commented Feb 25, 2026

@danielvegamyhre

on the bf16 path if alignment is 1, we can early exit.

the kernel can do padding and perturb.

what i mean is, (BF16) we early exit for None-EP path, and using current padding kernels (with size=1) for perturb for EP path. and thats init commit of this PR

@pianpwk
Copy link

pianpwk commented Feb 25, 2026

@rakkit sorry, would you mind also looking into removing the for-loop path padding/unpadding as part of this PR? This would be an alternative fix for #2399. Possibly conditional on TOKEN_GROUP_ALIGNMENT_SIZE=1, depending on how that settles:

def _run_experts_for_loop(

This patch should work, and on just one test, it gives me equivalent numerics before/after:

diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py
--- a/torchtitan/models/moe/moe.py
+++ b/torchtitan/models/moe/moe.py
@@ -81,12 +81,8 @@ def _run_experts_for_loop(
     # NOTE: this would incur a synchronization between device and host
     num_tokens_per_expert_list = num_tokens_per_expert.tolist()

-    # side-effect code due to the usage of generate_permute_indices
-    num_padding = x.shape[0] - sum(num_tokens_per_expert_list)
-
     # a tuple of tensors indexed by experts
     # each with shape (tokens_per_expert(varying), dim)
     x_splits = torch.split(
-        x[: sum(num_tokens_per_expert_list)],
+        x,
         split_size_or_sections=num_tokens_per_expert_list,
         dim=0,
     )
@@ -100,8 +96,6 @@ def _run_experts_for_loop(
         out_experts_splits.append(h)
     out = torch.cat(out_experts_splits, dim=0)

-    # side-effect code due to the usage of generate_permute_indices
-    out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1]))))
-
     return out

I think the forloop path only activates on < SM90 though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot. high priority

Projects

Status: Todo

Development

Successfully merging this pull request may close these issues.

6 participants