Remove unnecessary token padding for MoE in BF16 mode#2255
Remove unnecessary token padding for MoE in BF16 mode#2255rakkit wants to merge 5 commits intopytorch:mainfrom
Conversation
torchtitan/models/moe/utils.py
Outdated
|
|
||
| TOKEN_GROUP_ALIGN_SIZE_M = 8 | ||
| ValidTokenGroupAlignmentSize = Literal[8, 16, 32] | ||
| TOKEN_GROUP_ALIGN_SIZE_M = 1 |
There was a problem hiding this comment.
This fix is "soft", in the sense that the padding code path still exists for bf16.
I wonder whether it's viable to go one step further -- remove all padding logic for bf16 and move padding logic to quantized paths only. cc @danielvegamyhre
There was a problem hiding this comment.
Either way is fine, the 8 token alignment is needed if we want to use TMA in any kernels operating on each token group (8*2 bytes per elem = 16 byte alignment). However, if we are only doing that in the low precision code path, then there's no reason to pad.
Feel free to remove bf16 padding entirely.
There was a problem hiding this comment.
For bf16 8-token alignment is not needed anywhere, see
import torch
import torch.nn.functional as F
x = torch.randn(2048, 4096, device="cuda", dtype=torch.bfloat16).requires_grad_(True)
w = torch.randn(2, 4096, 7168, device="cuda", dtype=torch.bfloat16).requires_grad_(True)
# odd offsets
offs = torch.tensor([1023, 2048], device="cuda", dtype=torch.int32)
out = F.grouped_mm(x, w, offs=offs)
gO = torch.rand_like(out)
out.backward(gO)
# check that gradients are computed
print(x.grad.sum(), w.grad.sum())
There was a problem hiding this comment.
@danielvegamyhre
Right now we are mixing padding and permutation into one kernel. Since bf16 doesn't require padding, I wonder if it makes sense to move padding to quantization kernel? The argument is that the kernel itself should be general and not require user to do padding from outside.
There was a problem hiding this comment.
sure if we agree move the padding logic to quant paths then i will refactor to remove TOKEN_GROUP_ALIGN_SIZE_M in torchtitan.
There was a problem hiding this comment.
@tianyu-l we have a version of the permutation and pad/fill kernel in torchao now, used in the MXFP8 EP primitives. It is not fused with quantization though. To clarify, are you asking if we can delete the permute+pad kernel from torchtitan and replace it with fused permute+pad+quantize kernel in torchao?
There was a problem hiding this comment.
@danielvegamyhre My request is that we remove padding from torchtitan entirely, while keeping correctness.
In the past we have the permute+pad kernel to avoid d2h sync on the padding part. Now that if we no longer need padding for bf16, I'd hope we remove the kernel altogether, but that requires torchao to handle padding.
|
Solar Open-102B technical report is very interesting @rakkit, thanks for sharing it! |
|
thanks @danielvegamyhre @tianyu-l made a refactor and this aligment things are completely removed now. (PR message also updated for test run). i think we still need |
torchtitan/models/moe/utils.py
Outdated
| x = torch.vstack((x, x.new_zeros((x.shape[-1])))) | ||
| input_shape = x.shape | ||
| x = x[permuted_indices, :] | ||
| x = torch.index_select(x, 0, permuted_indices) |
There was a problem hiding this comment.
curious, what's the reason for this change?
There was a problem hiding this comment.
i revert the changes, i though they end up at same kernel
torchtitan/models/moe/utils.py
Outdated
| max_len = x.shape[0] | ||
|
|
||
| with torch.no_grad(): | ||
| (permuted_indices, num_tokens_per_expert, _offsets,) = generate_permute_indices( |
There was a problem hiding this comment.
i think we still need
generate_permute_indiceskernels for EP permute, but it runs on no-padding mode by default now
Do we? IIUC we needed generate_permute_indices to avoid D2H sync (in getting pad sizes). Now that we don't do padding, can we just use pytorch code to do permutation? Please correct me if I'm wrong.
There was a problem hiding this comment.
i addede a pytorch version generate_permute_indices above.
|
@danielvegamyhre let us know when the mxfp8 kernels supports padding. |
|
@rakkit i have not closely reviewed the PR yet but as an intermediate solution can we just set TOKEN_GROUP_ALIGNMENT_SIZE=1 for bf16 until a performant solution for automatic padding in MXFP8? I'm prototyping some solutions to this, but it's harder than it sounds to not kill performance. The TLDR is it needs to be fused into the alltoall dispatch, which will take time, maybe ~weeks |
|
hi @danielvegamyhre i am not sure if i understand it correctly. do you mean, set TOKEN_GROUP_ALIGNMENT_SIZE=1, and we let
then its the init commits of this PR you can check it here so i can checkout to that commits? |
on the bf16 path if alignment is 1, we can early exit. i am just asking to keep the padding path as it exists, and not delete it for now, so mxfp8 can continue to use it, until we land a proper solution, if is ok with everyone. additional extra context - at the moment, it looks like it will be either using HybridEP's native handling for this, or alternatively, if that doesn't work out, extending an experimental MXFP8 all-to-all implementation we have in torchao to also handle the padded. once that solution is tested and landed, we can delete the current non-optimal padding logic entirely. |
the kernel can do padding and perturb. what i mean is, (BF16) we early exit for None-EP path, and using current padding kernels (with size=1) for perturb for EP path. and thats init commit of this PR |
|
@rakkit sorry, would you mind also looking into removing the for-loop path padding/unpadding as part of this PR? This would be an alternative fix for #2399. Possibly conditional on This patch should work, and on just one test, it gives me equivalent numerics before/after: diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py
--- a/torchtitan/models/moe/moe.py
+++ b/torchtitan/models/moe/moe.py
@@ -81,12 +81,8 @@ def _run_experts_for_loop(
# NOTE: this would incur a synchronization between device and host
num_tokens_per_expert_list = num_tokens_per_expert.tolist()
- # side-effect code due to the usage of generate_permute_indices
- num_padding = x.shape[0] - sum(num_tokens_per_expert_list)
-
# a tuple of tensors indexed by experts
# each with shape (tokens_per_expert(varying), dim)
x_splits = torch.split(
- x[: sum(num_tokens_per_expert_list)],
+ x,
split_size_or_sections=num_tokens_per_expert_list,
dim=0,
)
@@ -100,8 +96,6 @@ def _run_experts_for_loop(
out_experts_splits.append(h)
out = torch.cat(out_experts_splits, dim=0)
- # side-effect code due to the usage of generate_permute_indices
- out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1]))))
-
return outI think the forloop path only activates on < SM90 though. |
fix of #2225
Context: Solar Open-102B points out in BF16 mode, Expert parallel did unnecessary token padding (ps. also non-EP case).
This PR set
TOKEN_GROUP_ALIGN_SIZE_M=1by Default.indices_padding_wrapper_permutetakesTOKEN_GROUP_ALIGN_SIZE_M=1andpadded_max_len = x.shape[0]that can avoid any padding.Test:
Original implement (with
TOKEN_GROUP_ALIGN_SIZE_M=8)CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10 --parallelism.expert_parallel_degree=2And with this PR
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10 --parallelism.expert_parallel_degree=2Another update to completely remove
TOKEN_GROUP_ALIGN_SIZE_MCONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10 --parallelism.expert_parallel_degree=2