[Bugfix] Force reduce in bwd for MoE router weights #2416
[Bugfix] Force reduce in bwd for MoE router weights #2416acisseJZhong wants to merge 7 commits intomainfrom
Conversation
| # so the router gate receives the correct full gradient. | ||
| top_scores = DTensor.from_local( | ||
| top_scores, device_mesh, (Replicate(),) | ||
| ).to_local(grad_placements=(Partial(),)) |
There was a problem hiding this comment.
is doing 2 DTensor operations (to, from) really the best way here? why not just insert one autograd.function where the backwards is an all_reduce?
There was a problem hiding this comment.
ok. well short term it seems harmless to land this, other than it would have more overhead. (worth confirming, but I assume it does not do any extra collectives or operations?) anyway, we could fix them all to use a dedicated function from spmd_types once that's available.
There was a problem hiding this comment.
it should not have any extra collectives other than the all reduce in bwd. yes I think we could directly use the function from spmd_types once it's ready!
|
@acisseJZhong I'm glad @volcacius and I were on the right track with the NoParallel wrapper being the correct conceptual boundary for the reduction, but your fix handling the activations directly is a really clean solution. Really happy to have helped kick off this investigation and grateful for the opportunity to learn from the team here :) Thanks everyone once again |
solve #2387 where user reported a bug that MoE router weights diff on each rank due to a missing all reduce under certain senarios:
self.score_before_experts=Falsefor both ETP=1 and ETP=TPself.score_before_experts=Trueand ETP=1Bug analysis
Case1:
When ETP=1 < TP, MoE is using TP2EP(
etp_meshis None, there is no ETP),gateis Replicated(via NoParallel), experts weight is Sharded(via EP mesh), MoE input is Replicated.Root Cause:
ReordererSequenceParallellis applied to splittop_scoresacross TP ranks. Since each TP rank process its local slice of tokens through the experts,d_routed_inputis Partial and there is no forced all-reduce.Case2:
When ETP = TP, MoE is using DP2EP,
gateis Replicated(via NoParallel), experts weight is Sharded(via ETP which force all reduce in bwd #1878), MoE input is Replicated on TP mesh.Root Cause: The gradient flowing back to
gate.weightis Partial whenself.score_before_experts=False, which we incorrectly marked as Replicated and therefore results in wrong numerics.Explanation:
When
self.score_before_experts=True,d_routed_inputis already all reduced because of the fix #1878 we added to ETP.When
self.score_before_experts=False, gradient oftop_scoresfrombmmis Partial. The previous all reduce fix only force all reduce ond_routed_input, but notd_top_scoreswhich is wrong.The fix we propose that could fix both Case1 and 2
ColWiseParallel/RowwiseParallelfor shared expert as they force all reduce dx; but we want to keep Partial dx. So we introducedMoEColWiseParallel/MoERowwiseParallelwhich removed the input and output hooks. See the reasoning here for more details.Caveat: The graph is plotted from Claude generated computational graph , please take it with a grain of salt.
cc @volcacius @fatih-uzlmz