Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 161 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,123 @@ This branch builds on top of the official TRM implementation and adds the follow
- simpler setup with `uv`
- better checkpoint saving
- simpler problems for debugging (e.g. Sudoku 4x4)
- **NEW**: Advanced Q&A datasets for reasoning evaluation
- **NEW**: Mathematical reasoning datasets (MATH & GSM8K style problems)

Nothing is changed in the model/architecture/training.

The scripts to prepare the data and train the model remain the same.
The scripts to prepare the data and train the model remain the same.

## Example on Rubik's cube 2x2x2
## 🚀 Quick Start Examples

To prepare the data:
### Rubik's Cube 2x2x2
```bash
# Prepare data
uv run dataset/build_rubik2x2_dataset.py

# Train model
./train_rubik2x2.sh

# Evaluate
uv run python evaluate.py --data-path data/rubik2x2/ --config checkpoints/trm/<yours>/all_config.yaml --checkpoint checkpoints/trm/<yours>/final_step_4500/model.pt
```

### Q&A Pairs (Natural Language Understanding)
```bash
# Prepare data
uv run dataset/build_qa_dataset.py

# Train model
./train_qa_pairs.sh

`uv run dataset/build_rubik2x2_dataset.py`
# Evaluate
uv run python evaluate.py --data-path data/qa_pairs/ --config checkpoints/trm/<yours>/all_config.yaml --checkpoint checkpoints/trm/<yours>/final_step_4500/model.pt
```

To train the model: `train_rubik2x2.sh` (this model trains in a few minutes on an A10)
### Sudoku 4x4
```bash
# Prepare data
uv run python dataset/build_sudoku_4x4_dataset.py

To evaluate the model:
# Train model
./train_sudoku4x4.sh

`uv run python evaluate.py --data-path data/sudoku4x4/ --config checkpoints/trm/messy-earwig-of-enthusiasm/all_config.yaml --checkpoint checkpoints/trm/messy-earwig-of-enthusiasm/final_step_45/model.pt`
# Evaluate
uv run python evaluate.py --data-path data/sudoku4x4/ --config checkpoints/trm/<yours>/all_config.yaml --checkpoint checkpoints/trm/<yours>/final_step_45/model.pt
```

## 🧠 Advanced Reasoning Examples

## Example on Sudoku 4x4
### Advanced Q&A Reasoning Tasks
```bash
# Ultra-advanced reasoning Q&A pairs (76.05% accuracy achieved)
uv run dataset/build_qa_dataset.py --advanced

To prepare the data:
# Train on advanced reasoning
uv run python pretrain.py --config-name cfg_qa_advanced

`uv run python dataset/build_sudoku_4x4_dataset.py`
# Evaluate reasoning capabilities
uv run python evaluate.py --data-path data/qa_pairs_advanced/
```

To train the model: `train_sudoku4x4.sh` (this model trains in a few minutes on an A10)
### Ultra-Complex Reasoning Tasks
```bash
# Ultra-complex multi-step reasoning problems
uv run dataset/build_qa_dataset.py --ultra-complex

To evaluate the model:
# Smaller version for testing
uv run dataset/build_qa_dataset.py --ultra-complex-small
```

### Mathematical Reasoning (MATH & GSM8K)
```bash
# Prepare comprehensive math dataset (10K training, 2K test)
uv run python dataset/build_math_gsm8k_dataset.py

# Train on mathematical reasoning
./train_math&gsmk8.sh

# Evaluate math capabilities
uv run python evaluate_math.py --checkpoint-dir checkpoints/TRM-Math-Reasoning/<run>/
```

**Math Dataset Composition:**
- **Basic Arithmetic**: Addition, subtraction, multiplication, division word problems
- **Algebra**: Linear equations, systems of equations, quadratic equations
- **Geometry**: Circle area/volume, triangle area, rectangle perimeter, sphere volume
- **Calculus**: Derivatives, indefinite/definite integrals, limits, Taylor series
- **Advanced Topics**: Differential equations, complex analysis, residues
- **Statistics**: Mean, standard deviation, probability distributions
- **Number Theory**: GCD, prime checking, modular arithmetic, Euler's totient
- **Discrete Math**: Combinatorics, recurrence relations, graph theory

## 🧪 Evaluation & Analysis

### Standard Evaluation
```bash
# Evaluate any trained model
uv run python evaluate.py \
--data-path data/<dataset>/ \
--config checkpoints/trm/<run>/all_config.yaml \
--checkpoint checkpoints/trm/<run>/final_step_<N>/model.pt
```

### Mathematical Reasoning Evaluation
```bash
# Evaluate math capabilities specifically
uv run python evaluate_math.py \
--checkpoint-dir checkpoints/TRM-Math-Reasoning/<run>/ \
--data-path data/math_gsm8k_qa \
--num-samples 100
```

`uv run python evaluate.py --data-path data/sudoku4x4/ --config checkpoints/trm/messy-earwig-of-enthusiasm/all_config.yaml --checkpoint checkpoints/trm/messy-earwig-of-enthusiasm/final_step_45/model.pt`
### Available Scripts
- `evaluate.py` - General evaluation for all puzzle types
- `evaluate_math.py` - Specialized evaluation for mathematical reasoning
- `train_math&gsmk8.sh` - Training script for math dataset
- `train_math_gsmk8.sh` - Alternative training script

## Reference

# Less is More: Recursive Reasoning with Tiny Networks

Expand Down Expand Up @@ -91,6 +179,14 @@ python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1

# Maze-Hard
python dataset/build_maze_dataset.py # 1000 examples, 8 augments

# NEW: Advanced Q&A Reasoning Tasks
uv run python dataset/build_qa_dataset.py --advanced # Ultra-advanced reasoning (76.05% accuracy achieved)
uv run python dataset/build_qa_dataset.py --ultra-complex # Ultra-complex multi-step reasoning
uv run python dataset/build_qa_dataset.py --ultra-complex-small # Smaller version for testing

# NEW: Mathematical Reasoning (MATH & GSM8K style)
uv run python dataset/build_math_gsm8k_dataset.py # 10K training, 2K test examples across 14 math categories
```

## Experiments
Expand Down Expand Up @@ -171,7 +267,58 @@ arch.H_cycles=3 arch.L_cycles=4 \

*Runtime:* < 24 hours

## Reference
### Advanced Q&A Reasoning (assuming 1 GPU):

```bash
# Ultra-advanced reasoning tasks (achieved 76.05% accuracy)
run_name="pretrain_qa_advanced"
python pretrain.py \
arch=trm \
data_paths="[data/qa_pairs_advanced]" \
evaluators="[]" \
epochs=10000 eval_interval=1000 \
lr=1e-4 puzzle_emb_lr=1e-2 weight_decay=0.1 puzzle_emb_weight_decay=0.1 \
arch.L_layers=2 \
arch.H_cycles=2 arch.L_cycles=2 \
+run_name=${run_name}

# Ultra-complex reasoning tasks
run_name="pretrain_qa_ultra_complex"
python pretrain.py \
arch=trm \
data_paths="[data/qa_pairs_ultra_complex]" \
evaluators="[]" \
epochs=50000 eval_interval=5000 \
lr=1e-4 puzzle_emb_lr=1e-2 weight_decay=0.1 puzzle_emb_weight_decay=0.1 \
arch.L_layers=2 \
arch.H_cycles=3 arch.L_cycles=4 \
+run_name=${run_name}
```

*Runtime:* 2-12 hours

### Mathematical Reasoning (MATH & GSM8K) (assuming 1 GPU):

```bash
# Comprehensive mathematical reasoning training
run_name="pretrain_math_gsm8k"
python pretrain.py --config-name cfg_math_pretrain

# Quick test version (10 epochs)
python pretrain.py --config-name cfg_math_test
```

*Runtime:* 4-120 hours (depending on configuration)

**Math Dataset Composition:**
- **Basic Arithmetic**: Addition, subtraction, multiplication, division word problems
- **Algebra**: Linear equations, systems of equations, quadratic equations
- **Geometry**: Circle area/volume, triangle area, rectangle perimeter, sphere volume
- **Calculus**: Derivatives, indefinite/definite integrals, limits, Taylor series
- **Advanced Topics**: Differential equations, complex analysis, residues
- **Statistics**: Mean, standard deviation, probability distributions
- **Number Theory**: GCD, prime checking, modular arithmetic, Euler's totient
- **Discrete Math**: Combinatorics, recurrence relations, graph theory

If you find our work useful, please consider citing:

Expand Down
47 changes: 47 additions & 0 deletions config/cfg_math_pretrain.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# MATH GSM8K training config

defaults:
- arch: trm
- _self_

hydra:
output_subdir: null

# Data path
data_paths: ['data/math_gsm8k_qa']
data_paths_test: []

evaluators: []

# Hyperparams - Training
global_batch_size: 4 # Reduced batch size

epochs: 100 # Medium training for better results
eval_interval: 25 # Evaluate every 25 epochs
checkpoint_every_eval: True

lr: 1e-4
lr_min_ratio: 1.0
lr_warmup_steps: 1000

# Standard hyperparameter settings for LM, as used in Llama
beta1: 0.9
beta2: 0.95
weight_decay: 0.1
puzzle_emb_weight_decay: 0.1

# Hyperparams - Puzzle embeddings training
puzzle_emb_lr: 1e-2

seed: 0
min_eval_interval: 0 # when to start the eval

ema: False # use Exponential-Moving-Average
ema_rate: 0.999 # EMA-rate
freeze_weights: False # If True, freeze weights and only learn the embeddings
use_wandb: false # Disable wandb for now

# Project settings
project_name: "TRM-Math-Reasoning"
entity: null
run_name: null
46 changes: 46 additions & 0 deletions config/cfg_math_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Quick test config for MATH training

defaults:
- arch: trm
- _self_

hydra:
output_subdir: null

# Data path
data_paths: ['data/math_gsm8k_qa']
data_paths_test: []

evaluators: []

# Hyperparams - Training (very short test)
global_batch_size: 2 # Even smaller batch size

epochs: 10 # Just 10 epochs for testing
eval_interval: 5
checkpoint_every_eval: True

lr: 1e-3 # Higher learning rate for faster testing
lr_min_ratio: 1.0
lr_warmup_steps: 1

# Standard hyperparameter settings
beta1: 0.9
beta2: 0.95
weight_decay: 0.1
puzzle_emb_weight_decay: 0.1

# Hyperparams - Puzzle embeddings training
puzzle_emb_lr: 1e-2

seed: 0
min_eval_interval: 0

ema: False
freeze_weights: False
use_wandb: false

# Project settings
project_name: "TRM-Math-Test"
entity: null
run_name: null
Loading