Skip to content

[Qwen3VL] add qwen3 vl#2409

Draft
shuhuayu wants to merge 19 commits intopytorch:mainfrom
shuhuayu:modeldev
Draft

[Qwen3VL] add qwen3 vl#2409
shuhuayu wants to merge 19 commits intopytorch:mainfrom
shuhuayu:modeldev

Conversation

@shuhuayu
Copy link
Contributor

@shuhuayu shuhuayu commented Feb 20, 2026

Add qwen3 vl model support. To be refactored for new config systems. Currently support: Decoder (FSDP, EP, TP, PP), Visual encoder (FSDP, TP), to be adjusted based on needs and testing. Tested training on 2B, 8B, 30B-A3B models, from scratch and loading from hf checkpoint.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 20, 2026
@shuhuayu shuhuayu requested review from SherlockNoMad, fegin, tianyu-l, wconstab and wwwjn and removed request for SherlockNoMad February 20, 2026 20:59
@shuhuayu shuhuayu marked this pull request as draft February 20, 2026 21:00
if isinstance(non_ep_grads_total_norm, DTensor):
non_ep_grads_total_norm = non_ep_grads_total_norm.full_tensor()
# Group non-EP grads by mesh to handle mixed meshes (e.g., VLM models
# where vision encoder params are on (fsdp) mesh while decoder params
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If decoder has FSDP 2, TP 2 -- what FSDP degree does vision encoder has, 2 or 4?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is 2. Both decoder and visual encoder share the same fsdp mesh and tp mesh.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So on TP mesh, encoder is doing replicated computation? If so, why would you need special treatment on grad clipping here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the question. I think this is similar to the separate grouping of tensors on ep mesh and non-ep meshes. In our implementation, when tp is enabled, some tensors (e.g., mlp.weight) are on the (fsdp, tp) mesh and some (e.g., pos_embed.weight) are on the (fsdp, ) mesh, so we need to reduce separately before summing up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you need to put them on TP mesh (as Replicate), instead of just let them hang there. O/w DCP won't be able to figure out how to save / load.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! Used NoParallel to wrap norms and patch_embed, simplified a lot, and no longer need the extra logics to handle mixed non-EP meshes.

self.enable_weight_tying = orig_weight_tying

if self.enable_weight_tying:
if self.tok_embeddings is not None and self.output is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, what happens when we use PP- we stop weight tying? is this algorithmically correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think of weight-tying as an extremal memory saving technique, only used for very small models (e.g. Qwen3 0.6B), where it doesn't need PP for training. So it should be OK that they are just disjoint, we need to make sure NotImplementedError is thrown properly though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. This is a workaround when we use PP on a model with enable_weight_tying being true, if we do not temporarily disable weight tying, it will hit an assertion error when we train from scratch. If we load from checkpoint, this workaround is still problematic and I am still working on it. Weight tying is used on small models like for qwen3-VL, it is only used for 2B model. We are considering removing this feature, i.e., error out when we apply PP on a model with weight tying. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see. i didn't realize weight tying was strictly a memory saving technique. i assumed it also reduced # learnable parameters in a way that is meaningful to the convergence.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weight tying also introduces some complexity in initialization, since if we init output layer in the way of token embedding, the initial errors would be huge. So to make weight tying work well, we need add extra initialization logic.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants