Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions torchtitan/experiments/rl/unified/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,18 @@ uv pip install torch vllm xformers --pre \
python scripts/download_hf_assets.py --repo_id Qwen/Qwen3-0.6B --local_dir torchtitan/experiments/rl/example_checkpoint --all --hf_token=...
```

4. Run inference:
4. Run inference with unified model definition:
```bash
python torchtitan/experiments/rl/unified/infer.py
```

Run with TP:
```bash
python torchtitan/experiments/rl/unified/infer.py --tensor-parallel-size 2

torchrun --nproc_per_node=<world_size> \
torchtitan/experiments/rl/unified/infer.py
```

5. Run simple rl loop
```bash
VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py
```
Right now we only support VLLM_COMPAT mode, which could achieve trainer and generator bitwise identical. We are working on support UNIFIED mode,
which uses a unified model definition for trainer and generator.
python3 torchtitan/experiments/rl/unified/simple_grpo.py \
--trainer.checkpoint.initial_load_path=<path_to_model_checkpoint>
```
We use a unified model definition for the trainer and generator, ensuring bitwise-identical models to address a class of subtle correctness bugs in RL for LLMs.

## TODO
Work on batch invariance:
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/experiments/rl/unified/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, *, vllm_config, prefix=""):
# Register with vLLM
ModelRegistry.register_model(model_name, TorchTitanVLLMModelFromSpec)

logger.info(
logger.debug(
f"Successfully registered {model_name} with vLLM using ModelSpec "
f"(flavor={model_spec.flavor})"
)
Expand Down
Loading
Loading