[RL] Add sum digits task and pre and post training evaluation#2423
[RL] Add sum digits task and pre and post training evaluation#2423daniellepintz wants to merge 8 commits intomainfrom
Conversation
| ) | ||
|
|
||
| outputs = self.llm.generate(prompt_texts, sampling_params) | ||
| outputs = self.llm.generate(prompt_texts, sampling_params, use_tqdm=False) |
There was a problem hiding this comment.
changed to avoid spamming terminal on each step
|
|
||
| # Vibe coding | ||
| .claude | ||
|
|
There was a problem hiding this comment.
will remove this change before landing
| # only generator rank 0 saves the weight | ||
| if torch.distributed.get_rank() == 0: | ||
| logger.info(f"Saving weights to {checkpoint_path}") | ||
| logger.debug(f"Saving weights to {checkpoint_path}") |
There was a problem hiding this comment.
changed to avoid spamming terminal on each step
| num_steps = 10 | ||
| learning_rate = 1e-5 | ||
| max_new_tokens = 20 | ||
| num_steps = 12 |
There was a problem hiding this comment.
What's the reasoning behind this exact number of steps?
There was a problem hiding this comment.
This is what I experimented with and it worked well - but is somewhat arbitrary and I can change to something else if you prefer
| task_spec = SumDigitsSpec(seed=42) | ||
| system_prompt = task_spec.get_system_prompt() | ||
|
|
||
| prompt_texts = [] |
There was a problem hiding this comment.
Here's my understanding of the flow here:
prompt_textof sizenum_promptis gathered. Here it's 5.- These prompt texts are passed to the generator at instantiation time
- Every step, the generator uses the SAME prompt texts to generate completions. No new prompts are sampled from the dataloader (task)
This means for num_steps, while completions will be slightly different b/c of a high temperature, the model is only ever trained to respond to 5 prompts. Therefore, this num_steps is actually more accurately number of ppo epochs and num_steps is actually 1. I don't think this is what we want as it doesn't follow the vanilla GRPO formulation.
It looks as though this has been here since before these changes, so please let me know if I'm missing anything here @wwwjn, but I don't immediately see where new prompts are passed to the generator.
There was a problem hiding this comment.
Great catch, thanks Joe!
I fixed this, and luckily results are still good:
[2026-02-23 07:59:38] INFO simple_rl_multiprocess.py:237: [actor=<root>] Pre-training: Accuracy=45% (9/20) Format=45% (9/20)
[2026-02-23 07:59:38] INFO simple_rl_multiprocess.py:241: [actor=<root>] Post-training: Accuracy=70% (14/20) Format=95% (19/20)
allenwang28
left a comment
There was a problem hiding this comment.
Thanks @daniellepintz!
| @@ -0,0 +1,116 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |||
There was a problem hiding this comment.
a nit on file placement, but would it make more sense to put this in torchtitan/experimental/rl rather than only in unified?
There was a problem hiding this comment.
yes that makes sense- updated.
my thinking was that we are eventually going to collapse rl/unified and rl/vllm_compat anyway to live just under rl/ - @wwwjn is that the ultimate direction or no?
There was a problem hiding this comment.
I actually plan to remove vllm_compat once we suppport bit-wise identity model on unified path, so I guess it makes sense to put the sum_digits.py under unified
There was a problem hiding this comment.
similar to @joecummings comment I would expect that we iterate on the prompts as a dataloader and pass prompts to the generator here
|
nit: this caught my eye- ping me on workchat if you want help to figure out how to squelch it
|
| @@ -0,0 +1,116 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |||
There was a problem hiding this comment.
I actually plan to remove vllm_compat once we suppport bit-wise identity model on unified path, so I guess it makes sense to put the sum_digits.py under unified
| logger.info(f"Loaded {len(prompt_texts)} prompts") | ||
| # Task spec | ||
| task_spec = SumDigitsSpec(seed=42) | ||
| system_prompt = task_spec.get_system_prompt() |
There was a problem hiding this comment.
The concept of task is actually the same as dataloader + reward calculation. Once we rebase onto the config system change, we would need proper config for Task as well. I think we are moving towards "everything is configurable" idea
evaluatefunction and runs evaluation before and after trainingFinal result: