Skip to content

[WIP][rl] enable batch-invariant mode in RL loop#2395

Open
wwwjn wants to merge 12 commits intogh/wwwjn/12/basefrom
gh/wwwjn/12/head
Open

[WIP][rl] enable batch-invariant mode in RL loop#2395
wwwjn wants to merge 12 commits intogh/wwwjn/12/basefrom
gh/wwwjn/12/head

Conversation

[ghstack-poisoned]
wwwjn added a commit that referenced this pull request Feb 19, 2026
ghstack-source-id: ae5abed
Pull-Request: #2395
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 19, 2026
@wwwjn wwwjn changed the title enable batch-invariant mode in RL loop [WIP][rl] enable batch-invariant mode in RL loop Feb 19, 2026
# Batch invariant mode: set NCCL determinism env vars
policy_opt = job_config.policy_optimization
if policy_opt.batch_invariant_mode:
_set_nccl_determinism_envs()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move this to line 125? also curious does generator need to set this?

[ghstack-poisoned]
wwwjn added a commit that referenced this pull request Feb 20, 2026
ghstack-source-id: bcacf2b
Pull-Request: #2395
[ghstack-poisoned]
wwwjn added a commit that referenced this pull request Feb 21, 2026
ghstack-source-id: 585ee6c
Pull-Request: #2395
wwwjn added a commit that referenced this pull request Feb 22, 2026
…ight tying (#2410)

Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.13.0)
(oldest at bottom):
* #2395
* #2244
* #2221
* #2194
* #2191
* __->__ #2410


This is a alternative fix to
#2402 (comment).

Weight updating between trainer and generator is totally broken because:
It's caused by we called "reload_weights" when updating the weights. The
reload_weights has following steps:

- initialize_layerwise_reload(model): Saves the current real GPU tensors
as info.kernel_tensors, and replace all parameters with meta tensor.
- Call model.load_weights(weights_iter): This function is written by us
and calls set_model_state_dict, Internally, set_model_state_dict tries
to do param.data.copy_(loaded_weight) for each parameter. When
parameters are meta tensor, it will do "no-op". So the weights never get
updated


In this PR:

- Totally bypass reload_weights, and don't load from a file when we
update the weights
- Gets the model via
self.engine.model_executor.driver_worker.get_model()
- Iterates over model.named_parameters() to find the matching parameter
by name
- Does param.data.copy_(new_tensor) directly
[ghstack-poisoned]
wwwjn added a commit that referenced this pull request Feb 23, 2026
ghstack-source-id: d50af97
Pull-Request: #2395
[ghstack-poisoned]
wwwjn added a commit that referenced this pull request Feb 23, 2026
ghstack-source-id: 1de5a75
Pull-Request: #2395
[ghstack-poisoned]
wwwjn added a commit that referenced this pull request Feb 23, 2026
ghstack-source-id: 4778303
Pull-Request: #2395
[ghstack-poisoned]
wwwjn added a commit that referenced this pull request Feb 24, 2026
ghstack-source-id: e49afbe
Pull-Request: #2395
[ghstack-poisoned]
wwwjn added a commit that referenced this pull request Feb 24, 2026
ghstack-source-id: 05604f4
Pull-Request: #2395
[ghstack-poisoned]
wwwjn added a commit that referenced this pull request Feb 24, 2026
ghstack-source-id: 51cc080
Pull-Request: #2395
[ghstack-poisoned]
wwwjn added a commit that referenced this pull request Feb 24, 2026
ghstack-source-id: 1f2c104
Pull-Request: #2395
[ghstack-poisoned]
wwwjn added a commit that referenced this pull request Feb 24, 2026
ghstack-source-id: 6ef380d
Pull-Request: #2395
[ghstack-poisoned]
wwwjn added a commit that referenced this pull request Feb 25, 2026
ghstack-source-id: 5b576bd
Pull-Request: #2395
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants