Skip to content

[NOT4REVIEW] Make RL model compilable: ops rewrite + per-sample backward#2394

Draft
Lucaskabela wants to merge 3 commits intopytorch:mainfrom
Lucaskabela:lucaskabela/rl-compilable
Draft

[NOT4REVIEW] Make RL model compilable: ops rewrite + per-sample backward#2394
Lucaskabela wants to merge 3 commits intopytorch:mainfrom
Lucaskabela:lucaskabela/rl-compilable

Conversation

@Lucaskabela
Copy link
Contributor

Rewrites batch-invariant ops as torch.library.custom_op (rms_norm, silu_and_mul, flash_attn) so they are opaque to Dynamo/AOT autograd. Adds aten dispatch overrides for matmul/linear backward to use vLLM's deterministic kernels.

Refactors compute_policy_gradient_loss_vllm to use per-sample gradient accumulation: each sample's forward is immediately followed by backward, keeping only one set of activations in memory at a time. This is a prerequisite for torch.compile since the compiled graph processes one sample at a time with fixed-shape inputs.

Changes:

  • batch_invariant_backward.py: custom ops rewrite
  • models/attention.py: custom op for flash_attn
  • simple_rl.py: per-sample backward, loss_scale param, timing metrics
  • trainer.py: move zero_grad before loss, remove loss.backward()

Authored with Claude

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 18, 2026
@Lucaskabela Lucaskabela force-pushed the lucaskabela/rl-compilable branch 2 times, most recently from 57ffef9 to e10c3b7 Compare February 19, 2026 19:17
Rewrites batch-invariant ops as torch.library.custom_op (rms_norm,
silu_and_mul, flash_attn) so they are opaque to Dynamo/AOT autograd.
Adds aten dispatch overrides for matmul/linear backward to use vLLM's
deterministic kernels.

Refactors compute_policy_gradient_loss_vllm to use per-sample gradient
accumulation: each sample's forward is immediately followed by backward,
keeping only one set of activations in memory at a time. This is a
prerequisite for torch.compile since the compiled graph processes one
sample at a time with fixed-shape inputs.

Changes:
- batch_invariant_backward.py: custom ops rewrite
- models/attention.py: custom op for flash_attn
- simple_rl.py: per-sample backward, loss_scale param, timing metrics
- trainer.py: move zero_grad before loss, remove loss.backward()
@Lucaskabela Lucaskabela force-pushed the lucaskabela/rl-compilable branch from e10c3b7 to 575a33d Compare February 23, 2026 16:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant