Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f50065d
feat: Add automatic git tracking and storage info to MLflowLogger
adamimos Sep 4, 2025
60f83fe
fix: Handle case where simplexity.__file__ is None
adamimos Sep 4, 2025
e3d155f
fix: Improve simplexity path detection for git tracking
adamimos Sep 4, 2025
1b8ee3e
style: Apply ruff formatting to mlflow_logger.py
adamimos Sep 4, 2025
e47a3c0
fix: Use getattr to avoid pyright type checking errors
adamimos Sep 4, 2025
06ae759
Address PR review comments: refactor git tracking to base Logger class
adamimos Sep 5, 2025
b61f945
Address remaining PR review comments
adamimos Sep 5, 2025
c091415
feat: add core FactoredGenerativeProcess class structure
adamimos Sep 7, 2025
6f7d76c
feat: implement outer product token generation
adamimos Sep 7, 2025
86b7f1c
feat: complete GenerativeProcess interface implementation
adamimos Sep 7, 2025
b628429
feat: ensure training pipeline compatibility
adamimos Sep 7, 2025
e4e865b
feat: add builder support for factored generators
adamimos Sep 7, 2025
98cdff0
docs: add comprehensive factored generator summary
adamimos Sep 7, 2025
594254e
feat: add Hydra configuration support for factored generators
adamimos Sep 7, 2025
55a8691
docs: add reviewer task and approval criteria to PR review report
adamimos Sep 7, 2025
b835c39
Add comprehensive testing for factored generator implementation
adamimos Sep 7, 2025
37f5391
Fix type checking issues in factored generator tests
adamimos Sep 7, 2025
2d01c46
Implement JAX-native vectorized factored generator with performance i…
adamimos Sep 7, 2025
905912c
Remove temporary documentation files created during development
adamimos Sep 7, 2025
6bd79bf
Update CUDA dependency in pyproject.toml to use jax[cuda12] instead o…
adamimos Sep 7, 2025
478071c
Merge branch 'main' into feature/factored-generator
ealt Sep 8, 2025
ce7992c
Move persistence configuration from config.ini to YAML files (#73)
adamimos Sep 9, 2025
1780e80
Add plot and image logging support to MLflow integration (#74)
loren-ac Sep 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ dependencies = [
"orbax-checkpoint",
"pandas",
"penzai",
"plotly",
"treescope",
]

[project.optional-dependencies]
aws = ["boto3"]
cuda = ["jax[cuda12_pip]"]
cuda = ["jax[cuda12]"]
dev = ["jaxtyping", "nbqa", "pyright", "pytest", "pytest-cov", "ruff"]
mac = ["jax-metal"]
pytorch = ["torch"]
Expand Down
22 changes: 21 additions & 1 deletion simplexity/configs/generative_process/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Literal
from typing import Any, Literal

ProcessName = Literal[
"days_of_week",
Expand All @@ -11,11 +11,14 @@
"rrxor",
"tom_quantum",
"zero_one_random",
"factored_generator",
]

ProcessBuilder = Literal[
"simplexity.generative_processes.builder.build_generalized_hidden_markov_model",
"simplexity.generative_processes.builder.build_hidden_markov_model",
"simplexity.generative_processes.builder.build_factored_generator",
"simplexity.generative_processes.builder.build_factored_hmm_generator",
]
ProcessType = ProcessName

Expand Down Expand Up @@ -113,6 +116,23 @@ class ZeroOneRandomConfig(ProcessInstanceConfig):
p: float


@dataclass
class FactoredGeneratorConfig(ProcessInstanceConfig):
"""Configuration for factored generator with flexible components."""

# _target_: build_factored_generator
component_specs: list[dict[str, Any]]
component_types: list[str] | None = None # ["hmm", "ghmm", ...] or None for all GHMM


@dataclass
class FactoredHmmGeneratorConfig(ProcessInstanceConfig):
"""Configuration for factored generator with all HMM components."""

# _target_: build_factored_hmm_generator
component_specs: list[dict[str, Any]]


@dataclass
class Config:
"""Base configuration for predictive models."""
Expand Down
3 changes: 2 additions & 1 deletion simplexity/configs/persistence/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ class LocalPenzaiPersisterConfig(PersistenceInstanceConfig):
class S3PersisterConfig(PersistenceInstanceConfig):
"""Configuration for S3 persister."""

filename: str
prefix: str
model_framework: str
config_filename: str = "config.ini"


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion simplexity/configs/persistence/s3_persister.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ name: s3_persister

instance:
_target_: simplexity.persistence.s3_persister.S3Persister.from_config
filename: "config.ini"
prefix: "models"
model_framework: "equinox"
4 changes: 2 additions & 2 deletions simplexity/evaluation/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from simplexity.configs.evaluation.config import Config
from simplexity.evaluation.metric_functions import METRIC_FUNCTIONS
from simplexity.generative_processes.generative_process import GenerativeProcess
from simplexity.generative_processes.generator import generate_data_batch
from simplexity.generative_processes.generator import batch_state, generate_data_batch
from simplexity.logging.logger import Logger
from simplexity.predictive_models.predictive_model import PredictiveModel

Expand Down Expand Up @@ -36,7 +36,7 @@ def evaluate(
key = jax.random.PRNGKey(cfg.seed)

gen_state = data_generator.initial_state
gen_states = jnp.repeat(gen_state[None, :], cfg.batch_size, axis=0)
gen_states = batch_state(gen_state, cfg.batch_size)
metrics = defaultdict(lambda: jnp.array(0.0))

for step in range(1, cfg.num_steps + 1):
Expand Down
5 changes: 4 additions & 1 deletion simplexity/evaluation/evaluate_pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ def evaluate(
key = jax.random.PRNGKey(cfg.seed)

gen_state = data_generator.initial_state
gen_states = jnp.repeat(gen_state[None, :], cfg.batch_size, axis=0)
gen_states = jax.tree_util.tree_map(
lambda s: jnp.repeat(s[None, ...], cfg.batch_size, axis=0),
gen_state,
)
metrics = defaultdict(lambda: jnp.array(0.0))

for step in range(1, cfg.num_steps + 1):
Expand Down
78 changes: 78 additions & 0 deletions simplexity/generative_processes/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax
import jax.numpy as jnp

from simplexity.generative_processes.factored_generator import FactoredGenerativeProcess
from simplexity.generative_processes.generalized_hidden_markov_model import GeneralizedHiddenMarkovModel
from simplexity.generative_processes.hidden_markov_model import HiddenMarkovModel
from simplexity.generative_processes.transition_matrices import (
Expand Down Expand Up @@ -113,3 +114,80 @@ def build_nonergodic_hidden_markov_model(
initial_state = jnp.zeros((num_states,), dtype=composite_transition_matrix.dtype)
initial_state = initial_state.at[num_states - 1].set(1)
return HiddenMarkovModel(composite_transition_matrix, initial_state)


def build_factored_generator(
component_specs: Sequence[dict[str, Any]],
component_types: Sequence[str] | None = None,
_process_name: str | None = None, # For Hydra compatibility, ignored
**_kwargs, # For Hydra compatibility, ignored
) -> FactoredGenerativeProcess:
"""Build a factored generator from component specifications.

Args:
component_specs: List of component spec dicts with 'process_name' and other kwargs
component_types: List of component types ("hmm" or "ghmm"). If None, defaults to "ghmm"
process_name: Ignored, for Hydra config compatibility
**kwargs: Ignored additional args, for Hydra config compatibility

Returns:
FactoredGenerativeProcess with the specified components

Example:
# Create factored generator with 2 coin HMMs
factored_gen = build_factored_generator([
{"process_name": "zero_one_random", "p": 0.7},
{"process_name": "zero_one_random", "p": 0.3}
])

# Mix HMM and GHMM components
factored_gen = build_factored_generator([
{"process_name": "zero_one_random", "p": 0.8},
{"process_name": "days_of_week"}
], component_types=["hmm", "ghmm"])
"""
if component_types is None:
component_types = ["ghmm"] * len(component_specs)

if len(component_specs) != len(component_types):
raise ValueError("component_specs and component_types must have the same length")

components = []
for component_spec, component_type in zip(component_specs, component_types, strict=True):
# Extract process_name and remaining kwargs
spec_copy = component_spec.copy()
process_name = spec_copy.pop("process_name")
kwargs = spec_copy

if component_type == "hmm":
component = build_hidden_markov_model(process_name, **kwargs)
elif component_type == "ghmm":
component = build_generalized_hidden_markov_model(process_name, **kwargs)
else:
raise ValueError(f"Unknown component type: {component_type}. Must be 'hmm' or 'ghmm'")
components.append(component)

return FactoredGenerativeProcess(components)


def build_factored_hmm_generator(
component_specs: Sequence[dict[str, Any]],
_process_name: str | None = None, # For Hydra compatibility, ignored
**_kwargs, # For Hydra compatibility, ignored
) -> FactoredGenerativeProcess:
"""Build a factored generator with all HMM components.

Convenience function for the common case of all components being HMMs.

Args:
component_specs: List of component spec dicts with 'process_name' and other kwargs
process_name: Ignored, for Hydra config compatibility
**kwargs: Ignored additional args, for Hydra config compatibility

Example:
factored_gen = build_factored_hmm_generator([
{"process_name": "zero_one_random", "p": 0.7},
{"process_name": "zero_one_random", "p": 0.4}
])
"""
return build_factored_generator(component_specs, component_types=["hmm"] * len(component_specs))
Loading