From 0014f6f7ddec3fdc99e35f2da9aed016cca04bff Mon Sep 17 00:00:00 2001 From: adamimos Date: Tue, 30 Sep 2025 14:54:21 -0700 Subject: [PATCH 1/5] configs: interpolate vocab size from generator in penzai transformer/gruformer YAMLs --- simplexity/configs/predictive_model/gruformer.yaml | 2 +- simplexity/configs/predictive_model/transformer.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/simplexity/configs/predictive_model/gruformer.yaml b/simplexity/configs/predictive_model/gruformer.yaml index b4a0dbf6..3f692bfd 100644 --- a/simplexity/configs/predictive_model/gruformer.yaml +++ b/simplexity/configs/predictive_model/gruformer.yaml @@ -9,7 +9,7 @@ instance: projection_dim: 16 mlp_hidden_dim: 16 num_decoder_blocks: 2 - vocab_size: 2 + vocab_size: ${training_data_generator.vocab_size} mlp_variant: geglu_approx tie_embedder_and_logits: false rope_wavelength: 10000 diff --git a/simplexity/configs/predictive_model/transformer.yaml b/simplexity/configs/predictive_model/transformer.yaml index 0bf13c1b..162f0e05 100644 --- a/simplexity/configs/predictive_model/transformer.yaml +++ b/simplexity/configs/predictive_model/transformer.yaml @@ -9,7 +9,7 @@ instance: projection_dim: 16 mlp_hidden_dim: 16 num_decoder_blocks: 2 - vocab_size: 2 + vocab_size: ${training_data_generator.vocab_size} mlp_variant: geglu_approx tie_embedder_and_logits: false rope_wavelength: 10000 From 49e1a61560e3c24b6a9a73f573afd212913f6660 Mon Sep 17 00:00:00 2001 From: adamimos Date: Tue, 30 Sep 2025 14:54:41 -0700 Subject: [PATCH 2/5] run: inject vocab from generator; add BOS/EOS and seq_len consistency checks; route PyTorch models to torch trainer with n_ctx alignment --- simplexity/run.py | 109 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 93 insertions(+), 16 deletions(-) diff --git a/simplexity/run.py b/simplexity/run.py index 12cdb3f2..d8ef063e 100644 --- a/simplexity/run.py +++ b/simplexity/run.py @@ -1,14 +1,15 @@ from contextlib import nullcontext import hydra -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from simplexity.configs.config import Config, validate_config from simplexity.generative_processes.generative_process import GenerativeProcess from simplexity.logging.logger import Logger from simplexity.persistence.model_persister import ModelPersister from simplexity.predictive_models.predictive_model import PredictiveModel -from simplexity.training.train_model import train +from simplexity.training.train_model import train as train_jax +from simplexity.training.train_pytorch_model import train as train_torch from simplexity.utils.hydra import typed_instantiate @@ -25,6 +26,7 @@ def train_model(cfg: Config) -> float: else: logger = None + # Instantiate data generators training_data_generator = typed_instantiate(cfg.training_data_generator.instance, GenerativeProcess) if cfg.validation_data_generator: @@ -36,7 +38,46 @@ def train_model(cfg: Config) -> float: validation_bos_token = None validation_eos_token = None - model = typed_instantiate(cfg.predictive_model.instance, PredictiveModel) + # Ensure model vocab_size matches data generator where applicable before instantiation + vocab_size = int(cfg.training_data_generator.vocab_size) + # Attempt to set common locations for vocab size (top-level or nested under config/cfg) + inst = cfg.predictive_model.instance + try: + if hasattr(inst, "vocab_size") or (isinstance(inst, DictConfig) and "vocab_size" in inst): + cfg.predictive_model.instance.vocab_size = vocab_size + # penzai-style nested config + if isinstance(inst, DictConfig) and "config" in inst and "vocab_size" in inst["config"]: + cfg.predictive_model.instance.config.vocab_size = vocab_size + # transformer-lens-style nested cfg + if isinstance(inst, DictConfig) and "cfg" in inst and "d_vocab" in inst["cfg"]: + cfg.predictive_model.instance.cfg.d_vocab = vocab_size + except Exception: + # Keep going; we will still validate post-instantiation + pass + + # Consistency checks for sequence lengths and BOS/EOS + bos = cfg.training_data_generator.bos_token + eos = cfg.training_data_generator.eos_token + B = 1 if bos is not None else 0 + E = 1 if eos is not None else 0 + + # Training/validation sequence length alignment + if cfg.validation is not None: + assert ( + cfg.validation.sequence_len == cfg.training.sequence_len + ), "validation.sequence_len must match training.sequence_len for consistent context" + + # BOS/EOS validity + if bos is not None: + assert 0 <= bos < vocab_size, "bos_token must be within [0, vocab_size)" + if eos is not None: + assert 0 <= eos < vocab_size, "eos_token must be within [0, vocab_size)" + + # Effective model input length (for transparency and potential model checks) + effective_inputs_len = cfg.training.sequence_len + B + E - 1 + + # Instantiate model without enforcing a specific protocol, then route to the right trainer + model = hydra.utils.instantiate(cfg.predictive_model.instance) persister_context = ( typed_instantiate(cfg.persistence.instance, ModelPersister) if cfg.persistence else nullcontext() @@ -50,19 +91,55 @@ def train_model(cfg: Config) -> float: else: train_persister = None - _, loss = train( - model, - cfg.training, - training_data_generator, - logger, - cfg.validation, - validation_data_generator, - train_persister, - training_bos_token=cfg.training_data_generator.bos_token, - training_eos_token=cfg.training_data_generator.eos_token, - validation_bos_token=validation_bos_token, - validation_eos_token=validation_eos_token, - ) + # Choose trainer based on model type + try: + import torch.nn as nn # defer import to avoid requiring torch for JAX runs + except Exception: + nn = None # type: ignore + + if nn is not None and isinstance(model, nn.Module): + # If TransformerLens-style config exists, assert n_ctx alignment when set + try: + n_ctx = None + inst = cfg.predictive_model.instance + if isinstance(inst, DictConfig) and "cfg" in inst and "n_ctx" in inst["cfg"]: + n_ctx = int(inst["cfg"]["n_ctx"]) # type: ignore + if n_ctx is not None: + assert ( + n_ctx == effective_inputs_len + ), f"predictive_model.cfg.n_ctx ({n_ctx}) must equal effective inputs length ({effective_inputs_len}) computed from sequence_len and BOS/EOS" + except Exception: + # Be permissive; downstream will surface shape mismatches + pass + + _, loss = train_torch( + model, + cfg.training, + training_data_generator, + logger, + cfg.validation, + validation_data_generator, + train_persister, + training_bos_token=cfg.training_data_generator.bos_token, + training_eos_token=cfg.training_data_generator.eos_token, + validation_bos_token=validation_bos_token, + validation_eos_token=validation_eos_token, + ) + else: + # Default JAX/Penzai/Equinox path + _, loss = train_jax( + model, # type: ignore[arg-type] + cfg.training, + training_data_generator, + logger, + cfg.validation, + validation_data_generator, + train_persister, + training_bos_token=cfg.training_data_generator.bos_token, + training_eos_token=cfg.training_data_generator.eos_token, + validation_bos_token=validation_bos_token, + validation_eos_token=validation_eos_token, + ) if logger: logger.close() From 10210a5e697c58ae2465971006c4c22cea573da4 Mon Sep 17 00:00:00 2001 From: adamimos Date: Tue, 30 Sep 2025 14:54:53 -0700 Subject: [PATCH 3/5] torch: move batch to model device; ensure token dtype long in training and evaluation --- simplexity/evaluation/evaluate_pytorch_model.py | 11 +++++++++++ simplexity/training/train_pytorch_model.py | 11 +++++++++++ 2 files changed, 22 insertions(+) diff --git a/simplexity/evaluation/evaluate_pytorch_model.py b/simplexity/evaluation/evaluate_pytorch_model.py index 6c2da15a..8ed37e69 100644 --- a/simplexity/evaluation/evaluate_pytorch_model.py +++ b/simplexity/evaluation/evaluate_pytorch_model.py @@ -25,7 +25,18 @@ def evaluation_step( ) -> dict[str, jax.Array]: """Compute evaluation metrics for a PyTorch model.""" model.eval() + try: + device = next(model.parameters()).device + except Exception: + device = None with torch.no_grad(): + # Ensure data on model device and dtype for token indices + if device is not None: + inputs = inputs.to(device) + labels = labels.to(device) + if inputs.dtype != torch.long: + inputs = inputs.long() + logits: torch.Tensor = model(inputs) metrics = {} diff --git a/simplexity/training/train_pytorch_model.py b/simplexity/training/train_pytorch_model.py index 98bab9d5..f404f62c 100644 --- a/simplexity/training/train_pytorch_model.py +++ b/simplexity/training/train_pytorch_model.py @@ -55,6 +55,17 @@ def train( model.train() optimizer.zero_grad() + # Move to model device and ensure integer token dtype for index-based models + try: + device = next(model.parameters()).device + inputs = inputs.to(device) + labels = labels.to(device) + except Exception: + pass + + if inputs.dtype != torch.long: + inputs = inputs.long() + logits = model(inputs) vocab_size = logits.shape[2] From 5a75f3bf33d3a1ed10611a0d7c5cadfa521baf67 Mon Sep 17 00:00:00 2001 From: adamimos Date: Tue, 30 Sep 2025 14:55:20 -0700 Subject: [PATCH 4/5] configs: add TransformerLens example (predictive_model, training, evaluation) and mess3_085 generator; include convenience top-level TL config --- .../configs/evaluation/transformerlens.yaml | 6 ++++++ .../configs/generative_process/mess3_085.yaml | 10 ++++++++++ .../transformer_lens_2L2H.yaml | 20 +++++++++++++++++++ .../configs/training/transformerlens.yaml | 12 +++++++++++ simplexity/configs/transformerlens_mess3.yaml | 14 +++++++++++++ 5 files changed, 62 insertions(+) create mode 100644 simplexity/configs/evaluation/transformerlens.yaml create mode 100644 simplexity/configs/generative_process/mess3_085.yaml create mode 100644 simplexity/configs/predictive_model/transformer_lens_2L2H.yaml create mode 100644 simplexity/configs/training/transformerlens.yaml create mode 100644 simplexity/configs/transformerlens_mess3.yaml diff --git a/simplexity/configs/evaluation/transformerlens.yaml b/simplexity/configs/evaluation/transformerlens.yaml new file mode 100644 index 00000000..43e08840 --- /dev/null +++ b/simplexity/configs/evaluation/transformerlens.yaml @@ -0,0 +1,6 @@ +seed: ${seed} +sequence_len: 6 +batch_size: 4 +num_steps: 10 +log_every: + diff --git a/simplexity/configs/generative_process/mess3_085.yaml b/simplexity/configs/generative_process/mess3_085.yaml new file mode 100644 index 00000000..962befde --- /dev/null +++ b/simplexity/configs/generative_process/mess3_085.yaml @@ -0,0 +1,10 @@ +name: mess3_085 +vocab_size: 4 +instance: + _target_: simplexity.generative_processes.builder.build_hidden_markov_model + process_name: mess3 + x: 0.05 + a: 0.85 +bos_token: 3 +eos_token: + diff --git a/simplexity/configs/predictive_model/transformer_lens_2L2H.yaml b/simplexity/configs/predictive_model/transformer_lens_2L2H.yaml new file mode 100644 index 00000000..b0dbdacf --- /dev/null +++ b/simplexity/configs/predictive_model/transformer_lens_2L2H.yaml @@ -0,0 +1,20 @@ +name: transformer_lens_2L2H +instance: + _target_: transformer_lens.HookedTransformer + cfg: + _target_: transformer_lens.HookedTransformerConfig + n_layers: 2 + n_heads: 2 + d_model: 64 + d_head: 32 + n_ctx: 6 + d_vocab: ${training_data_generator.vocab_size} + act_fn: "gelu" + normalization_type: "LN" + positional_embedding_type: "standard" + attn_only: false + seed: ${seed} + device: cpu + +load_checkpoint_step: + diff --git a/simplexity/configs/training/transformerlens.yaml b/simplexity/configs/training/transformerlens.yaml new file mode 100644 index 00000000..f2834d49 --- /dev/null +++ b/simplexity/configs/training/transformerlens.yaml @@ -0,0 +1,12 @@ +defaults: + - _self_ + - optimizer: pytorch_adam + +seed: ${seed} +sequence_len: 6 # With BOS: generator emits 6 + BOS -> tokens 7; inputs length stays 6 +batch_size: 32 +num_steps: 100 +log_every: 10 +validate_every: 20 +checkpoint_every: 50 + diff --git a/simplexity/configs/transformerlens_mess3.yaml b/simplexity/configs/transformerlens_mess3.yaml new file mode 100644 index 00000000..12d972e5 --- /dev/null +++ b/simplexity/configs/transformerlens_mess3.yaml @@ -0,0 +1,14 @@ +defaults: + - _self_ + - generative_process@training_data_generator: mess3_085 + - generative_process@validation_data_generator: mess3_085 + - predictive_model: transformer_lens_2L2H + - persistence: s3_persister + - logging: mlflow_logger + - training: transformerlens + - evaluation@validation: transformerlens + +seed: 42 +experiment_name: transformerlens_mess3 +run_name: ${now:%Y-%m-%d_%H-%M-%S}_${experiment_name}_${seed} + From ed6507ac485fdd4f322202e895c418b837183965 Mon Sep 17 00:00:00 2001 From: adamimos Date: Tue, 30 Sep 2025 14:56:05 -0700 Subject: [PATCH 5/5] docs: add TransformerLens example usage and notes --- README.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/README.md b/README.md index 3dd52fd8..f5b1e852 100644 --- a/README.md +++ b/README.md @@ -77,3 +77,18 @@ prefix = your_s3_prefix ``` [AWS configuration and credential files](https://docs.aws.amazon.com/cli/v1/userguide/cli-configure-files.html) can be used for authentication and settings. Authentication credentials should be specified in `~/.aws/credentials`. Settings like `region`, `output`, `endpoint_url` should be specified in `~/.aws/config`. Multiple different profiles can be defined and the specific profile to use can be specified in the `aws` section of the `.ini` file. + +### TransformerLens example (optional) + +We include a minimal TransformerLens setup mirroring the "basic_mess3" example from simplex-research. It aligns model context with the data pipeline and derives vocab from the generator. + +Run with the convenience config: + +```bash +uv run python -m simplexity.run --config-name transformerlens_mess3 +``` + +Notes: +- `d_vocab` is interpolated from `${training_data_generator.vocab_size}`. +- `n_ctx` must match the effective inputs length (checked at runtime) implied by `training.sequence_len` and BOS/EOS. +- The PyTorch trainer automatically moves batches to the model device and uses `long` token dtype.