fix(MoE): Apply grad_placements=Partial() to router to_local for TP#2388
fix(MoE): Apply grad_placements=Partial() to router to_local for TP#2388fatih-uzlmz wants to merge 8 commits intopytorch:mainfrom
Conversation
torchtitan/train.py
Outdated
There was a problem hiding this comment.
Thanks for the catch :)
I accidentally branched off my PR history, just reverted train.py so it should be clean now,
And I am also currently working on the unit test to verify the gradients, will update you soon
wconstab
left a comment
There was a problem hiding this comment.
can you add a small test case showing the numerics issue / validating the fix?
|
I opened the original issue here #2387 this does not look like the right fix. It might work for TP but I don't think is correct for EP. |
|
Thanks @volcacius for the input, You're right, I think my fix enforces I think the robust fix is to make it conditional, so something like
Does that align with what you were thinking? |
|
cc: @volcacius @wconstab I went ahead and pushed a new commit with this conditional logic, the code now explicitly checks I also re ran our unit test with the new logic to ensure everything is still functioning properly and everything passes
Let me know if there's anything else needed here 👍 |
torchtitan/models/moe/moe.py
Outdated
| # Only enforce Partial grad placement if the weight is actually replicated (TP). | ||
| # If it's sharded (EP), we shouldn't reduce gradients across different experts. | ||
| is_replicated = isinstance(self.w1.placements[0], Replicate) | ||
| grad_placement = (Partial(),) if is_replicated else None |
There was a problem hiding this comment.
I'm not sure if it fixes anything. In EP / TP, the weights are always sharded.
For activations, https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/expert_parallel.py#L57 fixes the issue.
I think there's one more issue here https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/expert_parallel.py#L285
|
I don't think this works. It's not the expert weights that are replicated, it's the gate weights through NoParallel. If you look at my original issue I suggested a different fix. |
|
Okay I think I understand whats going on now, thank you guys both for the guidance. First off my apologies, I accidentally applied the @volcacius is right about the divergence reported in #2387 , its happening because the Router Gate weights are replicated, and they are losing their backward reduction when passing through the I will revert my changes to moe.py and move the @tianyu-l for the manual slice in expert_parallel.py L285—since that drops the DTensor backward communication, would you like me to include a fix for that in this PR as well, or should we isolate this PR strictly to the NoParallel router fix? |
Sorry it's not clear to me what the root cause is, and why it's not the one that I mentioned. Could you include under what parallelism do you see the problem? The slicing one I mentioned would only occur during ETP=1 TP > 1 EP > 1. I would be surprised if it happens with TP = ETP. |
|
Yeah fair point, Since I picked this up from the issue tracker, I don't have the original reproduction config. @volcacius could you share your exact TP, EP, and ETP degrees? |
|
Looks like #2416 addresses the root cause via the activations. Closing this out, thanks for the great discussion everyone :) |



Description
Fixes #2387 (MoE router replication broken w/ TP).
The Issue
When Tensor Parallelism (TP) is enabled, the MoE router weights were diverging immediately after the first step. This was caused by
to_local()being called on the router weights without specifyinggrad_placements. As a result, the backward pass was treating the local gradients as final gradients rather than partial sums, skipping the necessary AllReduce synchronization across ranks.The Fix
Added
grad_placements=(Partial(),)to theto_local()conversion intorchtitan/models/moe/moe.py. This explicitly signals to the autograd engine that the resulting local tensor gradients must be reduced (summed) across the TP mesh before the optimizer step.Verification
Partialis correctly imported fromtorch.distributed._tensor.tests/unit_tests/test_compile_moe.pylocally; tests passed successfully.