Skip to content

[RL] Add sum digits task and pre and post training evaluation#2423

Open
daniellepintz wants to merge 8 commits intomainfrom
dp/sum_digits
Open

[RL] Add sum digits task and pre and post training evaluation#2423
daniellepintz wants to merge 8 commits intomainfrom
dp/sum_digits

Conversation

@daniellepintz
Copy link
Contributor

  • Add sum digits task which model can learn and improve on in 12 steps
  • Add evaluate function and runs evaluation before and after training
  • Convert the INFO logging which runs on every step to DEBUG logging so we can visualize loss and rewards
  • Add per-token normalization of log probs to avoid long completions causing loss explosion

Final result:

[2026-02-23 06:44:14] INFO simple_rl_multiprocess.py:205: [actor=<root>] Evaluating pre-training baseline...
[2026-02-23 06:44:24] INFO simple_rl_multiprocess.py:87: [actor=<root>] Eval: Accuracy=20% (4/20) Format=30% (6/20)
[2026-02-23 06:44:24] INFO simple_rl_multiprocess.py:209: [actor=<root>] ================================================================================
[2026-02-23 06:44:24] INFO simple_rl_multiprocess.py:210: [actor=<root>] Starting RL training for 12 steps
[2026-02-23 06:44:24] INFO simple_rl_multiprocess.py:211: [actor=<root>] ================================================================================
NCCL version 2.28.9+cuda12.9
/home/daniellepintz/torchtitan/titan-rl-env/lib/python3.12/site-packages/torch/distributed/c10d_logger.py:83: UserWarning: barrier(): using the device under current context. You can specify `device_id` in `init_process_group` to mute this warning.
  return func(*args, **kwargs)
[2026-02-23 06:45:06] INFO simple_rl_multiprocess.py:222: [actor=<root>] Step   0 | Loss: -0.0057 | Reward: -0.450 | Correct: 11/40
[2026-02-23 06:45:41] INFO simple_rl_multiprocess.py:222: [actor=<root>] Step   1 | Loss: -0.0051 | Reward: -0.600 | Correct: 8/40
[2026-02-23 06:46:19] INFO simple_rl_multiprocess.py:222: [actor=<root>] Step   2 | Loss: -0.0049 | Reward: -0.100 | Correct: 18/40
[2026-02-23 06:46:55] INFO simple_rl_multiprocess.py:222: [actor=<root>] Step   3 | Loss: -0.0047 | Reward: +0.205 | Correct: 24/40
[2026-02-23 06:47:31] INFO simple_rl_multiprocess.py:222: [actor=<root>] Step   4 | Loss: -0.0051 | Reward: +0.100 | Correct: 22/40
[2026-02-23 06:48:07] INFO simple_rl_multiprocess.py:222: [actor=<root>] Step   5 | Loss: -0.0046 | Reward: +0.000 | Correct: 20/40
[2026-02-23 06:48:42] INFO simple_rl_multiprocess.py:222: [actor=<root>] Step   6 | Loss: -0.0045 | Reward: +0.100 | Correct: 22/40
[2026-02-23 06:49:17] INFO simple_rl_multiprocess.py:222: [actor=<root>] Step   7 | Loss: -0.0043 | Reward: +0.100 | Correct: 22/40
[2026-02-23 06:49:52] INFO simple_rl_multiprocess.py:222: [actor=<root>] Step   8 | Loss: -0.0039 | Reward: +0.450 | Correct: 29/40
[2026-02-23 06:50:28] INFO simple_rl_multiprocess.py:222: [actor=<root>] Step   9 | Loss: -0.0039 | Reward: +0.400 | Correct: 28/40
[2026-02-23 06:51:03] INFO simple_rl_multiprocess.py:222: [actor=<root>] Step  10 | Loss: -0.0035 | Reward: +0.650 | Correct: 33/40
[2026-02-23 06:51:39] INFO simple_rl_multiprocess.py:222: [actor=<root>] Step  11 | Loss: -0.0034 | Reward: +0.550 | Correct: 31/40
[2026-02-23 06:51:39] INFO simple_rl_multiprocess.py:249: [actor=<root>] RL Training complete
[2026-02-23 06:51:39] INFO simple_rl_multiprocess.py:250: [actor=<root>] Evaluating post-training performance...
[2026-02-23 06:51:50] INFO simple_rl_multiprocess.py:87: [actor=<root>] Eval: Accuracy=70% (14/20) Format=100% (20/20)
[2026-02-23 06:51:50] INFO simple_rl_multiprocess.py:253: [actor=<root>] ================================================================================
[2026-02-23 06:51:50] INFO simple_rl_multiprocess.py:254: [actor=<root>] Pre-training:  Accuracy=20% (4/20) Format=30% (6/20)
[2026-02-23 06:51:50] INFO simple_rl_multiprocess.py:258: [actor=<root>] Post-training: Accuracy=70% (14/20) Format=100% (20/20)
[2026-02-23 06:51:50] INFO simple_rl_multiprocess.py:262: [actor=<root>] ================================================================================

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 23, 2026
@daniellepintz daniellepintz changed the title Dp/sum digits [RL] Add sum digits task and pre and post training evaluation Feb 23, 2026
)

outputs = self.llm.generate(prompt_texts, sampling_params)
outputs = self.llm.generate(prompt_texts, sampling_params, use_tqdm=False)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to avoid spamming terminal on each step


# Vibe coding
.claude

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will remove this change before landing

# only generator rank 0 saves the weight
if torch.distributed.get_rank() == 0:
logger.info(f"Saving weights to {checkpoint_path}")
logger.debug(f"Saving weights to {checkpoint_path}")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to avoid spamming terminal on each step

num_steps = 10
learning_rate = 1e-5
max_new_tokens = 20
num_steps = 12
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reasoning behind this exact number of steps?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what I experimented with and it worked well - but is somewhat arbitrary and I can change to something else if you prefer

task_spec = SumDigitsSpec(seed=42)
system_prompt = task_spec.get_system_prompt()

prompt_texts = []
Copy link
Member

@joecummings joecummings Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's my understanding of the flow here:

  1. prompt_text of size num_prompt is gathered. Here it's 5.
  2. These prompt texts are passed to the generator at instantiation time
  3. Every step, the generator uses the SAME prompt texts to generate completions. No new prompts are sampled from the dataloader (task)

This means for num_steps, while completions will be slightly different b/c of a high temperature, the model is only ever trained to respond to 5 prompts. Therefore, this num_steps is actually more accurately number of ppo epochs and num_steps is actually 1. I don't think this is what we want as it doesn't follow the vanilla GRPO formulation.

It looks as though this has been here since before these changes, so please let me know if I'm missing anything here @wwwjn, but I don't immediately see where new prompts are passed to the generator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch, thanks Joe!

I fixed this, and luckily results are still good:

[2026-02-23 07:59:38] INFO simple_rl_multiprocess.py:237: [actor=<root>] Pre-training:  Accuracy=45% (9/20) Format=45% (9/20)
[2026-02-23 07:59:38] INFO simple_rl_multiprocess.py:241: [actor=<root>] Post-training: Accuracy=70% (14/20) Format=95% (19/20)

Copy link
Contributor

@allenwang28 allenwang28 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @daniellepintz!

@@ -0,0 +1,116 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a nit on file placement, but would it make more sense to put this in torchtitan/experimental/rl rather than only in unified?

Copy link
Contributor Author

@daniellepintz daniellepintz Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes that makes sense- updated.

my thinking was that we are eventually going to collapse rl/unified and rl/vllm_compat anyway to live just under rl/ - @wwwjn is that the ultimate direction or no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually plan to remove vllm_compat once we suppport bit-wise identity model on unified path, so I guess it makes sense to put the sum_digits.py under unified

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to @joecummings comment I would expect that we iterate on the prompts as a dataloader and pass prompts to the generator here

@wconstab
Copy link
Contributor

nit: this caught my eye- ping me on workchat if you want help to figure out how to squelch it

/home/daniellepintz/torchtitan/titan-rl-env/lib/python3.12/site-packages/torch/distributed/c10d_logger.py:83: UserWarning: barrier(): using the device under current context. You can specify device_id in init_process_group to mute this warning.

Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this PR might need rebase after the #2191 landed

@@ -0,0 +1,116 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually plan to remove vllm_compat once we suppport bit-wise identity model on unified path, so I guess it makes sense to put the sum_digits.py under unified

logger.info(f"Loaded {len(prompt_texts)} prompts")
# Task spec
task_spec = SumDigitsSpec(seed=42)
system_prompt = task_spec.get_system_prompt()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The concept of task is actually the same as dataloader + reward calculation. Once we rebase onto the config system change, we would need proper config for Task as well. I think we are moving towards "everything is configurable" idea

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants