Skip to content

[rl] bypass reload_weights by manually copy weights per-param, fix weight tying#2410

Merged
wwwjn merged 4 commits intomainfrom
gh/wwwjn/13/head
Feb 22, 2026
Merged

[rl] bypass reload_weights by manually copy weights per-param, fix weight tying#2410
wwwjn merged 4 commits intomainfrom
gh/wwwjn/13/head

Conversation

@wwwjn
Copy link
Contributor

@wwwjn wwwjn commented Feb 20, 2026

Stack from ghstack (oldest at bottom):

This is a alternative fix to #2402 (comment).

Weight updating between trainer and generator is totally broken because: It's caused by we called "reload_weights" when updating the weights. The reload_weights has following steps:

  • initialize_layerwise_reload(model): Saves the current real GPU tensors as info.kernel_tensors, and replace all parameters with meta tensor.
  • Call model.load_weights(weights_iter): This function is written by us and calls set_model_state_dict, Internally, set_model_state_dict tries to do param.data.copy_(loaded_weight) for each parameter. When parameters are meta tensor, it will do "no-op". So the weights never get updated

In this PR:

  • Totally bypass reload_weights, and don't load from a file when we update the weights
  • Gets the model via self.engine.model_executor.driver_worker.get_model()
  • Iterates over model.named_parameters() to find the matching parameter by name
  • Does param.data.copy_(new_tensor) directly

[ghstack-poisoned]
[ghstack-poisoned]
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 20, 2026
@wwwjn wwwjn changed the title bypass reload_weights by manually copy weights per-param [rl] bypass reload_weights by manually copy weights per-param Feb 20, 2026
[ghstack-poisoned]
wwwjn added a commit that referenced this pull request Feb 20, 2026
[ghstack-poisoned]
@wwwjn wwwjn changed the title [rl] bypass reload_weights by manually copy weights per-param [rl] bypass reload_weights by manually copy weights per-param, fix weight tying Feb 21, 2026
@wwwjn wwwjn changed the base branch from gh/wwwjn/13/base to main February 22, 2026 23:17
@wwwjn wwwjn merged commit bc9f3ff into main Feb 22, 2026
13 of 22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants