[rl] refactor save and load model weights using DCP#2221
[rl] refactor save and load model weights using DCP#2221wwwjn wants to merge 21 commits intogh/wwwjn/6/basefrom
Conversation
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
|
I'm wondering that should we refactor TorchTitan checkpointer so that it can be directly used in this case. While the current PR work, if TorchTitan migrates to a new checkpoint library other use cases need the same updates as well. This is more future work, not blocking this PR. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
| job_config=self.parallel_config, | ||
| ) | ||
|
|
||
| # Initial load model weights from HuggingFace checkpoint path |
There was a problem hiding this comment.
Do we really need this complication?
IIUC we can do
- trainer loads checkpoint (can be HF one thanks to torchtitan sd adapters)
- generator only get weigts from trainer, even the initial one
- generator never needs to know HF checkpoints or worry about the conversion
There was a problem hiding this comment.
This is to make sure when it maintains the same expectation whenever / wherever you initialize a LLMEngine(model_path=<checkpointer_folder>, hf_override={"model": "TorchtitanVLLMModel"}). My "expectation" to this call is the underlaying model is initialized with the weights from checkpoint folder.
- vllm achieve this "expectation" by calling
TorchtitanVLLMWrapper.load_weights()when initializing the LLMEngine. - We achieve this "expectation" by: 1) make
TorchtitanVLLMWrapper.load_weights()a no-op, and 2) Call_initial_load_weightsin Wrapper's__init__.
This "expectation" would help inference as well.
The way you described would work perfectly in RL, but people might implicitly ignore the model weights are not initialized during inference. wdyt?
|
|
||
| return self.load_weights_from_state_dict(torchtitan_state_dict) | ||
|
|
||
| def load_weights(self, weights_iter): |
There was a problem hiding this comment.
Do we still need this function? Can it be deleted or overridden with "NotImplementedError"?
There was a problem hiding this comment.
This function will still be called internally by vLLM
…ight tying (#2410) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.13.0) (oldest at bottom): * #2395 * #2244 * #2221 * #2194 * #2191 * __->__ #2410 This is a alternative fix to #2402 (comment). Weight updating between trainer and generator is totally broken because: It's caused by we called "reload_weights" when updating the weights. The reload_weights has following steps: - initialize_layerwise_reload(model): Saves the current real GPU tensors as info.kernel_tensors, and replace all parameters with meta tensor. - Call model.load_weights(weights_iter): This function is written by us and calls set_model_state_dict, Internally, set_model_state_dict tries to do param.data.copy_(loaded_weight) for each parameter. When parameters are meta tensor, it will do "no-op". So the weights never get updated In this PR: - Totally bypass reload_weights, and don't load from a file when we update the weights - Gets the model via self.engine.model_executor.driver_worker.get_model() - Iterates over model.named_parameters() to find the matching parameter by name - Does param.data.copy_(new_tensor) directly
Stack from ghstack (oldest at bottom):
What's new in this PR