Skip to content

[rl] refactor save and load model weights using DCP#2221

Open
wwwjn wants to merge 21 commits intogh/wwwjn/6/basefrom
gh/wwwjn/6/head
Open

[rl] refactor save and load model weights using DCP#2221
wwwjn wants to merge 21 commits intogh/wwwjn/6/basefrom
gh/wwwjn/6/head

Conversation

@wwwjn
Copy link
Contributor

@wwwjn wwwjn commented Jan 13, 2026

Stack from ghstack (oldest at bottom):

What's new in this PR

  • Directly passing weights as tensor (plain tensor) from trainer to generator.
  • Remove the burden of writing and reading from files.
  • Supported test: Trainer supports DDP, generator only supports TP=1 (no DTensor in both sides yet)

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 13, 2026
wwwjn added a commit that referenced this pull request Jan 13, 2026
ghstack-source-id: bcd9f5e
Pull Request resolved: #2221
@wwwjn wwwjn changed the title refactor save and load model weights using DCP [WIP] refactor save and load model weights using DCP Jan 13, 2026
wwwjn added a commit that referenced this pull request Jan 13, 2026
ghstack-source-id: b7642b4
Pull Request resolved: #2221
wwwjn added a commit that referenced this pull request Jan 13, 2026
ghstack-source-id: b7642b4
Pull Request resolved: #2221
wwwjn added a commit that referenced this pull request Jan 14, 2026
ghstack-source-id: 87a29dc
Pull Request resolved: #2221
@wwwjn wwwjn changed the title [WIP] refactor save and load model weights using DCP [rl] refactor save and load model weights using DCP Jan 14, 2026
@fegin
Copy link
Contributor

fegin commented Jan 14, 2026

I'm wondering that should we refactor TorchTitan checkpointer so that it can be directly used in this case. While the current PR work, if TorchTitan migrates to a new checkpoint library other use cases need the same updates as well. This is more future work, not blocking this PR.

@wwwjn wwwjn changed the title [rl] refactor save and load model weights using DCP [WIP][rl] refactor save and load model weights using DCP Jan 30, 2026
@wwwjn wwwjn changed the title [WIP][rl] refactor save and load model weights using DCP [rl] refactor save and load model weights using DCP Feb 14, 2026
job_config=self.parallel_config,
)

# Initial load model weights from HuggingFace checkpoint path
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need this complication?
IIUC we can do

  • trainer loads checkpoint (can be HF one thanks to torchtitan sd adapters)
  • generator only get weigts from trainer, even the initial one
  • generator never needs to know HF checkpoints or worry about the conversion

Copy link
Contributor Author

@wwwjn wwwjn Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to make sure when it maintains the same expectation whenever / wherever you initialize a LLMEngine(model_path=<checkpointer_folder>, hf_override={"model": "TorchtitanVLLMModel"}). My "expectation" to this call is the underlaying model is initialized with the weights from checkpoint folder.

  • vllm achieve this "expectation" by calling TorchtitanVLLMWrapper.load_weights() when initializing the LLMEngine.
  • We achieve this "expectation" by: 1) make TorchtitanVLLMWrapper.load_weights() a no-op, and 2) Call _initial_load_weights in Wrapper's __init__.

This "expectation" would help inference as well.

The way you described would work perfectly in RL, but people might implicitly ignore the model weights are not initialized during inference. wdyt?


return self.load_weights_from_state_dict(torchtitan_state_dict)

def load_weights(self, weights_iter):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need this function? Can it be deleted or overridden with "NotImplementedError"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function will still be called internally by vLLM

[ghstack-poisoned]
[ghstack-poisoned]
wwwjn added a commit that referenced this pull request Feb 22, 2026
…ight tying (#2410)

Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.13.0)
(oldest at bottom):
* #2395
* #2244
* #2221
* #2194
* #2191
* __->__ #2410


This is a alternative fix to
#2402 (comment).

Weight updating between trainer and generator is totally broken because:
It's caused by we called "reload_weights" when updating the weights. The
reload_weights has following steps:

- initialize_layerwise_reload(model): Saves the current real GPU tensors
as info.kernel_tensors, and replace all parameters with meta tensor.
- Call model.load_weights(weights_iter): This function is written by us
and calls set_model_state_dict, Internally, set_model_state_dict tries
to do param.data.copy_(loaded_weight) for each parameter. When
parameters are meta tensor, it will do "no-op". So the weights never get
updated


In this PR:

- Totally bypass reload_weights, and don't load from a file when we
update the weights
- Gets the model via
self.engine.model_executor.driver_worker.get_model()
- Iterates over model.named_parameters() to find the matching parameter
by name
- Does param.data.copy_(new_tensor) directly
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants