Skip to content
7 changes: 5 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,8 @@ docs/build
temp
.coverage
*.ipynb_checkpoints
*/.cache
*/lightning_logs
.cache
lightning_logs
wandb
outputs
checkpoints
2 changes: 2 additions & 0 deletions cortex/cmdline/train_cortex_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def execute(cfg):
model = hydra.utils.instantiate(cfg.tree)
model.build_tree(cfg, skip_task_setup=False)

model = torch.compile(model, backend="inductor", mode="reduce-overhead")

trainer.fit(
model,
train_dataloaders=model.get_dataloader(split="train"),
Expand Down
3 changes: 3 additions & 0 deletions cortex/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .neural_tree_config import NeuralTreeConfig, RootConfig

__all__ = ["NeuralTreeConfig", "RootConfig"]
9 changes: 9 additions & 0 deletions cortex/config/hydra/roots/huggingface_protein.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
_target_: cortex.model.root.HuggingFaceRoot
model_name_or_path: Rostlab/prot_bert_bfd
model_max_length: 512
freeze: false
pooling_type: mean
# HuggingFace model config overrides
model_config:
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
log_fluorescence_hf:
_target_: cortex.task.RegressionTask
# Single root key mapping for HF tokenized inputs
input_map:
protein_seq: [] # Empty list since we'll use tokenized inputs directly
outcome_cols: ["label"]
root_key: protein_seq
corrupt_train_inputs: false
corrupt_inference_inputs: false
nominal_label_var: 0.01

data_module:
_target_: cortex.data.data_module.HFTaskDataModule
dataset_config:
_target_: datasets.load_dataset
path: "proteinglm/fluorescence_prediction"
name: default
trust_remote_code: true
batch_size: ${fit.batch_size}
num_workers: ${num_workers}
drop_last: true
text_field: "seq"
label_field: "label"
add_spaces_between_chars: true
tokenization_batch_size: 1000
tokenization_num_proc: 4
89 changes: 89 additions & 0 deletions cortex/config/hydra/train_hf_protein_model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
defaults:
- _self_
- logging: default
- tasks:
- protein_property/log_fluorescence_hf

# HuggingFace protein model training configuration
job_name: hf_protein_model

# Tree configuration
tree:
_target_: cortex.model.tree.NeuralTreeLightningV2
root_nodes:
_target_: torch.nn.ModuleDict
trunk_node: null
branch_nodes:
_target_: torch.nn.ModuleDict
leaf_nodes:
_target_: torch.nn.ModuleDict
seed: 42
num_workers: 4 # Reduced for Mac
download_datasets: true
dataset_root_dir: ${hydra:runtime.cwd}/data
data_dir: ${hydra:runtime.cwd}/data
save_ckpt: true
ckpt_file: model.ckpt
ckpt_cfg: model.yaml
warnings_filter: default

# Wandb config
wandb_mode: offline

# Model configuration - using tiny BERT for Mac
roots:
protein_seq:
_target_: cortex.model.root.HuggingFaceRoot
model_name_or_path: prajjwal1/bert-tiny # 4.4M params instead of 420M
# model_name_or_path: facebook/esm2_t30_150M_UR50D
pooling_strategy: none # Keep sequence dimension for Conv1dBranch
freeze_pretrained: false
# cropped_max_len: 256

trunk:
_target_: cortex.model.trunk.SumTrunk
out_dim: 128 # Smaller for tiny model

branches:
protein_property:
_target_: cortex.model.branch.Conv1dBranch
out_dim: 64 # Smaller
num_blocks: 1 # Fewer blocks
kernel_size: 5
dilation_base: 1
channel_dim: 128 # Smaller
channel_dim_mult: 1
dropout_p: 0.1

# Task configuration using HuggingFace datasets
tasks:
protein_property:
log_fluorescence_hf:
ensemble_size: 1

# Training configuration
fit:
batch_size: 128 # Smaller batch size for Mac
optimizer:
_target_: torch.optim.AdamW
lr: 1e-4 # Higher LR for faster convergence in demo
weight_decay: 0.0
lr_scheduler:
_target_: torch.optim.lr_scheduler.CosineAnnealingLR
T_max: 10 # Shorter schedule
eta_min: 1e-5

trainer:
_target_: lightning.pytorch.Trainer
max_epochs: 64 # Just 2 epochs for demo
accelerator: gpu
devices: 1
precision: 16-mixed # Full precision on Mac
# gradient_clip_val: 1.0 # Not supported with manual optimization
accumulate_grad_batches: 1
log_every_n_steps: 10
val_check_interval: 1.0
enable_checkpointing: true
enable_progress_bar: true
enable_model_summary: true
num_sanity_val_steps: 0
4 changes: 4 additions & 0 deletions cortex/config/hydra/tree/neural_tree_lightning_v2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_target_: cortex.model.tree.NeuralTreeLightningV2
model_type: seq2float
trunk:
_target_: cortex.model.trunk.SumTrunk
206 changes: 206 additions & 0 deletions cortex/config/neural_tree_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
"""HuggingFace-compatible configuration for NeuralTree models."""

from dataclasses import dataclass
from typing import Any, Dict, Optional

from transformers import PretrainedConfig


@dataclass
class RootConfig:
"""Configuration for a single root node in the neural tree."""

# Dual mode: HuggingFace or custom
use_hf_model: bool = False
hf_config: Optional[Dict[str, Any]] = None
cortex_config: Optional[Dict[str, Any]] = None
processor_name: Optional[str] = None

def __post_init__(self):
if self.use_hf_model and self.hf_config is None:
raise ValueError("hf_config must be provided when use_hf_model=True")
if not self.use_hf_model and self.cortex_config is None:
raise ValueError("cortex_config must be provided when use_hf_model=False")


class NeuralTreeConfig(PretrainedConfig):
"""
Configuration class for NeuralTree models that preserves Hydra composition
while enabling HuggingFace ecosystem integration.

This configuration supports both traditional cortex components and modern
HuggingFace pretrained models within the same neural tree architecture.
"""

model_type = "neural_tree"

def __init__(
self,
roots: Optional[Dict[str, Any]] = None,
trunk: Optional[Dict[str, Any]] = None,
branches: Optional[Dict[str, Dict[str, Any]]] = None,
tasks: Optional[Dict[str, Dict[str, Any]]] = None,
processors: Optional[Dict[str, str]] = None,
optimizer_config: Optional[Dict[str, Any]] = None,
lr_scheduler_config: Optional[Dict[str, Any]] = None,
ensemble_size: int = 1,
channel_dim: int = 64,
dropout_prob: float = 0.0,
enable_torch_compile: bool = False,
compile_mode: str = "default",
**kwargs,
):
super().__init__(**kwargs)

# Core tree architecture (preserved from existing cortex)
self.roots = roots or {}
self.trunk = trunk or {}
self.branches = branches or {}
self.tasks = tasks or {}

# New: Transform and processor registry for dataloader execution
self.processors = processors or {} # root_name -> processor_name

# Training configuration (migrated from fit_cfg)
self.optimizer_config = optimizer_config or {}
self.lr_scheduler_config = lr_scheduler_config or {}

# Model global settings
self.ensemble_size = ensemble_size
self.channel_dim = channel_dim
self.dropout_prob = dropout_prob

# Compilation and performance settings
self.enable_torch_compile = enable_torch_compile
self.compile_mode = compile_mode # "default", "reduce-overhead", "max-autotune"

# Convert root configs to RootConfig objects if they're dicts
if self.roots:
for root_name, root_cfg in self.roots.items():
if isinstance(root_cfg, dict):
self.roots[root_name] = RootConfig(**root_cfg)

def to_dict(self):
"""Convert to dictionary for JSON serialization."""
output = super().to_dict()

# Convert RootConfig objects to dicts
if hasattr(self, "roots") and self.roots:
roots_dict = {}
for root_name, root_config in self.roots.items():
if isinstance(root_config, RootConfig):
roots_dict[root_name] = {
"use_hf_model": root_config.use_hf_model,
"hf_config": root_config.hf_config,
"cortex_config": root_config.cortex_config,
"processor_name": root_config.processor_name,
}
else:
roots_dict[root_name] = root_config
output["roots"] = roots_dict

return output

@classmethod
def from_dict(cls, config_dict, **kwargs):
"""Create from dictionary (used during deserialization)."""
# Convert root configs back to RootConfig objects
if "roots" in config_dict:
roots = {}
for root_name, root_data in config_dict["roots"].items():
if isinstance(root_data, dict) and "use_hf_model" in root_data:
roots[root_name] = RootConfig(**root_data)
else:
roots[root_name] = root_data
config_dict["roots"] = roots

return super().from_dict(config_dict, **kwargs)

def add_root(self, name: str, root_config: RootConfig):
"""Add a root node configuration."""
self.roots[name] = root_config

def add_hf_root(self, name: str, model_name_or_path: str, processor_name: Optional[str] = None):
"""Convenience method to add a HuggingFace pretrained root."""
from transformers import AutoConfig

hf_config = AutoConfig.from_pretrained(model_name_or_path)
root_config = RootConfig(
use_hf_model=True, hf_config=hf_config.to_dict(), processor_name=processor_name or model_name_or_path
)
self.add_root(name, root_config)

# Register processor for dataloader execution
if processor_name:
self.processors[name] = processor_name

def add_cortex_root(self, name: str, cortex_config: Dict[str, Any], processor_name: Optional[str] = None):
"""Convenience method to add a traditional cortex root."""
root_config = RootConfig(use_hf_model=False, cortex_config=cortex_config, processor_name=processor_name)
self.add_root(name, root_config)

# Register processor if provided
if processor_name:
self.processors[name] = processor_name

def to_hydra_config(self) -> Dict[str, Any]:
"""
Convert back to Hydra-style configuration for backwards compatibility.
This allows existing training scripts to work with minimal changes.
"""
hydra_config = {
"roots": {},
"trunk": self.trunk,
"branches": self.branches,
"tasks": self.tasks,
"ensemble_size": self.ensemble_size,
"channel_dim": self.channel_dim,
"dropout_prob": self.dropout_prob,
}

# Convert root configs back to Hydra format
for root_name, root_config in self.roots.items():
if root_config.use_hf_model:
# For HF models, we'll need to create a wrapper config
hydra_config["roots"][root_name] = {
"_target_": "cortex.model.root.HuggingFaceRoot",
"model_name_or_path": root_config.processor_name,
"config": root_config.hf_config,
}
else:
# Use existing cortex config directly
hydra_config["roots"][root_name] = root_config.cortex_config

return hydra_config

@classmethod
def from_hydra_config(cls, hydra_config: Dict[str, Any]) -> "NeuralTreeConfig":
"""
Create NeuralTreeConfig from existing Hydra configuration.
This enables migration from existing configs.
"""
config = cls()

# Extract core tree components
config.trunk = hydra_config.get("trunk", {})
config.branches = hydra_config.get("branches", {})
config.tasks = hydra_config.get("tasks", {})

# Extract global settings
config.ensemble_size = hydra_config.get("ensemble_size", 1)
config.channel_dim = hydra_config.get("channel_dim", 64)
config.dropout_prob = hydra_config.get("dropout_prob", 0.0)

# Convert root configurations
for root_name, root_cfg in hydra_config.get("roots", {}).items():
if isinstance(root_cfg, dict):
# Detect if this is a HF model based on target or model_name_or_path
target = root_cfg.get("_target_", "")
if "HuggingFace" in target or "model_name_or_path" in root_cfg:
config.add_hf_root(
root_name, root_cfg.get("model_name_or_path", ""), root_cfg.get("processor_name")
)
else:
config.add_cortex_root(root_name, root_cfg)

return config
6 changes: 6 additions & 0 deletions cortex/corruption/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,10 @@
from ._diffusion_noise_schedule import get_named_beta_schedule
from ._gaussian_corruption import GaussianCorruptionProcess
from ._mask_corruption import MaskCorruptionProcess
from ._static_corruption import (
StaticCorruptionFactory,
StaticCorruptionProcess,
StaticGaussianCorruption,
StaticMaskCorruption,
)
from ._substitution_corruption import SubstitutionCorruptionProcess
Loading
Loading