Skip to content

[rl] refactor torchtitan model registery in vllm#2194

Open
wwwjn wants to merge 23 commits intogh/wwwjn/5/basefrom
gh/wwwjn/5/head
Open

[rl] refactor torchtitan model registery in vllm#2194
wwwjn wants to merge 23 commits intogh/wwwjn/5/basefrom
gh/wwwjn/5/head

Conversation

[ghstack-poisoned]
wwwjn added a commit that referenced this pull request Jan 2, 2026
ghstack-source-id: 557ecd0
Pull Request resolved: #2194
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 2, 2026
@wwwjn wwwjn closed this Jan 2, 2026
wwwjn added a commit that referenced this pull request Jan 2, 2026
ghstack-source-id: 557ecd0
Pull Request resolved: #2194
@wwwjn wwwjn reopened this Jan 2, 2026
@wwwjn wwwjn changed the title refactor model registery [rl] refactor model registery Jan 2, 2026
wwwjn added a commit that referenced this pull request Jan 13, 2026
ghstack-source-id: 557ecd0
Pull Request resolved: #2194
@wwwjn wwwjn changed the title [rl] refactor model registery [rl] refactor torchtitan model registery in vllm Feb 19, 2026
# model_flavor during registration because we can not pass torchtitan job_config from LLM() Api
model_flavor="0.6B",
from torchtitan.experiments.rl.unified.infra.parallelism_utils import (
create_parallel_dims_from_vllm_config,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope we could put all torchtitan-vllm glue code in one file / folder, and carefully document why we need each class / method. This one sounds one of them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, refactored this part

[ghstack-poisoned]
[ghstack-poisoned]
wwwjn added a commit that referenced this pull request Feb 22, 2026
…ight tying (#2410)

Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.13.0)
(oldest at bottom):
* #2395
* #2244
* #2221
* #2194
* #2191
* __->__ #2410


This is a alternative fix to
#2402 (comment).

Weight updating between trainer and generator is totally broken because:
It's caused by we called "reload_weights" when updating the weights. The
reload_weights has following steps:

- initialize_layerwise_reload(model): Saves the current real GPU tensors
as info.kernel_tensors, and replace all parameters with meta tensor.
- Call model.load_weights(weights_iter): This function is written by us
and calls set_model_state_dict, Internally, set_model_state_dict tries
to do param.data.copy_(loaded_weight) for each parameter. When
parameters are meta tensor, it will do "no-op". So the weights never get
updated


In this PR:

- Totally bypass reload_weights, and don't load from a file when we
update the weights
- Gets the model via
self.engine.model_executor.driver_worker.get_model()
- Iterates over model.named_parameters() to find the matching parameter
by name
- Does param.data.copy_(new_tensor) directly
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants