Conversation
torchtitan/distributed/utils.py
Outdated
| if isinstance(non_ep_grads_total_norm, DTensor): | ||
| non_ep_grads_total_norm = non_ep_grads_total_norm.full_tensor() | ||
| # Group non-EP grads by mesh to handle mixed meshes (e.g., VLM models | ||
| # where vision encoder params are on (fsdp) mesh while decoder params |
There was a problem hiding this comment.
If decoder has FSDP 2, TP 2 -- what FSDP degree does vision encoder has, 2 or 4?
There was a problem hiding this comment.
It is 2. Both decoder and visual encoder share the same fsdp mesh and tp mesh.
There was a problem hiding this comment.
So on TP mesh, encoder is doing replicated computation? If so, why would you need special treatment on grad clipping here?
There was a problem hiding this comment.
Thanks for the question. I think this is similar to the separate grouping of tensors on ep mesh and non-ep meshes. In our implementation, when tp is enabled, some tensors (e.g., mlp.weight) are on the (fsdp, tp) mesh and some (e.g., pos_embed.weight) are on the (fsdp, ) mesh, so we need to reduce separately before summing up.
There was a problem hiding this comment.
I believe you need to put them on TP mesh (as Replicate), instead of just let them hang there. O/w DCP won't be able to figure out how to save / load.
There was a problem hiding this comment.
Thanks for the suggestion! Used NoParallel to wrap norms and patch_embed, simplified a lot, and no longer need the extra logics to handle mixed non-EP meshes.
| self.enable_weight_tying = orig_weight_tying | ||
|
|
||
| if self.enable_weight_tying: | ||
| if self.tok_embeddings is not None and self.output is not None: |
There was a problem hiding this comment.
just curious, what happens when we use PP- we stop weight tying? is this algorithmically correct?
There was a problem hiding this comment.
I think of weight-tying as an extremal memory saving technique, only used for very small models (e.g. Qwen3 0.6B), where it doesn't need PP for training. So it should be OK that they are just disjoint, we need to make sure NotImplementedError is thrown properly though.
There was a problem hiding this comment.
Good catch. This is a workaround when we use PP on a model with enable_weight_tying being true, if we do not temporarily disable weight tying, it will hit an assertion error when we train from scratch. If we load from checkpoint, this workaround is still problematic and I am still working on it. Weight tying is used on small models like for qwen3-VL, it is only used for 2B model. We are considering removing this feature, i.e., error out when we apply PP on a model with weight tying. What do you think?
There was a problem hiding this comment.
i see. i didn't realize weight tying was strictly a memory saving technique. i assumed it also reduced # learnable parameters in a way that is meaningful to the convergence.
There was a problem hiding this comment.
Weight tying also introduces some complexity in initialization, since if we init output layer in the way of token embedding, the initial errors would be huge. So to make weight tying work well, we need add extra initialization logic.
Add qwen3 vl model support. To be refactored for new config systems. Currently support: Decoder (FSDP, EP, TP, PP), Visual encoder (FSDP, TP), to be adjusted based on needs and testing. Tested training on 2B, 8B, 30B-A3B models, from scratch and loading from hf checkpoint.