Skip to content

Fix weight transfer in simple_multiprocess_rl to restore bitwise parity#2402

Closed
Lucaskabela wants to merge 1 commit intopytorch:mainfrom
Lucaskabela:lucaskabela/fix_simple_multiprocess_rl
Closed

Fix weight transfer in simple_multiprocess_rl to restore bitwise parity#2402
Lucaskabela wants to merge 1 commit intopytorch:mainfrom
Lucaskabela:lucaskabela/fix_simple_multiprocess_rl

Conversation

@Lucaskabela
Copy link
Contributor

@Lucaskabela Lucaskabela commented Feb 20, 2026

Summary

Running simple_multiprocess_rl.py on main does not have logit parity, and further shows rewards that do not change in different steps. This divergence only happens after step 0, and does not happen in simple_rl.py.

In this PR, we submit a fix for this divergence, which we identify as being from two issues:

  1. When model has reused (tied) embeddings, there is no output weight but this needs to be plumbed through in the vLLM definition

  2. When loading, iterate the named params explicitly using the supported weight_loader api - this prevents bug of not updating the weight as, according to claude

set_model_state_dict() bypasses vLLM's online_process_loader mechanism. During reload_weights(), initialize_layerwise_reload() puts all parameters on META device and wraps each parameter's weight_loader. When set_model_state_dict() is used, none of those wrapped weight_loaders are triggered.

With these two changes, we are able to get bitwise parity again in all steps

Test Plan

with-proxy VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py

Baseline

Step   8 | Loss: 0.2140 | Reward: +1.177
  Phase                Time    Peak Mem
  --------------------------------------
  rollout:            0.93s,   4.09 GiB  (4.3%)
  train:              5.81s,   9.58 GiB  (10.1%)
  optimizer:          0.03s,   8.43 GiB  (8.9%)
  weight_sync:        6.77s
  total:             13.51s
[2026-02-20 10:19:02] INFO simple_rl_multiprocess.py:258: [actor=<root>]   Sample:  Paris, the old city of the Roman Empire. This anciently oriented city has been ...
[2026-02-20 10:19:02] INFO generator.py:400: [actor=<root>.<torchtitan.experiments.rl.unified.actors.generator.Generator generator{'gpus': 0/1}>] os.getpid()=1563090 Generating start generate (policy v9)...
Rendering prompts: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 1258.12it/s]
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 48.96it/s, est. speed input: 342.78 toks/s, output: 979.38 toks/s]
[2026-02-20 10:19:03] INFO generator.py:483: [actor=<root>.<torchtitan.experiments.rl.unified.actors.generator.Generator generator{'gpus': 0/1}>] os.getpid()=1563090 Generating finish generate (policy v9)...
[2026-02-20 10:19:03] INFO trainer.py:111: [actor=<root>.<torchtitan.experiments.rl.unified.actors.trainer.Trainer trainer{'gpus': 1/2}>] os.getpid()=1562482 Trainer starts to train 9 on traj:
[2026-02-20 10:19:03] INFO trainer.py:111: [actor=<root>.<torchtitan.experiments.rl.unified.actors.trainer.Trainer trainer{'gpus': 0/2}>] os.getpid()=1562001 Trainer starts to train 9 on traj:
  ⚠ vLLM-TorchTitan logprobs differ: 20/20 tokens
    Max delta: 8.817692e-01, Avg delta: 1.767104e-01
    vLLM logprobs:     ['-0.4551636875', '-0.8292556405', '-2.7441515923', '-7.0027947426', '-0.9324634671']
    TorchTitan logprobs: ['-0.1245563552', '-0.7034843564', '-2.9534451962', '-6.1210255623', '-0.8873867393']
  ⚠ vLLM-TorchTitan logprobs differ: 20/20 tokens
    Max delta: 8.817692e-01, Avg delta: 1.767104e-01
    vLLM logprobs:     ['-0.4551636875', '-0.8292556405', '-2.7441515923', '-7.0027947426', '-0.9324634671']
    TorchTitan logprobs: ['-0.1245563552', '-0.7034843564', '-2.9534451962', '-6.1210255623', '-0.8873867393']

This PR

Step   8 | Loss: -0.0169 | Reward: +1.913
  Phase                Time    Peak Mem
  --------------------------------------
  rollout:            0.82s,   4.09 GiB  (4.3%)
  train:              5.58s,   8.15 GiB  (8.6%)
  optimizer:          0.02s,   6.69 GiB  (7.0%)
  weight_sync:        6.83s
  total:             13.24s
[2026-02-20 10:02:11] INFO simple_rl_multiprocess.py:258: [actor=<root>]   Sample:  Paris, the old city of the Roman Empire. This anciently oriented city has been ...
[2026-02-20 10:02:11] INFO generator.py:400: [actor=<root>.<torchtitan.experiments.rl.unified.actors.generator.Generator generator{'gpus': 0/1}>] os.getpid()=403563 Generating start generate (policy v9)...
Rendering prompts: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 2665.42it/s]
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 49.67it/s, est. speed input: 347.76 toks/s, output: 993.60 toks/s]
[2026-02-20 10:02:12] INFO generator.py:483: [actor=<root>.<torchtitan.experiments.rl.unified.actors.generator.Generator generator{'gpus': 0/1}>] os.getpid()=403563 Generating finish generate (policy v9)...
[2026-02-20 10:02:12] INFO trainer.py:111: [actor=<root>.<torchtitan.experiments.rl.unified.actors.trainer.Trainer trainer{'gpus': 1/2}>] os.getpid()=403106 Trainer starts to train 9 on traj:
[2026-02-20 10:02:12] INFO trainer.py:111: [actor=<root>.<torchtitan.experiments.rl.unified.actors.trainer.Trainer trainer{'gpus': 0/2}>] os.getpid()=402650 Trainer starts to train 9 on traj:
  ✓ vLLM-TorchTitan bitwise determinism verified: 20 tokens match exactly
  ✓ vLLM-TorchTitan bitwise determinism verified: 20 tokens match exactly

Timing/Memory comparison

using #2398 we ensure this different weight loading code doesn't degrade performance

Metric Baseline % After Changes %
Total wall-clock 135.15s 135.99s
Cumul. rollout 9.45s 7.0% 9.35s 6.9%
Cumul. train 56.23s 41.6% 55.30s 40.7%
Cumul. optimizer 0.31s 0.2% 0.26s 0.2%
Cumul. weight_sync 69.47s 51.4% 71.34s 52.5%
Peak mem (rollout) 4.09 GiB 4.09 GiB
Peak mem (train) 9.58 GiB 8.15 GiB
Peak mem (optimizer) 8.44 GiB 6.70 GiB

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 20, 2026
@Lucaskabela Lucaskabela changed the title Bring simple_multiprocess_rl back to parity Fix weight transfer in simple_multiprocess_rl to restore bitwise parity Feb 20, 2026
@Lucaskabela Lucaskabela marked this pull request as ready for review February 20, 2026 18:37
@wwwjn
Copy link
Contributor

wwwjn commented Feb 20, 2026

Thanks, I found this issue yesterday too after updating to vllm nightly version. Do you know what's the recent change in vllm causing this issue?

Claude gives me the same solution for "set_model_state_dict() bypasses vLLM's online_process_loader mechanism. During reload_weights(), initialize_layerwise_reload() puts all parameters on META device and wraps each parameter's weight_loader. When set_model_state_dict() is used, none of those wrapped weight_loaders are triggered." But I would want to learn more details about this issue and how would it affect loading state_dict before we move with this solution

@wwwjn
Copy link
Contributor

wwwjn commented Feb 20, 2026

I see the issue with the help of claude:

It's caused by we called "reload_weights" when updating the weights. The reload_weights has following steps:

  1. initialize_layerwise_reload(model): Saves the current real GPU tensors as info.kernel_tensors, and replace all parameters with meta tensor.
  2. Call model.load_weights(weights_iter): This function is written by us and calls set_model_state_dict, Internally, set_model_state_dict tries to do param.data.copy_(loaded_weight) for each parameter. When parameters are meta tensor, it will do "no-op". So the weights never get updated.

I would prefer the following solution, as it's more strateforward, and will remove the set_state_dict shallow wrapper for now

  • Totally bypass reload_weights, and don't load from a file when we update the weights
  • Gets the model via self.engine.model_executor.driver_worker.get_model()
  • Iterates over model.named_parameters() to find the matching parameter by name
  • Does param.data.copy_(new_tensor) directly

I can put up a PR to fix this one, the tied_weight fix looks good to me

wwwjn added a commit that referenced this pull request Feb 22, 2026
…ight tying (#2410)

Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.13.0)
(oldest at bottom):
* #2395
* #2244
* #2221
* #2194
* #2191
* __->__ #2410


This is a alternative fix to
#2402 (comment).

Weight updating between trainer and generator is totally broken because:
It's caused by we called "reload_weights" when updating the weights. The
reload_weights has following steps:

- initialize_layerwise_reload(model): Saves the current real GPU tensors
as info.kernel_tensors, and replace all parameters with meta tensor.
- Call model.load_weights(weights_iter): This function is written by us
and calls set_model_state_dict, Internally, set_model_state_dict tries
to do param.data.copy_(loaded_weight) for each parameter. When
parameters are meta tensor, it will do "no-op". So the weights never get
updated


In this PR:

- Totally bypass reload_weights, and don't load from a file when we
update the weights
- Gets the model via
self.engine.model_executor.driver_worker.get_model()
- Iterates over model.named_parameters() to find the matching parameter
by name
- Does param.data.copy_(new_tensor) directly
@Lucaskabela
Copy link
Contributor Author

Closing since this is resolved in #2410

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants