Fix weight transfer in simple_multiprocess_rl to restore bitwise parity#2402
Fix weight transfer in simple_multiprocess_rl to restore bitwise parity#2402Lucaskabela wants to merge 1 commit intopytorch:mainfrom
Conversation
|
Thanks, I found this issue yesterday too after updating to vllm nightly version. Do you know what's the recent change in vllm causing this issue? Claude gives me the same solution for "set_model_state_dict() bypasses vLLM's online_process_loader mechanism. During reload_weights(), initialize_layerwise_reload() puts all parameters on META device and wraps each parameter's weight_loader. When set_model_state_dict() is used, none of those wrapped weight_loaders are triggered." But I would want to learn more details about this issue and how would it affect loading state_dict before we move with this solution |
|
I see the issue with the help of claude: It's caused by we called "reload_weights" when updating the weights. The
I would prefer the following solution, as it's more strateforward, and will remove the
I can put up a PR to fix this one, the tied_weight fix looks good to me |
…ight tying (#2410) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.13.0) (oldest at bottom): * #2395 * #2244 * #2221 * #2194 * #2191 * __->__ #2410 This is a alternative fix to #2402 (comment). Weight updating between trainer and generator is totally broken because: It's caused by we called "reload_weights" when updating the weights. The reload_weights has following steps: - initialize_layerwise_reload(model): Saves the current real GPU tensors as info.kernel_tensors, and replace all parameters with meta tensor. - Call model.load_weights(weights_iter): This function is written by us and calls set_model_state_dict, Internally, set_model_state_dict tries to do param.data.copy_(loaded_weight) for each parameter. When parameters are meta tensor, it will do "no-op". So the weights never get updated In this PR: - Totally bypass reload_weights, and don't load from a file when we update the weights - Gets the model via self.engine.model_executor.driver_worker.get_model() - Iterates over model.named_parameters() to find the matching parameter by name - Does param.data.copy_(new_tensor) directly
|
Closing since this is resolved in #2410 |
Summary
Running
simple_multiprocess_rl.pyon main does not have logit parity, and further shows rewards that do not change in different steps. This divergence only happens after step 0, and does not happen insimple_rl.py.In this PR, we submit a fix for this divergence, which we identify as being from two issues:
When model has reused (tied) embeddings, there is no output weight but this needs to be plumbed through in the vLLM definition
When loading, iterate the named params explicitly using the supported weight_loader api - this prevents bug of not updating the weight as, according to claude
With these two changes, we are able to get bitwise parity again in all steps
Test Plan
Baseline
This PR
Timing/Memory comparison
using #2398 we ensure this different weight loading code doesn't degrade performance