Skip to content

[Bugfix] Force reduce in bwd for MoE router weights #2416

Open
acisseJZhong wants to merge 7 commits intomainfrom
force_reduce_in_bwd
Open

[Bugfix] Force reduce in bwd for MoE router weights #2416
acisseJZhong wants to merge 7 commits intomainfrom
force_reduce_in_bwd

Conversation

@acisseJZhong
Copy link
Contributor

@acisseJZhong acisseJZhong commented Feb 21, 2026

solve #2387 where user reported a bug that MoE router weights diff on each rank due to a missing all reduce under certain senarios:

  • self.score_before_experts=False for both ETP=1 and ETP=TP
  • self.score_before_experts=True and ETP=1

Bug analysis

Case1:
When ETP=1 < TP, MoE is using TP2EP(etp_mesh is None, there is no ETP), gate is Replicated(via NoParallel), experts weight is Sharded(via EP mesh), MoE input is Replicated.

Root Cause: ReordererSequenceParallell is applied to split top_scores across TP ranks. Since each TP rank process its local slice of tokens through the experts, d_routed_input is Partial and there is no forced all-reduce.

Case2:
When ETP = TP, MoE is using DP2EP, gate is Replicated(via NoParallel), experts weight is Sharded(via ETP which force all reduce in bwd #1878), MoE input is Replicated on TP mesh.

Root Cause: The gradient flowing back to gate.weight is Partial when self.score_before_experts=False, which we incorrectly marked as Replicated and therefore results in wrong numerics.

Explanation:
When self.score_before_experts=True, d_routed_input is already all reduced because of the fix #1878 we added to ETP.

routed_input = routed_input * top_scores_experts_sorted
routed_output = self.experts(routed_input, ...) 
out_experts = routed_output_unsorted.sum(dim=1)    

When self.score_before_experts=False, gradient of top_scores from bmm is Partial. The previous all reduce fix only force all reduce on d_routed_input, but not d_top_scores which is wrong.

routed_output = self.experts(routed_input, ...)
out_experts = bmm(top_scores, routed_output_unsorted) 

The fix we propose that could fix both Case1 and 2

  • We revert the fix in fix MoE TP backward #1878 to keep Partial dx flowing back, and delay the all reduce later at the boundary of MoE.
  • We choose to not use ColWiseParallel/RowwiseParallel for shared expert as they force all reduce dx; but we want to keep Partial dx. So we introduced MoEColWiseParallel/MoERowwiseParallel which removed the input and output hooks. See the reasoning here for more details.
  • We annotate gradient of router to be partial in bwd through adding a field to NoParallel(), so that DTensor will handle the all reduce
image image

Caveat: The graph is plotted from Claude generated computational graph , please take it with a grain of salt.

cc @volcacius @fatih-uzlmz

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 21, 2026
# so the router gate receives the correct full gradient.
top_scores = DTensor.from_local(
top_scores, device_mesh, (Replicate(),)
).to_local(grad_placements=(Partial(),))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is doing 2 DTensor operations (to, from) really the best way here? why not just insert one autograd.function where the backwards is an all_reduce?

Copy link
Contributor Author

@acisseJZhong acisseJZhong Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I guess we could also insert a customized function where fwd=no_op and bwd does all_reduce? I was following a similar fix #1878
cc @tianyu-l if you know why we might prefer this over customized function

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. well short term it seems harmless to land this, other than it would have more overhead. (worth confirming, but I assume it does not do any extra collectives or operations?) anyway, we could fix them all to use a dedicated function from spmd_types once that's available.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should not have any extra collectives other than the all reduce in bwd. yes I think we could directly use the function from spmd_types once it's ready!

@fatih-uzlmz
Copy link
Contributor

@acisseJZhong
Thanks for the ping and for putting together the great deep dive, very fascinating to see how the bug split into two distinct cases based on the ETP/TP config

I'm glad @volcacius and I were on the right track with the NoParallel wrapper being the correct conceptual boundary for the reduction, but your fix handling the activations directly is a really clean solution.

Really happy to have helped kick off this investigation and grateful for the opportunity to learn from the team here :)

Thanks everyone once again

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants