Skip to content

[DRAFT] Wire through RL compilation: compile_utils, graph_utils, config#5

Draft
Lucaskabela wants to merge 4 commits intolucaskabela/rl-compilablefrom
lucaskabela/rl-compile-wiring
Draft

[DRAFT] Wire through RL compilation: compile_utils, graph_utils, config#5
Lucaskabela wants to merge 4 commits intolucaskabela/rl-compilablefrom
lucaskabela/rl-compile-wiring

Conversation

@Lucaskabela
Copy link
Owner

Adds compilation support to the RL training loop:

  • graph_utils.py: return gm from export_joint, add validate_dtensor param, use gm.named_parameters(remove_duplicate=False) to handle tied weights
  • compile_utils.py: new file using joint_graph_builder with validate_dtensor=False, RLCompiledModule wrapper with input padding
  • simple_rl.py: add compilation config vars and compile_rl_model call
  • trainer.py: add compile params, compile before DDP
  • simple_rl_multiprocess.py: add compilation config and fw+bw timing

@Lucaskabela Lucaskabela force-pushed the lucaskabela/rl-compilable branch from ee75c66 to 57ffef9 Compare February 19, 2026 17:53
@Lucaskabela Lucaskabela force-pushed the lucaskabela/rl-compile-wiring branch from 86b8622 to 6e0185e Compare February 19, 2026 18:40
Rewrites batch-invariant ops as torch.library.custom_op (rms_norm,
silu_and_mul, flash_attn) so they are opaque to Dynamo/AOT autograd.
Adds aten dispatch overrides for matmul/linear backward to use vLLM's
deterministic kernels.

Refactors compute_policy_gradient_loss_vllm to use per-sample gradient
accumulation: each sample's forward is immediately followed by backward,
keeping only one set of activations in memory at a time. This is a
prerequisite for torch.compile since the compiled graph processes one
sample at a time with fixed-shape inputs.

Changes:
- batch_invariant_backward.py: custom ops rewrite
- models/attention.py: custom op for flash_attn
- simple_rl.py: per-sample backward, loss_scale param, timing metrics
- trainer.py: move zero_grad before loss, remove loss.backward()
@Lucaskabela Lucaskabela force-pushed the lucaskabela/rl-compilable branch from 57ffef9 to e10c3b7 Compare February 19, 2026 19:17
Rewrites batch-invariant ops as torch.library.custom_op (rms_norm,
silu_and_mul, flash_attn) so they are opaque to Dynamo/AOT autograd.
Adds aten dispatch overrides for matmul/linear backward to use vLLM's
deterministic kernels.

Refactors compute_policy_gradient_loss_vllm to use per-sample gradient
accumulation: each sample's forward is immediately followed by backward,
keeping only one set of activations in memory at a time. This is a
prerequisite for torch.compile since the compiled graph processes one
sample at a time with fixed-shape inputs.

Changes:
- batch_invariant_backward.py: custom ops rewrite
- models/attention.py: custom op for flash_attn
- simple_rl.py: per-sample backward, loss_scale param, timing metrics
- trainer.py: move zero_grad before loss, remove loss.backward()
Rewrites batch-invariant ops as torch.library.custom_op (rms_norm,
silu_and_mul, flash_attn) so they are opaque to Dynamo/AOT autograd.
Adds aten dispatch overrides for matmul/linear backward to use vLLM's
deterministic kernels.

Refactors compute_policy_gradient_loss_vllm to use per-sample gradient
accumulation: each sample's forward is immediately followed by backward,
keeping only one set of activations in memory at a time. This is a
prerequisite for torch.compile since the compiled graph processes one
sample at a time with fixed-shape inputs.

Changes:
- batch_invariant_backward.py: custom ops rewrite
- models/attention.py: custom op for flash_attn
- simple_rl.py: per-sample backward, loss_scale param, timing metrics
- trainer.py: move zero_grad before loss, remove loss.backward()
Adds compilation support to the RL training loop:

- graph_utils.py: return gm from export_joint, add validate_dtensor param,
  use gm.named_parameters(remove_duplicate=False) to handle tied weights
- compile_utils.py: new file using joint_graph_builder with
  validate_dtensor=False, RLCompiledModule wrapper with input padding
- simple_rl.py: add compilation config vars and compile_rl_model call
- trainer.py: add compile params, compile before DDP
- simple_rl_multiprocess.py: add compilation config and fw+bw timing
@Lucaskabela Lucaskabela force-pushed the lucaskabela/rl-compile-wiring branch from 6e0185e to 962d80f Compare February 19, 2026 22:07
@Lucaskabela Lucaskabela force-pushed the lucaskabela/rl-compilable branch from e10c3b7 to 575a33d Compare February 23, 2026 16:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant