Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,18 @@ prefix = your_s3_prefix
```

[AWS configuration and credential files](https://docs.aws.amazon.com/cli/v1/userguide/cli-configure-files.html) can be used for authentication and settings. Authentication credentials should be specified in `~/.aws/credentials`. Settings like `region`, `output`, `endpoint_url` should be specified in `~/.aws/config`. Multiple different profiles can be defined and the specific profile to use can be specified in the `aws` section of the `.ini` file.

### TransformerLens example (optional)

We include a minimal TransformerLens setup mirroring the "basic_mess3" example from simplex-research. It aligns model context with the data pipeline and derives vocab from the generator.

Run with the convenience config:

```bash
uv run python -m simplexity.run --config-name transformerlens_mess3
```

Notes:
- `d_vocab` is interpolated from `${training_data_generator.vocab_size}`.
- `n_ctx` must match the effective inputs length (checked at runtime) implied by `training.sequence_len` and BOS/EOS.
- The PyTorch trainer automatically moves batches to the model device and uses `long` token dtype.
6 changes: 6 additions & 0 deletions simplexity/configs/evaluation/transformerlens.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seed: ${seed}
sequence_len: 6
batch_size: 4
num_steps: 10
log_every:

10 changes: 10 additions & 0 deletions simplexity/configs/generative_process/mess3_085.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
name: mess3_085
vocab_size: 4
instance:
_target_: simplexity.generative_processes.builder.build_hidden_markov_model
process_name: mess3
x: 0.05
a: 0.85
bos_token: 3
eos_token:

2 changes: 1 addition & 1 deletion simplexity/configs/predictive_model/gruformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ instance:
projection_dim: 16
mlp_hidden_dim: 16
num_decoder_blocks: 2
vocab_size: 2
vocab_size: ${training_data_generator.vocab_size}
mlp_variant: geglu_approx
tie_embedder_and_logits: false
rope_wavelength: 10000
Expand Down
2 changes: 1 addition & 1 deletion simplexity/configs/predictive_model/transformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ instance:
projection_dim: 16
mlp_hidden_dim: 16
num_decoder_blocks: 2
vocab_size: 2
vocab_size: ${training_data_generator.vocab_size}
mlp_variant: geglu_approx
tie_embedder_and_logits: false
rope_wavelength: 10000
Expand Down
20 changes: 20 additions & 0 deletions simplexity/configs/predictive_model/transformer_lens_2L2H.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
name: transformer_lens_2L2H
instance:
_target_: transformer_lens.HookedTransformer
cfg:
_target_: transformer_lens.HookedTransformerConfig
n_layers: 2
n_heads: 2
d_model: 64
d_head: 32
n_ctx: 6
d_vocab: ${training_data_generator.vocab_size}
act_fn: "gelu"
normalization_type: "LN"
positional_embedding_type: "standard"
attn_only: false
seed: ${seed}
device: cpu

load_checkpoint_step:

12 changes: 12 additions & 0 deletions simplexity/configs/training/transformerlens.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
defaults:
- _self_
- optimizer: pytorch_adam

seed: ${seed}
sequence_len: 6 # With BOS: generator emits 6 + BOS -> tokens 7; inputs length stays 6
batch_size: 32
num_steps: 100
log_every: 10
validate_every: 20
checkpoint_every: 50

14 changes: 14 additions & 0 deletions simplexity/configs/transformerlens_mess3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
defaults:
- _self_
- generative_process@training_data_generator: mess3_085
- generative_process@validation_data_generator: mess3_085
- predictive_model: transformer_lens_2L2H
- persistence: s3_persister
- logging: mlflow_logger
- training: transformerlens
- evaluation@validation: transformerlens

seed: 42
experiment_name: transformerlens_mess3
run_name: ${now:%Y-%m-%d_%H-%M-%S}_${experiment_name}_${seed}

11 changes: 11 additions & 0 deletions simplexity/evaluation/evaluate_pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,18 @@ def evaluation_step(
) -> dict[str, jax.Array]:
"""Compute evaluation metrics for a PyTorch model."""
model.eval()
try:
device = next(model.parameters()).device
except Exception:
device = None
with torch.no_grad():
# Ensure data on model device and dtype for token indices
if device is not None:
inputs = inputs.to(device)
labels = labels.to(device)
if inputs.dtype != torch.long:
inputs = inputs.long()

logits: torch.Tensor = model(inputs)
metrics = {}

Expand Down
109 changes: 93 additions & 16 deletions simplexity/run.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from contextlib import nullcontext

import hydra
from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf

Check failure on line 4 in simplexity/run.py

View workflow job for this annotation

GitHub Actions / static-analysis

Ruff (F401)

simplexity/run.py:4:35: F401 `omegaconf.OmegaConf` imported but unused

from simplexity.configs.config import Config, validate_config
from simplexity.generative_processes.generative_process import GenerativeProcess
from simplexity.logging.logger import Logger
from simplexity.persistence.model_persister import ModelPersister
from simplexity.predictive_models.predictive_model import PredictiveModel

Check failure on line 10 in simplexity/run.py

View workflow job for this annotation

GitHub Actions / static-analysis

Ruff (F401)

simplexity/run.py:10:59: F401 `simplexity.predictive_models.predictive_model.PredictiveModel` imported but unused
from simplexity.training.train_model import train
from simplexity.training.train_model import train as train_jax
from simplexity.training.train_pytorch_model import train as train_torch
from simplexity.utils.hydra import typed_instantiate


Expand All @@ -25,6 +26,7 @@
else:
logger = None

# Instantiate data generators
training_data_generator = typed_instantiate(cfg.training_data_generator.instance, GenerativeProcess)

if cfg.validation_data_generator:
Expand All @@ -36,7 +38,46 @@
validation_bos_token = None
validation_eos_token = None

model = typed_instantiate(cfg.predictive_model.instance, PredictiveModel)
# Ensure model vocab_size matches data generator where applicable before instantiation
vocab_size = int(cfg.training_data_generator.vocab_size)
# Attempt to set common locations for vocab size (top-level or nested under config/cfg)
inst = cfg.predictive_model.instance
try:
if hasattr(inst, "vocab_size") or (isinstance(inst, DictConfig) and "vocab_size" in inst):
cfg.predictive_model.instance.vocab_size = vocab_size
# penzai-style nested config
if isinstance(inst, DictConfig) and "config" in inst and "vocab_size" in inst["config"]:
cfg.predictive_model.instance.config.vocab_size = vocab_size
# transformer-lens-style nested cfg
if isinstance(inst, DictConfig) and "cfg" in inst and "d_vocab" in inst["cfg"]:
cfg.predictive_model.instance.cfg.d_vocab = vocab_size
except Exception:
# Keep going; we will still validate post-instantiation
pass

# Consistency checks for sequence lengths and BOS/EOS
bos = cfg.training_data_generator.bos_token
eos = cfg.training_data_generator.eos_token
B = 1 if bos is not None else 0
E = 1 if eos is not None else 0

# Training/validation sequence length alignment
if cfg.validation is not None:
assert (
cfg.validation.sequence_len == cfg.training.sequence_len
), "validation.sequence_len must match training.sequence_len for consistent context"

# BOS/EOS validity
if bos is not None:
assert 0 <= bos < vocab_size, "bos_token must be within [0, vocab_size)"
if eos is not None:
assert 0 <= eos < vocab_size, "eos_token must be within [0, vocab_size)"

# Effective model input length (for transparency and potential model checks)
effective_inputs_len = cfg.training.sequence_len + B + E - 1

# Instantiate model without enforcing a specific protocol, then route to the right trainer
model = hydra.utils.instantiate(cfg.predictive_model.instance)

persister_context = (
typed_instantiate(cfg.persistence.instance, ModelPersister) if cfg.persistence else nullcontext()
Expand All @@ -50,19 +91,55 @@
else:
train_persister = None

_, loss = train(
model,
cfg.training,
training_data_generator,
logger,
cfg.validation,
validation_data_generator,
train_persister,
training_bos_token=cfg.training_data_generator.bos_token,
training_eos_token=cfg.training_data_generator.eos_token,
validation_bos_token=validation_bos_token,
validation_eos_token=validation_eos_token,
)
# Choose trainer based on model type
try:
import torch.nn as nn # defer import to avoid requiring torch for JAX runs
except Exception:
nn = None # type: ignore

if nn is not None and isinstance(model, nn.Module):
# If TransformerLens-style config exists, assert n_ctx alignment when set
try:
n_ctx = None
inst = cfg.predictive_model.instance
if isinstance(inst, DictConfig) and "cfg" in inst and "n_ctx" in inst["cfg"]:
n_ctx = int(inst["cfg"]["n_ctx"]) # type: ignore
if n_ctx is not None:
assert (
n_ctx == effective_inputs_len
), f"predictive_model.cfg.n_ctx ({n_ctx}) must equal effective inputs length ({effective_inputs_len}) computed from sequence_len and BOS/EOS"

Check failure on line 110 in simplexity/run.py

View workflow job for this annotation

GitHub Actions / static-analysis

Ruff (E501)

simplexity/run.py:110:121: E501 Line too long (161 > 120)
except Exception:
# Be permissive; downstream will surface shape mismatches
pass

_, loss = train_torch(
model,
cfg.training,
training_data_generator,
logger,
cfg.validation,
validation_data_generator,
train_persister,
training_bos_token=cfg.training_data_generator.bos_token,
training_eos_token=cfg.training_data_generator.eos_token,
validation_bos_token=validation_bos_token,
validation_eos_token=validation_eos_token,
)
else:
# Default JAX/Penzai/Equinox path
_, loss = train_jax(
model, # type: ignore[arg-type]
cfg.training,
training_data_generator,
logger,
cfg.validation,
validation_data_generator,
train_persister,
training_bos_token=cfg.training_data_generator.bos_token,
training_eos_token=cfg.training_data_generator.eos_token,
validation_bos_token=validation_bos_token,
validation_eos_token=validation_eos_token,
)

if logger:
logger.close()
Expand Down
11 changes: 11 additions & 0 deletions simplexity/training/train_pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@ def train(
model.train()
optimizer.zero_grad()

# Move to model device and ensure integer token dtype for index-based models
try:
device = next(model.parameters()).device
inputs = inputs.to(device)
labels = labels.to(device)
except Exception:
pass

if inputs.dtype != torch.long:
inputs = inputs.long()

logits = model(inputs)

vocab_size = logits.shape[2]
Expand Down
Loading