Skip to content

fix(MoE): Apply grad_placements=Partial() to router to_local for TP#2388

Closed
fatih-uzlmz wants to merge 8 commits intopytorch:mainfrom
fatih-uzlmz:fix/moe-router-tp-divergence
Closed

fix(MoE): Apply grad_placements=Partial() to router to_local for TP#2388
fatih-uzlmz wants to merge 8 commits intopytorch:mainfrom
fatih-uzlmz:fix/moe-router-tp-divergence

Conversation

@fatih-uzlmz
Copy link
Contributor

Description

Fixes #2387 (MoE router replication broken w/ TP).

The Issue

When Tensor Parallelism (TP) is enabled, the MoE router weights were diverging immediately after the first step. This was caused by to_local() being called on the router weights without specifying grad_placements. As a result, the backward pass was treating the local gradients as final gradients rather than partial sums, skipping the necessary AllReduce synchronization across ranks.

The Fix

Added grad_placements=(Partial(),) to the to_local() conversion in torchtitan/models/moe/moe.py. This explicitly signals to the autograd engine that the resulting local tensor gradients must be reduced (summed) across the TP mesh before the optimizer step.

Verification

  • Validated that Partial is correctly imported from torch.distributed._tensor.
  • Ran tests/unit_tests/test_compile_moe.py locally; tests passed successfully.
moefix

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 17, 2026
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accidental change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch :)

I accidentally branched off my PR history, just reverted train.py so it should be clean now,

And I am also currently working on the unit test to verify the gradients, will update you soon

Copy link
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a small test case showing the numerics issue / validating the fix?

@fatih-uzlmz
Copy link
Contributor Author

Alright, have gone ahead and added tests/unit_tests/test_moe_tp.py

Basically it mocks a distributed environment and validates that gradients correctly flow through GroupedExperts weights when they are wrapped as DTensors. I then verified locally and it passes, here is a screenshot from my terminal

MoE_test_case

@volcacius
Copy link

I opened the original issue here #2387 this does not look like the right fix. It might work for TP but I don't think is correct for EP.

@fatih-uzlmz
Copy link
Contributor Author

Thanks @volcacius for the input,

You're right, I think my fix enforces Partial reduction, which is correct for replicated case (like the TP divergence we saw in #2387), but it'd probably be incorrect for EP where weights are Sharded

I think the robust fix is to make it conditional, so something like

# Only enforce Partial grad placement if the weight is actually replicated
is_replicated = isinstance(self.w1.placements[0], Replicate)
grad_placement = (Partial(),) if is_replicated else None

w1 = self.w1.to_local(grad_placements=grad_placement)
# Repeat for w2, w3...

Does that align with what you were thinking?

@fatih-uzlmz
Copy link
Contributor Author

cc: @volcacius @wconstab

I went ahead and pushed a new commit with this conditional logic,

the code now explicitly checks isinstance(self.w1.placements[0], Replicate) before applying Partial(), so this fixes the TP divergence issue while remaining completely safe for EP/sharded scenarios.

I also re ran our unit test with the new logic to ensure everything is still functioning properly and everything passes

Screenshot 2026-02-19 at 5 12 09 PM

Let me know if there's anything else needed here 👍

# Only enforce Partial grad placement if the weight is actually replicated (TP).
# If it's sharded (EP), we shouldn't reduce gradients across different experts.
is_replicated = isinstance(self.w1.placements[0], Replicate)
grad_placement = (Partial(),) if is_replicated else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if it fixes anything. In EP / TP, the weights are always sharded.
For activations, https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/expert_parallel.py#L57 fixes the issue.
I think there's one more issue here https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/expert_parallel.py#L285

@volcacius
Copy link

I don't think this works. It's not the expert weights that are replicated, it's the gate weights through NoParallel. If you look at my original issue I suggested a different fix.

@fatih-uzlmz
Copy link
Contributor Author

Okay I think I understand whats going on now, thank you guys both for the guidance.

First off my apologies, I accidentally applied the Partial() fix to the GroupedExperts weights (which @tianyu-l rightly pointed out are always sharded, making my fix useless).

@volcacius is right about the divergence reported in #2387 , its happening because the Router Gate weights are replicated, and they are losing their backward reduction when passing through the NoParallel wrapper during DTensor erasure.

I will revert my changes to moe.py and move the grad_placements=(Partial(),) fix directly into the NoParallel class where the gate outputs are localized.

@tianyu-l for the manual slice in expert_parallel.py L285—since that drops the DTensor backward communication, would you like me to include a fix for that in this PR as well, or should we isolate this PR strictly to the NoParallel router fix?

@fatih-uzlmz
Copy link
Contributor Author

I reverted the moe.py changes and moved the fix directly to the NoParallel wrapper in torchtitan/distributed/__init__.py

now the _prepare_output_fn conditionally applies grad_placements=(Partial(),) if the output layout is Replicate right before to_local() erases the DTensor.

I also rewrote the unit test to specifically target the NoParallel wrapper with a dummy router gate, and it successfully verifies that gradients now flow correctly through the erasure during the backward pass, here is a ss.

Screenshot 2026-02-19 at 10 24 11 PM

Let me know if this looks good!

@tianyu-l
Copy link
Contributor

@fatih-uzlmz

@volcacius is right about the divergence reported in #2387 , its happening because the Router Gate weights are replicated, and they are losing their backward reduction when passing through the NoParallel wrapper during DTensor erasure.

Sorry it's not clear to me what the root cause is, and why it's not the one that I mentioned.
Having replicate weights with grad_placement=Replicate is fine, as long as what interacts with it is also Replicate (the TP activations). So I don't think the fix should be on router weights.

Could you include under what parallelism do you see the problem? The slicing one I mentioned would only occur during ETP=1 TP > 1 EP > 1. I would be surprised if it happens with TP = ETP.

@fatih-uzlmz
Copy link
Contributor Author

Yeah fair point,
Replicate x Replicate should yield identical gradients across ranks naturally, so the router weights shouldn't strictly need a Partial reduction to avoid divergence.

Since I picked this up from the issue tracker, I don't have the original reproduction config.

@volcacius could you share your exact TP, EP, and ETP degrees?
If you were running ETP=1, TP>1, EP>1, then @tianyu-l is probably right about that manual slice corrupting the downstream gradients.

@fatih-uzlmz
Copy link
Contributor Author

Looks like #2416 addresses the root cause via the activations.

Closing this out, thanks for the great discussion everyone :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MoE router replication is broken w/ TP

4 participants