From c744d7060ec6a1c99a32d20237b9b3cc4d8b54e4 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 19 Sep 2025 23:46:22 +0000 Subject: [PATCH 01/17] Create MLFlow persister --- simplexity/logging/mlflow_logger.py | 10 + simplexity/persistence/mlflow_persister.py | 215 +++++++++++++++++++++ tests/persistence/test_mlflow_persister.py | 142 ++++++++++++++ 3 files changed, 367 insertions(+) create mode 100644 simplexity/persistence/mlflow_persister.py create mode 100644 tests/persistence/test_mlflow_persister.py diff --git a/simplexity/logging/mlflow_logger.py b/simplexity/logging/mlflow_logger.py index 8bed4003..5f5fa605 100644 --- a/simplexity/logging/mlflow_logger.py +++ b/simplexity/logging/mlflow_logger.py @@ -42,6 +42,16 @@ def __init__( run = self._client.create_run(experiment_id=experiment_id, run_name=run_name) self._run_id = run.info.run_id + @property + def client(self) -> mlflow.MlflowClient: + """Expose underlying MLflow client for integrations.""" + return self._client + + @property + def run_id(self) -> str: + """Expose active MLflow run identifier.""" + return self._run_id + def log_config(self, config: DictConfig, resolve: bool = False) -> None: """Log config to MLflow.""" with tempfile.TemporaryDirectory() as temp_dir: diff --git a/simplexity/persistence/mlflow_persister.py b/simplexity/persistence/mlflow_persister.py new file mode 100644 index 00000000..0c0e4ec6 --- /dev/null +++ b/simplexity/persistence/mlflow_persister.py @@ -0,0 +1,215 @@ +"""MLflow-backed model persistence utilities.""" + +from __future__ import annotations + +import shutil +import tempfile +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from simplexity.persistence.model_persister import ModelPersister +from simplexity.predictive_models.predictive_model import PredictiveModel +from simplexity.predictive_models.types import ModelFramework + +if TYPE_CHECKING: + from mlflow import MlflowClient + from simplexity.logging.mlflow_logger import MLFlowLogger + + +def _normalize_artifact_path(artifact_path: str) -> str: + """Return a normalized artifact path without surrounding slashes.""" + artifact_path = artifact_path.strip() + return artifact_path.strip("/") + + +class MLFlowPersister(ModelPersister): + """Persist model checkpoints as MLflow artifacts, optionally reusing an existing run.""" + + client: Any + run_id: str + artifact_path: str + model_framework: ModelFramework + registered_model_name: str | None + _temp_dir: tempfile.TemporaryDirectory + _base_dir: Path + _artifact_dir: Path + _local_persister: ModelPersister + _registered_model_checked: bool + _managed_run: bool + + def __init__( + self, + client: "MlflowClient | Any", + run_id: str, + *, + artifact_path: str = "models", + model_framework: ModelFramework = ModelFramework.Equinox, + registered_model_name: str | None = None, + temp_dir: tempfile.TemporaryDirectory | None = None, + managed_run: bool = False, + ): + self.client = client + self.run_id = run_id + self.artifact_path = _normalize_artifact_path(artifact_path) + self.model_framework = model_framework + self.registered_model_name = registered_model_name + self._temp_dir = temp_dir or tempfile.TemporaryDirectory() + self._managed_run = managed_run + + # Local staging directories mirror the remote artifact layout for round-tripping. + self._base_dir = Path(self._temp_dir.name) + self._artifact_dir = self._base_dir / self.artifact_path if self.artifact_path else self._base_dir + self._artifact_dir.mkdir(parents=True, exist_ok=True) + self._local_persister = self._build_local_persister(self._artifact_dir) + self._registered_model_checked = False + + @classmethod + def from_experiment( + cls, + experiment_name: str, + *, + run_name: str | None = None, + tracking_uri: str | None = None, + artifact_path: str = "models", + model_framework: ModelFramework = ModelFramework.Equinox, + registered_model_name: str | None = None, + ) -> "MLFlowPersister": + import mlflow + + client = mlflow.MlflowClient(tracking_uri=tracking_uri) + experiment = client.get_experiment_by_name(experiment_name) + if experiment: + experiment_id = experiment.experiment_id + else: + experiment_id = client.create_experiment(experiment_name) + run = client.create_run(experiment_id=experiment_id, run_name=run_name) + return cls( + client=client, + run_id=run.info.run_id, + artifact_path=artifact_path, + model_framework=model_framework, + registered_model_name=registered_model_name, + managed_run=True, + ) + + @classmethod + def from_logger( + cls, + logger: "MLFlowLogger", + *, + artifact_path: str = "models", + model_framework: ModelFramework = ModelFramework.Equinox, + registered_model_name: str | None = None, + ) -> "MLFlowPersister": + """Create a persister reusing an existing MLFlowLogger run.""" + return cls( + client=logger.client, + run_id=logger.run_id, + artifact_path=artifact_path, + model_framework=model_framework, + registered_model_name=registered_model_name, + managed_run=False, + ) + + @property + def local_persister(self) -> ModelPersister: + """Expose the backing local persister (primarily for testing).""" + return self._local_persister + + def cleanup(self) -> None: + """Remove temporary resources and optionally end the MLflow run.""" + persister_cleanup = getattr(self._local_persister, "cleanup", None) + if callable(persister_cleanup): + persister_cleanup() + if self._managed_run: + try: + self.client.set_terminated(self.run_id) + except Exception: + # Cleanup is best-effort; ignore failures when ending the run. + pass + self._temp_dir.cleanup() + + def save_weights(self, model: PredictiveModel, step: int = 0) -> None: + """Serialize weights locally and upload them as MLflow artifacts.""" + self._clear_step_dir(step) + step_dir = self._artifact_dir / str(step) + self._local_persister.save_weights(model, step) + artifact_path = self._remote_step_path(step) + try: + self.client.log_artifacts(self.run_id, str(step_dir), artifact_path=artifact_path) + except Exception as exc: # pragma: no cover - exercised via mocks + raise RuntimeError(f"Failed to log model artifacts to MLflow at step {step}") from exc + self._maybe_register_model(artifact_path) + + def load_weights(self, model: PredictiveModel, step: int = 0) -> PredictiveModel: + """Download MLflow artifacts and restore them into the provided model.""" + self._clear_step_dir(step) + artifact_path = self._remote_step_path(step) + try: + downloaded_path = Path( + self.client.download_artifacts( + self.run_id, + artifact_path, + dst_path=str(self._base_dir), + ) + ) + except Exception as exc: # pragma: no cover - exercised via mocks + raise RuntimeError(f"Failed to download model artifacts from MLflow at step {step}") from exc + + if not downloaded_path.exists(): + raise RuntimeError(f"MLflow artifact for step {step} was not found after download") + + return self._local_persister.load_weights(model, step) + + def _build_local_persister(self, directory: Path) -> ModelPersister: + if self.model_framework == ModelFramework.Equinox: + from simplexity.persistence.local_equinox_persister import LocalEquinoxPersister + + return LocalEquinoxPersister(directory) + if self.model_framework == ModelFramework.Penzai: + from simplexity.persistence.local_penzai_persister import LocalPenzaiPersister + + return LocalPenzaiPersister(directory) + if self.model_framework == ModelFramework.Pytorch: + from simplexity.persistence.local_pytorch_persister import LocalPytorchPersister + + return LocalPytorchPersister(directory) + raise ValueError(f"Unsupported model framework: {self.model_framework}") + + def _remote_step_path(self, step: int) -> str: + parts: list[str] = [] + if self.artifact_path: + parts.append(self.artifact_path) + parts.append(str(step)) + return "/".join(parts) + + def _clear_step_dir(self, step: int) -> None: + step_dir = self._artifact_dir / str(step) + if step_dir.exists(): + shutil.rmtree(step_dir) + step_dir.parent.mkdir(parents=True, exist_ok=True) + + def _maybe_register_model(self, artifact_path: str) -> None: + if not self.registered_model_name: + return + + if not self._registered_model_checked: + try: + self.client.get_registered_model(self.registered_model_name) + except Exception: + try: + self.client.create_registered_model(self.registered_model_name) + except Exception: + pass + self._registered_model_checked = True + + source = f"runs:/{self.run_id}/{artifact_path}" + try: + self.client.create_model_version( + name=self.registered_model_name, + source=source, + run_id=self.run_id, + ) + except Exception: + # Surface registration failures as warnings while allowing training to proceed. + pass diff --git a/tests/persistence/test_mlflow_persister.py b/tests/persistence/test_mlflow_persister.py new file mode 100644 index 00000000..9c524c62 --- /dev/null +++ b/tests/persistence/test_mlflow_persister.py @@ -0,0 +1,142 @@ +"""Tests for MLFlowPersister behavior.""" + +import shutil +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +chex = pytest.importorskip("chex") +jax = pytest.importorskip("jax") + +from simplexity.persistence.mlflow_persister import MLFlowPersister +from simplexity.predictive_models.gru_rnn import GRURNN +from simplexity.predictive_models.types import ModelFramework + + +def get_model(seed: int) -> GRURNN: + return GRURNN(vocab_size=2, embedding_size=4, hidden_sizes=[3, 3], key=jax.random.PRNGKey(seed)) + + +@pytest.fixture +def mlflow_client_mock(tmp_path: Path) -> tuple[MagicMock, Path]: + """Create an MlflowClient mock that simulates artifact storage.""" + remote_root = tmp_path / "remote" + remote_root.mkdir() + + client = MagicMock() + + def log_artifacts(run_id: str, local_dir: str, artifact_path: str | None = None): + assert run_id == "run_123" + destination = remote_root if artifact_path is None else remote_root / artifact_path + if destination.exists(): + shutil.rmtree(destination) + destination.parent.mkdir(parents=True, exist_ok=True) + shutil.copytree(local_dir, destination) + + def download_artifacts(run_id: str, path: str, dst_path: str | None = None) -> str: + assert run_id == "run_123" + source = remote_root / path + if not source.exists(): + raise FileNotFoundError(path) + base_dir = Path(dst_path) if dst_path else tmp_path / "downloads" + destination = base_dir / path + if destination.exists(): + shutil.rmtree(destination) + destination.parent.mkdir(parents=True, exist_ok=True) + shutil.copytree(source, destination) + return str(destination) + + client.log_artifacts.side_effect = log_artifacts + client.download_artifacts.side_effect = download_artifacts + return client, remote_root + + +def test_mlflow_persister_round_trip(tmp_path: Path, mlflow_client_mock: tuple[MagicMock, Path]): + """Model weights saved via MLFLow can be restored back into a new instance.""" + client, remote_root = mlflow_client_mock + + persister = MLFlowPersister( + client=client, + run_id="run_123", + artifact_path="models", + model_framework=ModelFramework.Equinox, + ) + + original = get_model(0) + persister.save_weights(original, step=0) + + remote_model_path = remote_root / "models" / "0" / "model.eqx" + assert remote_model_path.exists() + + updated = get_model(1) + loaded = persister.load_weights(updated, step=0) + + chex.assert_trees_all_equal(loaded, original) + client.log_artifacts.assert_called_once() + client.download_artifacts.assert_called_once() + persister.cleanup() + + +def test_mlflow_persister_registers_versions(tmp_path: Path, mlflow_client_mock: tuple[MagicMock, Path]): + """Model versions are registered when a name is provided.""" + client, _ = mlflow_client_mock + client.get_registered_model.side_effect = Exception("missing") + + persister = MLFlowPersister( + client=client, + run_id="run_123", + artifact_path="models", + model_framework=ModelFramework.Equinox, + registered_model_name="TestModel", + ) + + persister.save_weights(get_model(2), step=5) + + client.create_registered_model.assert_called_once_with("TestModel") + client.create_model_version.assert_called_once() + call_kwargs = client.create_model_version.call_args.kwargs + assert call_kwargs["name"] == "TestModel" + assert call_kwargs["source"] == "runs:/run_123/models/5" + assert call_kwargs["run_id"] == "run_123" + persister.cleanup() + + +def test_mlflow_persister_from_logger_reuses_run( + tmp_path: Path, mlflow_client_mock: tuple[MagicMock, Path] +): + """Persister created from logger uses existing client/run without terminating it.""" + + class DummyLogger: + def __init__(self, client: MagicMock): + self._client = client + self._run_id = "run_123" + + @property + def client(self) -> MagicMock: + return self._client + + @property + def run_id(self) -> str: + return self._run_id + + client, remote_root = mlflow_client_mock + logger = DummyLogger(client) + + persister = MLFlowPersister.from_logger( + logger, + artifact_path="models", + model_framework=ModelFramework.Equinox, + ) + + original = get_model(0) + persister.save_weights(original, step=1) + + remote_model_path = remote_root / "models" / "1" / "model.eqx" + assert remote_model_path.exists() + + restored = persister.load_weights(get_model(1), step=1) + chex.assert_trees_all_equal(restored, original) + + persister.cleanup() + assert not client.set_terminated.called From b93d136246397fc03b05268d541607e95ba0c11e Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Tue, 30 Sep 2025 20:58:47 +0000 Subject: [PATCH 02/17] Fix lint issues --- simplexity/persistence/mlflow_persister.py | 35 ++++++++++++---------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/simplexity/persistence/mlflow_persister.py b/simplexity/persistence/mlflow_persister.py index 0c0e4ec6..38f9a950 100644 --- a/simplexity/persistence/mlflow_persister.py +++ b/simplexity/persistence/mlflow_persister.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: from mlflow import MlflowClient + from simplexity.logging.mlflow_logger import MLFlowLogger @@ -39,7 +40,7 @@ class MLFlowPersister(ModelPersister): def __init__( self, - client: "MlflowClient | Any", + client: MlflowClient | Any, run_id: str, *, artifact_path: str = "models", @@ -70,13 +71,15 @@ def from_experiment( *, run_name: str | None = None, tracking_uri: str | None = None, + registry_uri: str | None = None, artifact_path: str = "models", model_framework: ModelFramework = ModelFramework.Equinox, registered_model_name: str | None = None, - ) -> "MLFlowPersister": + ) -> MLFlowPersister: + """Create a persister from an MLflow experiment.""" import mlflow - client = mlflow.MlflowClient(tracking_uri=tracking_uri) + client = mlflow.MlflowClient(tracking_uri=tracking_uri, registry_uri=registry_uri) experiment = client.get_experiment_by_name(experiment_name) if experiment: experiment_id = experiment.experiment_id @@ -95,12 +98,12 @@ def from_experiment( @classmethod def from_logger( cls, - logger: "MLFlowLogger", + logger: MLFlowLogger, *, artifact_path: str = "models", model_framework: ModelFramework = ModelFramework.Equinox, registered_model_name: str | None = None, - ) -> "MLFlowPersister": + ) -> MLFlowPersister: """Create a persister reusing an existing MLFlowLogger run.""" return cls( client=logger.client, @@ -122,11 +125,11 @@ def cleanup(self) -> None: if callable(persister_cleanup): persister_cleanup() if self._managed_run: - try: - self.client.set_terminated(self.run_id) - except Exception: + import contextlib + + with contextlib.suppress(Exception): # Cleanup is best-effort; ignore failures when ending the run. - pass + self.client.set_terminated(self.run_id) self._temp_dir.cleanup() def save_weights(self, model: PredictiveModel, step: int = 0) -> None: @@ -197,19 +200,19 @@ def _maybe_register_model(self, artifact_path: str) -> None: try: self.client.get_registered_model(self.registered_model_name) except Exception: - try: + import contextlib + + with contextlib.suppress(Exception): self.client.create_registered_model(self.registered_model_name) - except Exception: - pass self._registered_model_checked = True source = f"runs:/{self.run_id}/{artifact_path}" - try: + import contextlib + + with contextlib.suppress(Exception): + # Surface registration failures as warnings while allowing training to proceed. self.client.create_model_version( name=self.registered_model_name, source=source, run_id=self.run_id, ) - except Exception: - # Surface registration failures as warnings while allowing training to proceed. - pass From ad18345a3186b57e931e69ef1af8c34bc3a47a32 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Tue, 30 Sep 2025 20:59:22 +0000 Subject: [PATCH 03/17] Use Unity Catelog --- pyproject.toml | 2 +- simplexity/configs/logging/config.py | 1 + simplexity/configs/logging/mlflow_logger.yaml | 1 + simplexity/logging/mlflow_logger.py | 15 ++++++++++++++- 4 files changed, 17 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index aa48833b..4ebc23a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "jax", "jupyter", "matplotlib", - "mlflow", + "mlflow>=3.0.0", "optax", "orbax-checkpoint", "pandas", diff --git a/simplexity/configs/logging/config.py b/simplexity/configs/logging/config.py index a44522d0..49dc5370 100644 --- a/simplexity/configs/logging/config.py +++ b/simplexity/configs/logging/config.py @@ -29,6 +29,7 @@ class MLFlowLoggerConfig(LoggingInstanceConfig): experiment_name: str run_name: str tracking_uri: str + registry_uri: str | None = None @dataclass diff --git a/simplexity/configs/logging/mlflow_logger.yaml b/simplexity/configs/logging/mlflow_logger.yaml index caf345e8..882b1fc4 100644 --- a/simplexity/configs/logging/mlflow_logger.yaml +++ b/simplexity/configs/logging/mlflow_logger.yaml @@ -4,3 +4,4 @@ instance: experiment_name: /Shared/${experiment_name} run_name: ${run_name} tracking_uri: databricks + registry_uri: databricks-uc diff --git a/simplexity/logging/mlflow_logger.py b/simplexity/logging/mlflow_logger.py index 5f5fa605..a5eb8a28 100644 --- a/simplexity/logging/mlflow_logger.py +++ b/simplexity/logging/mlflow_logger.py @@ -31,9 +31,10 @@ def __init__( experiment_name: str, run_name: str | None = None, tracking_uri: str | None = None, + registry_uri: str | None = None, ): """Initialize MLflow logger.""" - self._client = mlflow.MlflowClient(tracking_uri=tracking_uri) + self._client = mlflow.MlflowClient(tracking_uri=tracking_uri, registry_uri=registry_uri) experiment = self._client.get_experiment_by_name(experiment_name) if experiment: experiment_id = experiment.experiment_id @@ -41,6 +42,8 @@ def __init__( experiment_id = self._client.create_experiment(experiment_name) run = self._client.create_run(experiment_id=experiment_id, run_name=run_name) self._run_id = run.info.run_id + self._tracking_uri = tracking_uri + self._registry_uri = registry_uri @property def client(self) -> mlflow.MlflowClient: @@ -52,6 +55,16 @@ def run_id(self) -> str: """Expose active MLflow run identifier.""" return self._run_id + @property + def tracking_uri(self) -> str | None: + """Return the tracking URI associated with this logger.""" + return self._tracking_uri + + @property + def registry_uri(self) -> str | None: + """Return the model registry URI associated with this logger.""" + return self._registry_uri + def log_config(self, config: DictConfig, resolve: bool = False) -> None: """Log config to MLflow.""" with tempfile.TemporaryDirectory() as temp_dir: From 137e2f792834760f485c06463d137bebe0dde5b6 Mon Sep 17 00:00:00 2001 From: adamimos Date: Wed, 1 Oct 2025 18:07:14 +0000 Subject: [PATCH 04/17] Add configuration resolution and utility functions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added utilities to help resolve training configuration parameters and avoid redundant specification: - config_resolution.py: Functions to compute generator sequence length and model vocab size from each other and special token usage - persistence/utils.py: Checkpoint path parsing and step number formatting - Device resolution for both JAX and PyTorch frameworks 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Co-authored-by: Sculptor --- simplexity/persistence/utils.py | 77 +++++++++++++++++++++++++++ simplexity/utils/config_resolution.py | 67 +++++++++++++++++++++++ simplexity/utils/jnp.py | 33 ++++++++++++ simplexity/utils/pytorch_utils.py | 25 +++++++++ 4 files changed, 202 insertions(+) create mode 100644 simplexity/persistence/utils.py create mode 100644 simplexity/utils/config_resolution.py diff --git a/simplexity/persistence/utils.py b/simplexity/persistence/utils.py new file mode 100644 index 00000000..15d9c15c --- /dev/null +++ b/simplexity/persistence/utils.py @@ -0,0 +1,77 @@ +import re + + +def parse_checkpoint_step(path: str) -> int | None: + """Extract training step number from checkpoint path. + + Handles multiple formats: + - step_12345.pt / step-12345.pt + - 12345/model.pt + - model_weights/step_00012345.pt + + Args: + path: File path or S3 key containing checkpoint + + Returns: + Step number if found, None otherwise + + Examples: + >>> parse_checkpoint_step("model_weights/step_12345.pt") + 12345 + >>> parse_checkpoint_step("checkpoints/12345/model.pt") + 12345 + >>> parse_checkpoint_step("step-00500.pt") + 500 + """ + m = re.search(r"step[_-]?(\d+)\.pt$", path) + if m: + return int(m.group(1)) + + parts = path.split("/") + if parts and parts[-1] == "model.pt" and len(parts) >= 2: + try: + return int(parts[-2]) + except ValueError: + pass + + return None + + +def compute_step_width(max_steps: int) -> int: + """Compute zero-padding width for step numbers. + + Ensures lexicographic sorting matches chronological order. + + Args: + max_steps: Maximum number of training steps + + Returns: + Number of digits to use for zero-padding + + Examples: + >>> compute_step_width(999) + 3 + >>> compute_step_width(100000) + 6 + """ + return len(str(max_steps)) + + +def format_step_number(step: int, max_steps: int) -> str: + """Format step number with appropriate zero-padding. + + Args: + step: Current training step + max_steps: Maximum number of training steps + + Returns: + Zero-padded step string + + Examples: + >>> format_step_number(42, max_steps=100000) + '000042' + >>> format_step_number(999, max_steps=999) + '999' + """ + width = compute_step_width(max_steps) + return f"{step:0{width}d}" diff --git a/simplexity/utils/config_resolution.py b/simplexity/utils/config_resolution.py new file mode 100644 index 00000000..c6755af8 --- /dev/null +++ b/simplexity/utils/config_resolution.py @@ -0,0 +1,67 @@ +def compute_generator_sequence_length(model_n_ctx: int, use_bos: bool) -> int: + """Compute the generator's sequence length from model context length and BOS usage. + + The relationship is: model_n_ctx = generator_seq_len - 1 + BOS + + Solving for generator_seq_len: generator_seq_len = model_n_ctx + 1 - BOS + + Args: + model_n_ctx: The model's context length (number of input positions it processes) + use_bos: Whether a beginning-of-sequence token is prepended during data generation + + Returns: + The sequence length to configure for the data generator + + Examples: + >>> compute_generator_sequence_length(model_n_ctx=512, use_bos=True) + 512 + >>> compute_generator_sequence_length(model_n_ctx=512, use_bos=False) + 513 + """ + return model_n_ctx + 1 - int(use_bos) + + +def compute_model_context_length(generator_seq_len: int, use_bos: bool) -> int: + """Compute the model's context length from generator sequence length and BOS usage. + + The relationship is: model_n_ctx = generator_seq_len - 1 + BOS + + Args: + generator_seq_len: The sequence length configured for the data generator + use_bos: Whether a beginning-of-sequence token is prepended during data generation + + Returns: + The context length for the model (number of input positions it will process) + + Examples: + >>> compute_model_context_length(generator_seq_len=512, use_bos=True) + 512 + >>> compute_model_context_length(generator_seq_len=513, use_bos=False) + 512 + """ + return generator_seq_len - 1 + int(use_bos) + + +def compute_model_vocab_size(generator_vocab_size: int, use_bos: bool, use_eos: bool) -> int: + """Compute the model's vocabulary size from generator vocab and special tokens. + + When BOS or EOS tokens are used during data generation, they are added to the vocabulary, + increasing the total vocab size the model needs to handle. + + Args: + generator_vocab_size: The vocabulary size of the data generator + use_bos: Whether a beginning-of-sequence token is used during data generation + use_eos: Whether an end-of-sequence token is used during data generation + + Returns: + The vocabulary size the model should be configured with + + Examples: + >>> compute_model_vocab_size(generator_vocab_size=100, use_bos=True, use_eos=False) + 101 + >>> compute_model_vocab_size(generator_vocab_size=100, use_bos=True, use_eos=True) + 102 + >>> compute_model_vocab_size(generator_vocab_size=100, use_bos=False, use_eos=False) + 100 + """ + return generator_vocab_size + int(use_bos) + int(use_eos) diff --git a/simplexity/utils/jnp.py b/simplexity/utils/jnp.py index eab6db00..b600656a 100644 --- a/simplexity/utils/jnp.py +++ b/simplexity/utils/jnp.py @@ -3,6 +3,39 @@ import jax.numpy as jnp +def resolve_jax_device(device_spec: str | None = "auto") -> jax.Device: + """Resolve device specification to actual JAX device. + + Args: + device_spec: One of "auto", "gpu", "cuda", "cpu", or None (treated as "auto") + + Returns: + JAX device object + + Examples: + >>> resolve_jax_device("auto") # On GPU machine + GpuDevice(id=0, ...) + >>> resolve_jax_device("cpu") + CpuDevice(id=0) + """ + if device_spec is None or device_spec == "auto": + devices = jax.devices("gpu") + if devices: + return devices[0] + return jax.devices("cpu")[0] + + if device_spec in ("gpu", "cuda"): + devices = jax.devices("gpu") + if not devices: + raise RuntimeError("GPU requested but no GPU devices available") + return devices[0] + + if device_spec == "cpu": + return jax.devices("cpu")[0] + + raise ValueError(f"Unknown device specification: {device_spec}") + + @eqx.filter_jit def entropy(probs: jax.Array, log: bool = False) -> jax.Array: """Compute the entropy of a log probability distribution.""" diff --git a/simplexity/utils/pytorch_utils.py b/simplexity/utils/pytorch_utils.py index 957f869f..2492f31d 100644 --- a/simplexity/utils/pytorch_utils.py +++ b/simplexity/utils/pytorch_utils.py @@ -78,3 +78,28 @@ def torch_to_jax(torch_tensor: torch.Tensor) -> jax.Array: numpy_array = torch_tensor.detach().cpu().numpy() jax_array = jnp.array(numpy_array) return jax_array + + +def resolve_device(device_spec: str | None = "auto") -> str: + """Resolve device specification to actual PyTorch device string. + + Args: + device_spec: One of "auto", "cuda", "mps", "cpu", or None (treated as "auto") + + Returns: + Resolved device string: "cuda", "mps", or "cpu" + + Examples: + >>> resolve_device("auto") # On CUDA machine + 'cuda' + >>> resolve_device("cpu") + 'cpu' + """ + if device_spec is None or device_spec == "auto": + if torch.cuda.is_available(): + return "cuda" + elif torch.backends.mps.is_available(): + return "mps" + else: + return "cpu" + return device_spec From d9f928329950d0cebacf76c4b2073443c8868e0c Mon Sep 17 00:00:00 2001 From: adamimos Date: Wed, 1 Oct 2025 11:33:01 -0700 Subject: [PATCH 05/17] Add comprehensive test coverage for config resolution utilities MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add tests/utils/test_config_resolution.py with 45 tests for compute_generator_sequence_length, compute_model_context_length, and compute_model_vocab_size - Add tests/persistence/test_utils.py with 51 tests for parse_checkpoint_step, compute_step_width, and format_step_number - Add resolve_jax_device tests to tests/utils/test_jnp.py with error handling for GPU unavailable scenarios - Add resolve_device tests to tests/utils/test_pytorch_utils.py with CUDA/MPS availability checks - Fix resolve_device() to validate input and raise ValueError for unknown specs, RuntimeError when unavailable - Fix resolve_jax_device() to handle JAX RuntimeError when GPU backend unavailable - Standardize error handling between JAX and PyTorch device resolution functions All tests pass (257 passed, 6 skipped due to hardware unavailability) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- simplexity/utils/jnp.py | 20 ++- simplexity/utils/pytorch_utils.py | 20 ++- tests/persistence/test_utils.py | 177 ++++++++++++++++++++++++++ tests/utils/test_config_resolution.py | 145 +++++++++++++++++++++ tests/utils/test_jnp.py | 85 ++++++++++++- tests/utils/test_pytorch_utils.py | 71 ++++++++++- 6 files changed, 508 insertions(+), 10 deletions(-) create mode 100644 tests/persistence/test_utils.py create mode 100644 tests/utils/test_config_resolution.py diff --git a/simplexity/utils/jnp.py b/simplexity/utils/jnp.py index b600656a..a30e9958 100644 --- a/simplexity/utils/jnp.py +++ b/simplexity/utils/jnp.py @@ -19,16 +19,22 @@ def resolve_jax_device(device_spec: str | None = "auto") -> jax.Device: CpuDevice(id=0) """ if device_spec is None or device_spec == "auto": - devices = jax.devices("gpu") - if devices: - return devices[0] + try: + devices = jax.devices("gpu") + if devices: + return devices[0] + except RuntimeError: + pass return jax.devices("cpu")[0] if device_spec in ("gpu", "cuda"): - devices = jax.devices("gpu") - if not devices: - raise RuntimeError("GPU requested but no GPU devices available") - return devices[0] + try: + devices = jax.devices("gpu") + if devices: + return devices[0] + except RuntimeError: + pass + raise RuntimeError("GPU requested but no GPU devices available") if device_spec == "cpu": return jax.devices("cpu")[0] diff --git a/simplexity/utils/pytorch_utils.py b/simplexity/utils/pytorch_utils.py index 2492f31d..7ba25be7 100644 --- a/simplexity/utils/pytorch_utils.py +++ b/simplexity/utils/pytorch_utils.py @@ -89,6 +89,10 @@ def resolve_device(device_spec: str | None = "auto") -> str: Returns: Resolved device string: "cuda", "mps", or "cpu" + Raises: + ValueError: If device_spec is not a recognized device type + RuntimeError: If a specific device is requested but unavailable + Examples: >>> resolve_device("auto") # On CUDA machine 'cuda' @@ -102,4 +106,18 @@ def resolve_device(device_spec: str | None = "auto") -> str: return "mps" else: return "cpu" - return device_spec + + if device_spec == "cuda": + if not torch.cuda.is_available(): + raise RuntimeError("CUDA requested but CUDA is not available") + return "cuda" + + if device_spec == "mps": + if not torch.backends.mps.is_available(): + raise RuntimeError("MPS requested but MPS is not available") + return "mps" + + if device_spec == "cpu": + return "cpu" + + raise ValueError(f"Unknown device specification: {device_spec}") diff --git a/tests/persistence/test_utils.py b/tests/persistence/test_utils.py new file mode 100644 index 00000000..eecfd9c3 --- /dev/null +++ b/tests/persistence/test_utils.py @@ -0,0 +1,177 @@ +import pytest + +from simplexity.persistence.utils import compute_step_width, format_step_number, parse_checkpoint_step + + +class TestParseCheckpointStep: + """Test parse_checkpoint_step function.""" + + @pytest.mark.parametrize( + ("path", "expected"), + [ + ("model_weights/step_12345.pt", 12345), + ("step_12345.pt", 12345), + ("step_00012345.pt", 12345), + ("checkpoints/step_500.pt", 500), + ("path/to/step_999.pt", 999), + ], + ) + def test_step_underscore_format(self, path: str, expected: int): + """Test parsing step_XXXX.pt format.""" + assert parse_checkpoint_step(path) == expected + + @pytest.mark.parametrize( + ("path", "expected"), + [ + ("step-12345.pt", 12345), + ("step-00500.pt", 500), + ("model_weights/step-999.pt", 999), + ], + ) + def test_step_hyphen_format(self, path: str, expected: int): + """Test parsing step-XXXX.pt format.""" + assert parse_checkpoint_step(path) == expected + + @pytest.mark.parametrize( + ("path", "expected"), + [ + ("12345/model.pt", 12345), + ("checkpoints/12345/model.pt", 12345), + ("path/to/500/model.pt", 500), + ("0/model.pt", 0), + ], + ) + def test_directory_model_format(self, path: str, expected: int): + """Test parsing XXXX/model.pt format.""" + assert parse_checkpoint_step(path) == expected + + @pytest.mark.parametrize( + "path", + [ + "model.pt", + "checkpoint.pt", + "step.pt", + "weights/model.eqx", + "random_file.txt", + "step_abc.pt", + "nonumeric/model.pt", + ], + ) + def test_no_match_returns_none(self, path: str): + """Test paths that should not match any pattern.""" + assert parse_checkpoint_step(path) is None + + def test_zero_padded_step_numbers(self): + """Test that zero-padded step numbers are correctly parsed.""" + assert parse_checkpoint_step("step_00042.pt") == 42 + assert parse_checkpoint_step("step_00000.pt") == 0 + assert parse_checkpoint_step("0000/model.pt") == 0 + + def test_step_pattern_takes_precedence_over_directory(self): + """Test that step_*.pt pattern takes precedence over directory pattern.""" + assert parse_checkpoint_step("path/step_200.pt") == 200 + assert parse_checkpoint_step("checkpoints/step_999.pt") == 999 + + def test_windows_paths(self): + """Test Windows-style paths with backslashes.""" + path_unix = "checkpoints/12345/model.pt" + assert parse_checkpoint_step(path_unix) == 12345 + + def test_s3_style_keys(self): + """Test S3 object key formats.""" + assert parse_checkpoint_step("s3://bucket/prefix/step_12345.pt") == 12345 + assert parse_checkpoint_step("prefix/run_name/12345/model.pt") == 12345 + + +class TestComputeStepWidth: + """Test compute_step_width function.""" + + def test_single_digit(self): + """Test with max_steps requiring 1 digit.""" + assert compute_step_width(9) == 1 + assert compute_step_width(1) == 1 + + def test_two_digits(self): + """Test with max_steps requiring 2 digits.""" + assert compute_step_width(10) == 2 + assert compute_step_width(99) == 2 + + def test_three_digits(self): + """Test with max_steps requiring 3 digits.""" + assert compute_step_width(100) == 3 + assert compute_step_width(999) == 3 + + @pytest.mark.parametrize( + ("max_steps", "expected_width"), + [ + (1, 1), + (9, 1), + (10, 2), + (99, 2), + (100, 3), + (999, 3), + (1000, 4), + (9999, 4), + (10000, 5), + (99999, 5), + (100000, 6), + ], + ) + def test_parametrized_widths(self, max_steps: int, expected_width: int): + """Test various max_steps values.""" + assert compute_step_width(max_steps) == expected_width + + def test_large_step_counts(self): + """Test with very large step counts.""" + assert compute_step_width(1000000) == 7 + assert compute_step_width(10000000) == 8 + + +class TestFormatStepNumber: + """Test format_step_number function.""" + + def test_basic_formatting(self): + """Test basic zero-padding behavior.""" + assert format_step_number(42, max_steps=100) == "042" + assert format_step_number(5, max_steps=1000) == "0005" + + def test_no_padding_needed(self): + """Test when step already has maximum width.""" + assert format_step_number(999, max_steps=999) == "999" + assert format_step_number(100, max_steps=100) == "100" + + def test_zero_step(self): + """Test formatting step 0.""" + assert format_step_number(0, max_steps=100) == "000" + assert format_step_number(0, max_steps=10000) == "00000" + + @pytest.mark.parametrize( + ("step", "max_steps", "expected"), + [ + (0, 999, "000"), + (1, 999, "001"), + (42, 999, "042"), + (999, 999, "999"), + (0, 100000, "000000"), + (42, 100000, "000042"), + (12345, 100000, "012345"), + (100000, 100000, "100000"), + ], + ) + def test_parametrized_formatting(self, step: int, max_steps: int, expected: str): + """Test various step and max_steps combinations.""" + assert format_step_number(step, max_steps) == expected + + def test_lexicographic_ordering(self): + """Verify that formatted strings sort lexicographically.""" + max_steps = 10000 + formatted = [format_step_number(i, max_steps) for i in [1, 10, 100, 1000, 9999]] + assert formatted == sorted(formatted) + + def test_consistency_with_compute_step_width(self): + """Verify format_step_number uses compute_step_width correctly.""" + max_steps = 100000 + step = 42 + formatted = format_step_number(step, max_steps) + expected_width = compute_step_width(max_steps) + assert len(formatted) == expected_width diff --git a/tests/utils/test_config_resolution.py b/tests/utils/test_config_resolution.py new file mode 100644 index 00000000..2d66ce8e --- /dev/null +++ b/tests/utils/test_config_resolution.py @@ -0,0 +1,145 @@ +import pytest + +from simplexity.utils.config_resolution import ( + compute_generator_sequence_length, + compute_model_context_length, + compute_model_vocab_size, +) + + +class TestComputeGeneratorSequenceLength: + """Test compute_generator_sequence_length function.""" + + def test_with_bos_token(self): + """When BOS is used, generator_seq_len should equal model_n_ctx.""" + assert compute_generator_sequence_length(model_n_ctx=512, use_bos=True) == 512 + assert compute_generator_sequence_length(model_n_ctx=100, use_bos=True) == 100 + + def test_without_bos_token(self): + """When BOS is not used, generator_seq_len should be model_n_ctx + 1.""" + assert compute_generator_sequence_length(model_n_ctx=512, use_bos=False) == 513 + assert compute_generator_sequence_length(model_n_ctx=100, use_bos=False) == 101 + + @pytest.mark.parametrize( + ("model_n_ctx", "use_bos", "expected"), + [ + (1, True, 1), + (1, False, 2), + (64, True, 64), + (64, False, 65), + (1024, True, 1024), + (1024, False, 1025), + ], + ) + def test_parametrized_cases(self, model_n_ctx: int, use_bos: bool, expected: int): + """Test various combinations of model_n_ctx and use_bos.""" + assert compute_generator_sequence_length(model_n_ctx, use_bos) == expected + + def test_zero_context_with_bos(self): + """Edge case: zero context length with BOS.""" + assert compute_generator_sequence_length(model_n_ctx=0, use_bos=True) == 0 + + def test_zero_context_without_bos(self): + """Edge case: zero context length without BOS.""" + assert compute_generator_sequence_length(model_n_ctx=0, use_bos=False) == 1 + + +class TestComputeModelContextLength: + """Test compute_model_context_length function.""" + + def test_with_bos_token(self): + """When BOS is used, model_n_ctx should equal generator_seq_len.""" + assert compute_model_context_length(generator_seq_len=512, use_bos=True) == 512 + assert compute_model_context_length(generator_seq_len=100, use_bos=True) == 100 + + def test_without_bos_token(self): + """When BOS is not used, model_n_ctx should be generator_seq_len - 1.""" + assert compute_model_context_length(generator_seq_len=513, use_bos=False) == 512 + assert compute_model_context_length(generator_seq_len=101, use_bos=False) == 100 + + @pytest.mark.parametrize( + ("generator_seq_len", "use_bos", "expected"), + [ + (1, True, 1), + (2, False, 1), + (64, True, 64), + (65, False, 64), + (1024, True, 1024), + (1025, False, 1024), + ], + ) + def test_parametrized_cases(self, generator_seq_len: int, use_bos: bool, expected: int): + """Test various combinations of generator_seq_len and use_bos.""" + assert compute_model_context_length(generator_seq_len, use_bos) == expected + + def test_inverse_relationship_with_bos(self): + """Verify inverse relationship with compute_generator_sequence_length when using BOS.""" + model_n_ctx = 512 + use_bos = True + gen_seq_len = compute_generator_sequence_length(model_n_ctx, use_bos) + recovered_n_ctx = compute_model_context_length(gen_seq_len, use_bos) + assert recovered_n_ctx == model_n_ctx + + def test_inverse_relationship_without_bos(self): + """Verify inverse relationship with compute_generator_sequence_length without BOS.""" + model_n_ctx = 512 + use_bos = False + gen_seq_len = compute_generator_sequence_length(model_n_ctx, use_bos) + recovered_n_ctx = compute_model_context_length(gen_seq_len, use_bos) + assert recovered_n_ctx == model_n_ctx + + @pytest.mark.parametrize("model_n_ctx", [1, 64, 128, 512, 1024]) + @pytest.mark.parametrize("use_bos", [True, False]) + def test_round_trip_consistency(self, model_n_ctx: int, use_bos: bool): + """Verify round-trip conversion maintains original value.""" + gen_seq_len = compute_generator_sequence_length(model_n_ctx, use_bos) + recovered = compute_model_context_length(gen_seq_len, use_bos) + assert recovered == model_n_ctx + + +class TestComputeModelVocabSize: + """Test compute_model_vocab_size function.""" + + def test_no_special_tokens(self): + """When no special tokens are used, vocab size should equal generator vocab.""" + assert compute_model_vocab_size(generator_vocab_size=100, use_bos=False, use_eos=False) == 100 + + def test_with_bos_only(self): + """When only BOS is used, vocab size should be generator_vocab + 1.""" + assert compute_model_vocab_size(generator_vocab_size=100, use_bos=True, use_eos=False) == 101 + + def test_with_eos_only(self): + """When only EOS is used, vocab size should be generator_vocab + 1.""" + assert compute_model_vocab_size(generator_vocab_size=100, use_bos=False, use_eos=True) == 101 + + def test_with_both_special_tokens(self): + """When both BOS and EOS are used, vocab size should be generator_vocab + 2.""" + assert compute_model_vocab_size(generator_vocab_size=100, use_bos=True, use_eos=True) == 102 + + @pytest.mark.parametrize( + ("generator_vocab_size", "use_bos", "use_eos", "expected"), + [ + (100, False, False, 100), + (100, True, False, 101), + (100, False, True, 101), + (100, True, True, 102), + (1, False, False, 1), + (1, True, True, 3), + (50257, False, False, 50257), + (50257, True, False, 50258), + (50257, True, True, 50259), + ], + ) + def test_parametrized_cases( + self, generator_vocab_size: int, use_bos: bool, use_eos: bool, expected: int + ): + """Test various combinations of vocab size and special tokens.""" + assert compute_model_vocab_size(generator_vocab_size, use_bos, use_eos) == expected + + def test_minimal_vocab_with_tokens(self): + """Edge case: minimal vocabulary with special tokens.""" + assert compute_model_vocab_size(generator_vocab_size=2, use_bos=True, use_eos=True) == 4 + + def test_large_vocab(self): + """Test with large vocabulary sizes.""" + assert compute_model_vocab_size(generator_vocab_size=100000, use_bos=True, use_eos=True) == 100002 diff --git a/tests/utils/test_jnp.py b/tests/utils/test_jnp.py index dfcf5b6c..ecbd5b88 100644 --- a/tests/utils/test_jnp.py +++ b/tests/utils/test_jnp.py @@ -1,8 +1,9 @@ import chex import jax import jax.numpy as jnp +import pytest -from simplexity.utils.jnp import log_matmul, signed_logsumexp +from simplexity.utils.jnp import log_matmul, resolve_jax_device, signed_logsumexp def test_log_matmul(): @@ -43,3 +44,85 @@ def test_signed_logsumexp(): chex.assert_trees_all_close(actual_log_abs_values, expected_log_abs_values) chex.assert_trees_all_close(actual_signs, expected_signs) + + +class TestResolveJaxDevice: + """Test resolve_jax_device function.""" + + def test_auto_mode_returns_device(self): + """Test auto mode returns a valid JAX device.""" + device = resolve_jax_device("auto") + assert isinstance(device, jax.Device) + + def test_none_treated_as_auto(self): + """Test None is treated as auto mode.""" + device = resolve_jax_device(None) + assert isinstance(device, jax.Device) + + def test_cpu_returns_cpu_device(self): + """Test explicit CPU request returns CPU device.""" + device = resolve_jax_device("cpu") + assert isinstance(device, jax.Device) + assert "cpu" in str(device).lower() + + def test_gpu_when_available(self): + """Test GPU request when GPU is available.""" + try: + gpu_devices = jax.devices("gpu") + if not gpu_devices: + pytest.skip("GPU not available") + except RuntimeError: + pytest.skip("GPU not available") + + device = resolve_jax_device("gpu") + assert isinstance(device, jax.Device) + assert "gpu" in str(device).lower() or "cuda" in str(device).lower() + + def test_cuda_when_available(self): + """Test CUDA request when GPU is available.""" + try: + gpu_devices = jax.devices("gpu") + if not gpu_devices: + pytest.skip("GPU not available") + except RuntimeError: + pytest.skip("GPU not available") + + device = resolve_jax_device("cuda") + assert isinstance(device, jax.Device) + assert "gpu" in str(device).lower() or "cuda" in str(device).lower() + + def test_gpu_unavailable_raises_runtime_error(self): + """Test GPU request raises RuntimeError when GPU unavailable.""" + try: + gpu_devices = jax.devices("gpu") + if gpu_devices: + pytest.skip("GPU is available, cannot test unavailable case") + except RuntimeError: + pass + + with pytest.raises(RuntimeError, match="GPU requested but no GPU devices available"): + resolve_jax_device("gpu") + + def test_cuda_unavailable_raises_runtime_error(self): + """Test CUDA request raises RuntimeError when GPU unavailable.""" + try: + gpu_devices = jax.devices("gpu") + if gpu_devices: + pytest.skip("GPU is available, cannot test unavailable case") + except RuntimeError: + pass + + with pytest.raises(RuntimeError, match="GPU requested but no GPU devices available"): + resolve_jax_device("cuda") + + def test_invalid_spec_raises_value_error(self): + """Test invalid device spec raises ValueError.""" + with pytest.raises(ValueError, match="Unknown device specification"): + resolve_jax_device("invalid_device") + + def test_unknown_specs_raise_value_error(self): + """Test various unknown specs raise ValueError.""" + invalid_specs = ["tpu", "gpu0", "cuda:0", "mps", "unknown"] + for spec in invalid_specs: + with pytest.raises(ValueError, match="Unknown device specification"): + resolve_jax_device(spec) diff --git a/tests/utils/test_pytorch_utils.py b/tests/utils/test_pytorch_utils.py index 34c736f0..711e28cd 100644 --- a/tests/utils/test_pytorch_utils.py +++ b/tests/utils/test_pytorch_utils.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from simplexity.utils.pytorch_utils import jax_to_torch, torch_to_jax +from simplexity.utils.pytorch_utils import jax_to_torch, resolve_device, torch_to_jax try: import torch @@ -40,3 +40,72 @@ def test_torch_to_jax(device: str): assert jax_array.shape == (2, 2) assert jax_array.dtype == jnp.float32 np.testing.assert_array_equal(jax_array, torch_tensor.cpu().numpy()) + + +class TestResolveDevice: + """Test resolve_device function.""" + + def test_auto_mode_returns_valid_device(self): + """Test auto mode returns a valid PyTorch device string.""" + device = resolve_device("auto") + assert device in ("cuda", "mps", "cpu") + + def test_none_treated_as_auto(self): + """Test None is treated as auto mode.""" + device = resolve_device(None) + assert device in ("cuda", "mps", "cpu") + + def test_cpu_always_available(self): + """Test CPU is always available.""" + device = resolve_device("cpu") + assert device == "cpu" + + def test_cuda_when_available(self): + """Test CUDA request when CUDA is available.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + device = resolve_device("cuda") + assert device == "cuda" + + def test_cuda_unavailable_raises_runtime_error(self): + """Test CUDA request raises RuntimeError when CUDA unavailable.""" + if torch.cuda.is_available(): + pytest.skip("CUDA is available, cannot test unavailable case") + with pytest.raises(RuntimeError, match="CUDA requested but CUDA is not available"): + resolve_device("cuda") + + def test_mps_when_available(self): + """Test MPS request when MPS is available.""" + if not torch.backends.mps.is_available(): + pytest.skip("MPS not available") + device = resolve_device("mps") + assert device == "mps" + + def test_mps_unavailable_raises_runtime_error(self): + """Test MPS request raises RuntimeError when MPS unavailable.""" + if torch.backends.mps.is_available(): + pytest.skip("MPS is available, cannot test unavailable case") + with pytest.raises(RuntimeError, match="MPS requested but MPS is not available"): + resolve_device("mps") + + def test_invalid_spec_raises_value_error(self): + """Test invalid device spec raises ValueError.""" + with pytest.raises(ValueError, match="Unknown device specification"): + resolve_device("invalid_device") + + def test_unknown_specs_raise_value_error(self): + """Test various unknown specs raise ValueError.""" + invalid_specs = ["gpu", "cuda:0", "cuda:1", "tpu", "unknown"] + for spec in invalid_specs: + with pytest.raises(ValueError, match="Unknown device specification"): + resolve_device(spec) + + def test_auto_mode_priority_order(self): + """Test auto mode follows CUDA -> MPS -> CPU priority.""" + device = resolve_device("auto") + if torch.cuda.is_available(): + assert device == "cuda" + elif torch.backends.mps.is_available(): + assert device == "mps" + else: + assert device == "cpu" From bf2a3a6d140a3c719bf4c64a8c101550e72d5a17 Mon Sep 17 00:00:00 2001 From: adamimos Date: Wed, 1 Oct 2025 11:36:25 -0700 Subject: [PATCH 06/17] Apply ruff formatting to test_config_resolution.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- tests/utils/test_config_resolution.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/utils/test_config_resolution.py b/tests/utils/test_config_resolution.py index 2d66ce8e..73e3bfc3 100644 --- a/tests/utils/test_config_resolution.py +++ b/tests/utils/test_config_resolution.py @@ -130,9 +130,7 @@ def test_with_both_special_tokens(self): (50257, True, True, 50259), ], ) - def test_parametrized_cases( - self, generator_vocab_size: int, use_bos: bool, use_eos: bool, expected: int - ): + def test_parametrized_cases(self, generator_vocab_size: int, use_bos: bool, use_eos: bool, expected: int): """Test various combinations of vocab size and special tokens.""" assert compute_model_vocab_size(generator_vocab_size, use_bos, use_eos) == expected From 5aa788756d6cc755b3002501e3e656cd2f37d58e Mon Sep 17 00:00:00 2001 From: adamimos Date: Thu, 2 Oct 2025 16:13:00 -0700 Subject: [PATCH 07/17] Address PR review feedback from ealt MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Inline compute_step_width() into format_step_number() - Add input validation to format_step_number() (assert 0 <= step <= max_steps) - Remove unused step_*.pt and step-*.pt naming patterns from parse_checkpoint_step() - Add get_checkpoint_path() utility function and update local_pytorch_persister to use it - Add use_eos parameter to compute_generator_sequence_length() and compute_model_context_length() - Remove Windows and S3 tests from test_utils.py - Remove TestComputeStepWidth class (function inlined) - Add comprehensive tests for use_eos parameter in config resolution tests All tests pass (100/100), ruff formatting/linting passes, pyright type checking passes. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- .../persistence/local_pytorch_persister.py | 8 +- simplexity/persistence/utils.py | 62 +++++----- simplexity/utils/config_resolution.py | 32 ++--- tests/persistence/test_utils.py | 110 ++++-------------- tests/utils/test_config_resolution.py | 79 +++++++++---- 5 files changed, 130 insertions(+), 161 deletions(-) diff --git a/simplexity/persistence/local_pytorch_persister.py b/simplexity/persistence/local_pytorch_persister.py index 7696406d..4ad3618d 100644 --- a/simplexity/persistence/local_pytorch_persister.py +++ b/simplexity/persistence/local_pytorch_persister.py @@ -1,6 +1,7 @@ from pathlib import Path from simplexity.persistence.local_persister import LocalPersister +from simplexity.persistence.utils import get_checkpoint_path try: import torch @@ -20,7 +21,7 @@ def __init__(self, directory: str | Path, filename: str = "model.pt"): # TODO: This is a hack to get the type checker to work. def save_weights(self, model: torch.nn.Module, step: int = 0, overwrite_existing: bool = False) -> None: # type: ignore """Saves a PyTorch model to the local filesystem.""" - path = self._get_path(step) + path = get_checkpoint_path(self.directory, step, self.filename) path.parent.mkdir(parents=True, exist_ok=True) if overwrite_existing and path.exists(): @@ -31,11 +32,8 @@ def save_weights(self, model: torch.nn.Module, step: int = 0, overwrite_existing # TODO: This is a hack to get the type checker to work. def load_weights(self, model: torch.nn.Module, step: int = 0) -> torch.nn.Module: # type: ignore """Loads weights into a PyTorch model from the local filesystem.""" - path = self._get_path(step) + path = get_checkpoint_path(self.directory, step, self.filename) device = next(model.parameters()).device if list(model.parameters()) else "cpu" state_dict = torch.load(path, map_location=device) model.load_state_dict(state_dict) return model - - def _get_path(self, step: int) -> Path: - return self.directory / str(step) / self.filename diff --git a/simplexity/persistence/utils.py b/simplexity/persistence/utils.py index 15d9c15c..a698c6ce 100644 --- a/simplexity/persistence/utils.py +++ b/simplexity/persistence/utils.py @@ -1,13 +1,30 @@ -import re +from pathlib import Path + + +def get_checkpoint_path(directory: Path, step: int, filename: str = "model.pt") -> Path: + """Construct checkpoint path following the standard naming convention. + + Args: + directory: Base directory for checkpoints + step: Training step number + filename: Checkpoint filename (default: "model.pt") + + Returns: + Path to checkpoint file: {directory}/{step}/{filename} + + Examples: + >>> get_checkpoint_path(Path("checkpoints"), 12345) + PosixPath('checkpoints/12345/model.pt') + >>> get_checkpoint_path(Path("weights"), 100, "state.pt") + PosixPath('weights/100/state.pt') + """ + return directory / str(step) / filename def parse_checkpoint_step(path: str) -> int | None: """Extract training step number from checkpoint path. - Handles multiple formats: - - step_12345.pt / step-12345.pt - - 12345/model.pt - - model_weights/step_00012345.pt + Handles the format: {step}/model.pt or {step}/{filename} Args: path: File path or S3 key containing checkpoint @@ -16,19 +33,13 @@ def parse_checkpoint_step(path: str) -> int | None: Step number if found, None otherwise Examples: - >>> parse_checkpoint_step("model_weights/step_12345.pt") - 12345 >>> parse_checkpoint_step("checkpoints/12345/model.pt") 12345 - >>> parse_checkpoint_step("step-00500.pt") - 500 + >>> parse_checkpoint_step("12345/model.pt") + 12345 """ - m = re.search(r"step[_-]?(\d+)\.pt$", path) - if m: - return int(m.group(1)) - parts = path.split("/") - if parts and parts[-1] == "model.pt" and len(parts) >= 2: + if len(parts) >= 2 and parts[-1].endswith(".pt"): try: return int(parts[-2]) except ValueError: @@ -37,26 +48,6 @@ def parse_checkpoint_step(path: str) -> int | None: return None -def compute_step_width(max_steps: int) -> int: - """Compute zero-padding width for step numbers. - - Ensures lexicographic sorting matches chronological order. - - Args: - max_steps: Maximum number of training steps - - Returns: - Number of digits to use for zero-padding - - Examples: - >>> compute_step_width(999) - 3 - >>> compute_step_width(100000) - 6 - """ - return len(str(max_steps)) - - def format_step_number(step: int, max_steps: int) -> str: """Format step number with appropriate zero-padding. @@ -73,5 +64,6 @@ def format_step_number(step: int, max_steps: int) -> str: >>> format_step_number(999, max_steps=999) '999' """ - width = compute_step_width(max_steps) + assert 0 <= step <= max_steps, f"Step {step} must be between 0 and {max_steps}" + width = len(str(max_steps)) return f"{step:0{width}d}" diff --git a/simplexity/utils/config_resolution.py b/simplexity/utils/config_resolution.py index c6755af8..33e64ce9 100644 --- a/simplexity/utils/config_resolution.py +++ b/simplexity/utils/config_resolution.py @@ -1,45 +1,51 @@ -def compute_generator_sequence_length(model_n_ctx: int, use_bos: bool) -> int: - """Compute the generator's sequence length from model context length and BOS usage. +def compute_generator_sequence_length(model_n_ctx: int, use_bos: bool, use_eos: bool = False) -> int: + """Compute the generator's sequence length from model context length and special token usage. - The relationship is: model_n_ctx = generator_seq_len - 1 + BOS + The relationship is: model_n_ctx = generator_seq_len - 1 + BOS + EOS - Solving for generator_seq_len: generator_seq_len = model_n_ctx + 1 - BOS + Solving for generator_seq_len: generator_seq_len = model_n_ctx + 1 - BOS - EOS Args: model_n_ctx: The model's context length (number of input positions it processes) use_bos: Whether a beginning-of-sequence token is prepended during data generation + use_eos: Whether an end-of-sequence token is appended during data generation Returns: The sequence length to configure for the data generator Examples: - >>> compute_generator_sequence_length(model_n_ctx=512, use_bos=True) + >>> compute_generator_sequence_length(model_n_ctx=512, use_bos=True, use_eos=False) 512 - >>> compute_generator_sequence_length(model_n_ctx=512, use_bos=False) + >>> compute_generator_sequence_length(model_n_ctx=512, use_bos=False, use_eos=False) 513 + >>> compute_generator_sequence_length(model_n_ctx=512, use_bos=True, use_eos=True) + 511 """ - return model_n_ctx + 1 - int(use_bos) + return model_n_ctx + 1 - int(use_bos) - int(use_eos) -def compute_model_context_length(generator_seq_len: int, use_bos: bool) -> int: - """Compute the model's context length from generator sequence length and BOS usage. +def compute_model_context_length(generator_seq_len: int, use_bos: bool, use_eos: bool = False) -> int: + """Compute the model's context length from generator sequence length and special token usage. - The relationship is: model_n_ctx = generator_seq_len - 1 + BOS + The relationship is: model_n_ctx = generator_seq_len - 1 + BOS + EOS Args: generator_seq_len: The sequence length configured for the data generator use_bos: Whether a beginning-of-sequence token is prepended during data generation + use_eos: Whether an end-of-sequence token is appended during data generation Returns: The context length for the model (number of input positions it will process) Examples: - >>> compute_model_context_length(generator_seq_len=512, use_bos=True) + >>> compute_model_context_length(generator_seq_len=512, use_bos=True, use_eos=False) 512 - >>> compute_model_context_length(generator_seq_len=513, use_bos=False) + >>> compute_model_context_length(generator_seq_len=513, use_bos=False, use_eos=False) + 512 + >>> compute_model_context_length(generator_seq_len=511, use_bos=True, use_eos=True) 512 """ - return generator_seq_len - 1 + int(use_bos) + return generator_seq_len - 1 + int(use_bos) + int(use_eos) def compute_model_vocab_size(generator_vocab_size: int, use_bos: bool, use_eos: bool) -> int: diff --git a/tests/persistence/test_utils.py b/tests/persistence/test_utils.py index eecfd9c3..72933c76 100644 --- a/tests/persistence/test_utils.py +++ b/tests/persistence/test_utils.py @@ -1,37 +1,13 @@ +from pathlib import Path + import pytest -from simplexity.persistence.utils import compute_step_width, format_step_number, parse_checkpoint_step +from simplexity.persistence.utils import format_step_number, get_checkpoint_path, parse_checkpoint_step class TestParseCheckpointStep: """Test parse_checkpoint_step function.""" - @pytest.mark.parametrize( - ("path", "expected"), - [ - ("model_weights/step_12345.pt", 12345), - ("step_12345.pt", 12345), - ("step_00012345.pt", 12345), - ("checkpoints/step_500.pt", 500), - ("path/to/step_999.pt", 999), - ], - ) - def test_step_underscore_format(self, path: str, expected: int): - """Test parsing step_XXXX.pt format.""" - assert parse_checkpoint_step(path) == expected - - @pytest.mark.parametrize( - ("path", "expected"), - [ - ("step-12345.pt", 12345), - ("step-00500.pt", 500), - ("model_weights/step-999.pt", 999), - ], - ) - def test_step_hyphen_format(self, path: str, expected: int): - """Test parsing step-XXXX.pt format.""" - assert parse_checkpoint_step(path) == expected - @pytest.mark.parametrize( ("path", "expected"), [ @@ -39,10 +15,11 @@ def test_step_hyphen_format(self, path: str, expected: int): ("checkpoints/12345/model.pt", 12345), ("path/to/500/model.pt", 500), ("0/model.pt", 0), + ("prefix/run_name/12345/model.pt", 12345), ], ) def test_directory_model_format(self, path: str, expected: int): - """Test parsing XXXX/model.pt format.""" + """Test parsing {step}/model.pt format.""" assert parse_checkpoint_step(path) == expected @pytest.mark.parametrize( @@ -50,10 +27,8 @@ def test_directory_model_format(self, path: str, expected: int): [ "model.pt", "checkpoint.pt", - "step.pt", "weights/model.eqx", "random_file.txt", - "step_abc.pt", "nonumeric/model.pt", ], ) @@ -63,68 +38,33 @@ def test_no_match_returns_none(self, path: str): def test_zero_padded_step_numbers(self): """Test that zero-padded step numbers are correctly parsed.""" - assert parse_checkpoint_step("step_00042.pt") == 42 - assert parse_checkpoint_step("step_00000.pt") == 0 assert parse_checkpoint_step("0000/model.pt") == 0 - def test_step_pattern_takes_precedence_over_directory(self): - """Test that step_*.pt pattern takes precedence over directory pattern.""" - assert parse_checkpoint_step("path/step_200.pt") == 200 - assert parse_checkpoint_step("checkpoints/step_999.pt") == 999 - def test_windows_paths(self): - """Test Windows-style paths with backslashes.""" - path_unix = "checkpoints/12345/model.pt" - assert parse_checkpoint_step(path_unix) == 12345 +class TestGetCheckpointPath: + """Test get_checkpoint_path function.""" - def test_s3_style_keys(self): - """Test S3 object key formats.""" - assert parse_checkpoint_step("s3://bucket/prefix/step_12345.pt") == 12345 - assert parse_checkpoint_step("prefix/run_name/12345/model.pt") == 12345 + def test_basic_path_construction(self): + """Test basic checkpoint path construction.""" + path = get_checkpoint_path(Path("checkpoints"), 12345) + assert path == Path("checkpoints/12345/model.pt") - -class TestComputeStepWidth: - """Test compute_step_width function.""" - - def test_single_digit(self): - """Test with max_steps requiring 1 digit.""" - assert compute_step_width(9) == 1 - assert compute_step_width(1) == 1 - - def test_two_digits(self): - """Test with max_steps requiring 2 digits.""" - assert compute_step_width(10) == 2 - assert compute_step_width(99) == 2 - - def test_three_digits(self): - """Test with max_steps requiring 3 digits.""" - assert compute_step_width(100) == 3 - assert compute_step_width(999) == 3 + def test_custom_filename(self): + """Test with custom filename.""" + path = get_checkpoint_path(Path("weights"), 100, "state.pt") + assert path == Path("weights/100/state.pt") @pytest.mark.parametrize( - ("max_steps", "expected_width"), + ("directory", "step", "filename", "expected"), [ - (1, 1), - (9, 1), - (10, 2), - (99, 2), - (100, 3), - (999, 3), - (1000, 4), - (9999, 4), - (10000, 5), - (99999, 5), - (100000, 6), + (Path("checkpoints"), 0, "model.pt", Path("checkpoints/0/model.pt")), + (Path("runs/exp1"), 1000, "checkpoint.pt", Path("runs/exp1/1000/checkpoint.pt")), + (Path("."), 42, "model.pt", Path("42/model.pt")), ], ) - def test_parametrized_widths(self, max_steps: int, expected_width: int): - """Test various max_steps values.""" - assert compute_step_width(max_steps) == expected_width - - def test_large_step_counts(self): - """Test with very large step counts.""" - assert compute_step_width(1000000) == 7 - assert compute_step_width(10000000) == 8 + def test_parametrized_paths(self, directory: Path, step: int, filename: str, expected: Path): + """Test various path combinations.""" + assert get_checkpoint_path(directory, step, filename) == expected class TestFormatStepNumber: @@ -168,10 +108,10 @@ def test_lexicographic_ordering(self): formatted = [format_step_number(i, max_steps) for i in [1, 10, 100, 1000, 9999]] assert formatted == sorted(formatted) - def test_consistency_with_compute_step_width(self): - """Verify format_step_number uses compute_step_width correctly.""" + def test_width_computation(self): + """Verify format_step_number computes width correctly.""" max_steps = 100000 step = 42 formatted = format_step_number(step, max_steps) - expected_width = compute_step_width(max_steps) + expected_width = len(str(max_steps)) assert len(formatted) == expected_width diff --git a/tests/utils/test_config_resolution.py b/tests/utils/test_config_resolution.py index 73e3bfc3..94ac5ba5 100644 --- a/tests/utils/test_config_resolution.py +++ b/tests/utils/test_config_resolution.py @@ -20,20 +20,36 @@ def test_without_bos_token(self): assert compute_generator_sequence_length(model_n_ctx=512, use_bos=False) == 513 assert compute_generator_sequence_length(model_n_ctx=100, use_bos=False) == 101 + def test_with_eos_token(self): + """When EOS is used, generator_seq_len should be model_n_ctx.""" + assert compute_generator_sequence_length(model_n_ctx=512, use_bos=False, use_eos=True) == 512 + assert compute_generator_sequence_length(model_n_ctx=100, use_bos=False, use_eos=True) == 100 + + def test_with_bos_and_eos(self): + """When both BOS and EOS are used, generator_seq_len should be model_n_ctx - 1.""" + assert compute_generator_sequence_length(model_n_ctx=512, use_bos=True, use_eos=True) == 511 + assert compute_generator_sequence_length(model_n_ctx=100, use_bos=True, use_eos=True) == 99 + @pytest.mark.parametrize( - ("model_n_ctx", "use_bos", "expected"), + ("model_n_ctx", "use_bos", "use_eos", "expected"), [ - (1, True, 1), - (1, False, 2), - (64, True, 64), - (64, False, 65), - (1024, True, 1024), - (1024, False, 1025), + (1, True, False, 1), + (1, False, False, 2), + (1, False, True, 1), + (1, True, True, 0), + (64, True, False, 64), + (64, False, False, 65), + (64, False, True, 64), + (64, True, True, 63), + (1024, True, False, 1024), + (1024, False, False, 1025), + (1024, False, True, 1024), + (1024, True, True, 1023), ], ) - def test_parametrized_cases(self, model_n_ctx: int, use_bos: bool, expected: int): - """Test various combinations of model_n_ctx and use_bos.""" - assert compute_generator_sequence_length(model_n_ctx, use_bos) == expected + def test_parametrized_cases(self, model_n_ctx: int, use_bos: bool, use_eos: bool, expected: int): + """Test various combinations of model_n_ctx, use_bos, and use_eos.""" + assert compute_generator_sequence_length(model_n_ctx, use_bos, use_eos) == expected def test_zero_context_with_bos(self): """Edge case: zero context length with BOS.""" @@ -57,20 +73,36 @@ def test_without_bos_token(self): assert compute_model_context_length(generator_seq_len=513, use_bos=False) == 512 assert compute_model_context_length(generator_seq_len=101, use_bos=False) == 100 + def test_with_eos_token(self): + """When EOS is used, model_n_ctx should be generator_seq_len.""" + assert compute_model_context_length(generator_seq_len=512, use_bos=False, use_eos=True) == 512 + assert compute_model_context_length(generator_seq_len=100, use_bos=False, use_eos=True) == 100 + + def test_with_bos_and_eos(self): + """When both BOS and EOS are used, model_n_ctx should be generator_seq_len + 1.""" + assert compute_model_context_length(generator_seq_len=511, use_bos=True, use_eos=True) == 512 + assert compute_model_context_length(generator_seq_len=99, use_bos=True, use_eos=True) == 100 + @pytest.mark.parametrize( - ("generator_seq_len", "use_bos", "expected"), + ("generator_seq_len", "use_bos", "use_eos", "expected"), [ - (1, True, 1), - (2, False, 1), - (64, True, 64), - (65, False, 64), - (1024, True, 1024), - (1025, False, 1024), + (1, True, False, 1), + (2, False, False, 1), + (1, False, True, 1), + (0, True, True, 1), + (64, True, False, 64), + (65, False, False, 64), + (64, False, True, 64), + (63, True, True, 64), + (1024, True, False, 1024), + (1025, False, False, 1024), + (1024, False, True, 1024), + (1023, True, True, 1024), ], ) - def test_parametrized_cases(self, generator_seq_len: int, use_bos: bool, expected: int): - """Test various combinations of generator_seq_len and use_bos.""" - assert compute_model_context_length(generator_seq_len, use_bos) == expected + def test_parametrized_cases(self, generator_seq_len: int, use_bos: bool, use_eos: bool, expected: int): + """Test various combinations of generator_seq_len, use_bos, and use_eos.""" + assert compute_model_context_length(generator_seq_len, use_bos, use_eos) == expected def test_inverse_relationship_with_bos(self): """Verify inverse relationship with compute_generator_sequence_length when using BOS.""" @@ -90,10 +122,11 @@ def test_inverse_relationship_without_bos(self): @pytest.mark.parametrize("model_n_ctx", [1, 64, 128, 512, 1024]) @pytest.mark.parametrize("use_bos", [True, False]) - def test_round_trip_consistency(self, model_n_ctx: int, use_bos: bool): + @pytest.mark.parametrize("use_eos", [True, False]) + def test_round_trip_consistency(self, model_n_ctx: int, use_bos: bool, use_eos: bool): """Verify round-trip conversion maintains original value.""" - gen_seq_len = compute_generator_sequence_length(model_n_ctx, use_bos) - recovered = compute_model_context_length(gen_seq_len, use_bos) + gen_seq_len = compute_generator_sequence_length(model_n_ctx, use_bos, use_eos) + recovered = compute_model_context_length(gen_seq_len, use_bos, use_eos) assert recovered == model_n_ctx From dd124903dce226f9c6405583dd27e0510dc4b7fc Mon Sep 17 00:00:00 2001 From: adamimos Date: Thu, 2 Oct 2025 17:17:55 -0700 Subject: [PATCH 08/17] Add default use_eos=False to compute_model_vocab_size for API consistency MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit All three config resolution functions now have use_eos: bool = False as default parameter. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- simplexity/utils/config_resolution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/simplexity/utils/config_resolution.py b/simplexity/utils/config_resolution.py index 33e64ce9..792621ac 100644 --- a/simplexity/utils/config_resolution.py +++ b/simplexity/utils/config_resolution.py @@ -48,7 +48,7 @@ def compute_model_context_length(generator_seq_len: int, use_bos: bool, use_eos: return generator_seq_len - 1 + int(use_bos) + int(use_eos) -def compute_model_vocab_size(generator_vocab_size: int, use_bos: bool, use_eos: bool) -> int: +def compute_model_vocab_size(generator_vocab_size: int, use_bos: bool, use_eos: bool = False) -> int: """Compute the model's vocabulary size from generator vocab and special tokens. When BOS or EOS tokens are used during data generation, they are added to the vocabulary, From 4ede1d1ee2e9d163184ccad23385ef34f09746c2 Mon Sep 17 00:00:00 2001 From: adamimos Date: Thu, 2 Oct 2025 17:29:37 -0700 Subject: [PATCH 09/17] Add comprehensive input validation to prevent production issues MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit HIGH PRIORITY fixes: - Replace assert with if/raise ValueError in format_step_number (asserts removed with -O flag) - Add validation to get_checkpoint_path for negative step values - Add validation to all config_resolution functions for invalid inputs MEDIUM PRIORITY fixes: - Prevent non-positive sequence lengths in compute_generator_sequence_length - Prevent non-positive context lengths in compute_model_context_length - Prevent non-positive vocab sizes in compute_model_vocab_size Test updates: - Add comprehensive error testing for all validation cases - Update round_trip_consistency test to skip invalid configurations - Remove edge case tests that now correctly raise errors All 101 tests pass (1 skipped), ruff and pyright pass. Addresses automated review feedback on production code safety. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- simplexity/persistence/utils.py | 13 +++++++++-- simplexity/utils/config_resolution.py | 33 +++++++++++++++++++++++++-- tests/persistence/test_utils.py | 12 ++++++++++ tests/utils/test_config_resolution.py | 30 ++++++++++++++++++++---- 4 files changed, 79 insertions(+), 9 deletions(-) diff --git a/simplexity/persistence/utils.py b/simplexity/persistence/utils.py index a698c6ce..2b9e7cc5 100644 --- a/simplexity/persistence/utils.py +++ b/simplexity/persistence/utils.py @@ -6,18 +6,23 @@ def get_checkpoint_path(directory: Path, step: int, filename: str = "model.pt") Args: directory: Base directory for checkpoints - step: Training step number + step: Training step number (must be non-negative) filename: Checkpoint filename (default: "model.pt") Returns: Path to checkpoint file: {directory}/{step}/{filename} + Raises: + ValueError: If step is negative + Examples: >>> get_checkpoint_path(Path("checkpoints"), 12345) PosixPath('checkpoints/12345/model.pt') >>> get_checkpoint_path(Path("weights"), 100, "state.pt") PosixPath('weights/100/state.pt') """ + if step < 0: + raise ValueError(f"Step must be non-negative, got {step}") return directory / str(step) / filename @@ -58,12 +63,16 @@ def format_step_number(step: int, max_steps: int) -> str: Returns: Zero-padded step string + Raises: + ValueError: If step is not between 0 and max_steps + Examples: >>> format_step_number(42, max_steps=100000) '000042' >>> format_step_number(999, max_steps=999) '999' """ - assert 0 <= step <= max_steps, f"Step {step} must be between 0 and {max_steps}" + if not 0 <= step <= max_steps: + raise ValueError(f"Step {step} must be between 0 and {max_steps}") width = len(str(max_steps)) return f"{step:0{width}d}" diff --git a/simplexity/utils/config_resolution.py b/simplexity/utils/config_resolution.py index 792621ac..fc54d9dc 100644 --- a/simplexity/utils/config_resolution.py +++ b/simplexity/utils/config_resolution.py @@ -13,6 +13,9 @@ def compute_generator_sequence_length(model_n_ctx: int, use_bos: bool, use_eos: Returns: The sequence length to configure for the data generator + Raises: + ValueError: If the resulting generator sequence length would be non-positive + Examples: >>> compute_generator_sequence_length(model_n_ctx=512, use_bos=True, use_eos=False) 512 @@ -21,7 +24,16 @@ def compute_generator_sequence_length(model_n_ctx: int, use_bos: bool, use_eos: >>> compute_generator_sequence_length(model_n_ctx=512, use_bos=True, use_eos=True) 511 """ - return model_n_ctx + 1 - int(use_bos) - int(use_eos) + if model_n_ctx < 0: + raise ValueError(f"model_n_ctx must be non-negative, got {model_n_ctx}") + + result = model_n_ctx + 1 - int(use_bos) - int(use_eos) + if result <= 0: + raise ValueError( + f"Invalid configuration: model_n_ctx={model_n_ctx}, use_bos={use_bos}, use_eos={use_eos} " + f"results in non-positive generator sequence length ({result})" + ) + return result def compute_model_context_length(generator_seq_len: int, use_bos: bool, use_eos: bool = False) -> int: @@ -37,6 +49,9 @@ def compute_model_context_length(generator_seq_len: int, use_bos: bool, use_eos: Returns: The context length for the model (number of input positions it will process) + Raises: + ValueError: If the resulting model context length would be non-positive + Examples: >>> compute_model_context_length(generator_seq_len=512, use_bos=True, use_eos=False) 512 @@ -45,7 +60,16 @@ def compute_model_context_length(generator_seq_len: int, use_bos: bool, use_eos: >>> compute_model_context_length(generator_seq_len=511, use_bos=True, use_eos=True) 512 """ - return generator_seq_len - 1 + int(use_bos) + int(use_eos) + if generator_seq_len <= 0: + raise ValueError(f"generator_seq_len must be positive, got {generator_seq_len}") + + result = generator_seq_len - 1 + int(use_bos) + int(use_eos) + if result <= 0: + raise ValueError( + f"Invalid configuration: generator_seq_len={generator_seq_len}, use_bos={use_bos}, use_eos={use_eos} " + f"results in non-positive model context length ({result})" + ) + return result def compute_model_vocab_size(generator_vocab_size: int, use_bos: bool, use_eos: bool = False) -> int: @@ -62,6 +86,9 @@ def compute_model_vocab_size(generator_vocab_size: int, use_bos: bool, use_eos: Returns: The vocabulary size the model should be configured with + Raises: + ValueError: If generator_vocab_size is non-positive + Examples: >>> compute_model_vocab_size(generator_vocab_size=100, use_bos=True, use_eos=False) 101 @@ -70,4 +97,6 @@ def compute_model_vocab_size(generator_vocab_size: int, use_bos: bool, use_eos: >>> compute_model_vocab_size(generator_vocab_size=100, use_bos=False, use_eos=False) 100 """ + if generator_vocab_size <= 0: + raise ValueError(f"generator_vocab_size must be positive, got {generator_vocab_size}") return generator_vocab_size + int(use_bos) + int(use_eos) diff --git a/tests/persistence/test_utils.py b/tests/persistence/test_utils.py index 72933c76..d2b35d06 100644 --- a/tests/persistence/test_utils.py +++ b/tests/persistence/test_utils.py @@ -66,6 +66,11 @@ def test_parametrized_paths(self, directory: Path, step: int, filename: str, exp """Test various path combinations.""" assert get_checkpoint_path(directory, step, filename) == expected + def test_negative_step_raises_error(self): + """Test that negative step values raise ValueError.""" + with pytest.raises(ValueError, match="must be non-negative"): + get_checkpoint_path(Path("checkpoints"), -1) + class TestFormatStepNumber: """Test format_step_number function.""" @@ -115,3 +120,10 @@ def test_width_computation(self): formatted = format_step_number(step, max_steps) expected_width = len(str(max_steps)) assert len(formatted) == expected_width + + def test_invalid_step_raises_error(self): + """Test that invalid step values raise ValueError.""" + with pytest.raises(ValueError, match="must be between 0 and"): + format_step_number(-1, max_steps=100) + with pytest.raises(ValueError, match="must be between 0 and"): + format_step_number(101, max_steps=100) diff --git a/tests/utils/test_config_resolution.py b/tests/utils/test_config_resolution.py index 94ac5ba5..f393498a 100644 --- a/tests/utils/test_config_resolution.py +++ b/tests/utils/test_config_resolution.py @@ -36,7 +36,6 @@ def test_with_bos_and_eos(self): (1, True, False, 1), (1, False, False, 2), (1, False, True, 1), - (1, True, True, 0), (64, True, False, 64), (64, False, False, 65), (64, False, True, 64), @@ -51,9 +50,14 @@ def test_parametrized_cases(self, model_n_ctx: int, use_bos: bool, use_eos: bool """Test various combinations of model_n_ctx, use_bos, and use_eos.""" assert compute_generator_sequence_length(model_n_ctx, use_bos, use_eos) == expected - def test_zero_context_with_bos(self): - """Edge case: zero context length with BOS.""" - assert compute_generator_sequence_length(model_n_ctx=0, use_bos=True) == 0 + def test_invalid_configuration_raises_error(self): + """Test that invalid configurations raise ValueError.""" + with pytest.raises(ValueError, match="non-positive generator sequence length"): + compute_generator_sequence_length(model_n_ctx=1, use_bos=True, use_eos=True) + with pytest.raises(ValueError, match="non-positive generator sequence length"): + compute_generator_sequence_length(model_n_ctx=0, use_bos=True, use_eos=False) + with pytest.raises(ValueError, match="non-negative"): + compute_generator_sequence_length(model_n_ctx=-1, use_bos=False, use_eos=False) def test_zero_context_without_bos(self): """Edge case: zero context length without BOS.""" @@ -89,7 +93,6 @@ def test_with_bos_and_eos(self): (1, True, False, 1), (2, False, False, 1), (1, False, True, 1), - (0, True, True, 1), (64, True, False, 64), (65, False, False, 64), (64, False, True, 64), @@ -104,6 +107,13 @@ def test_parametrized_cases(self, generator_seq_len: int, use_bos: bool, use_eos """Test various combinations of generator_seq_len, use_bos, and use_eos.""" assert compute_model_context_length(generator_seq_len, use_bos, use_eos) == expected + def test_invalid_inputs_raise_error(self): + """Test that invalid inputs raise ValueError.""" + with pytest.raises(ValueError, match="must be positive"): + compute_model_context_length(generator_seq_len=0, use_bos=True, use_eos=True) + with pytest.raises(ValueError, match="must be positive"): + compute_model_context_length(generator_seq_len=-1, use_bos=False, use_eos=False) + def test_inverse_relationship_with_bos(self): """Verify inverse relationship with compute_generator_sequence_length when using BOS.""" model_n_ctx = 512 @@ -125,6 +135,9 @@ def test_inverse_relationship_without_bos(self): @pytest.mark.parametrize("use_eos", [True, False]) def test_round_trip_consistency(self, model_n_ctx: int, use_bos: bool, use_eos: bool): """Verify round-trip conversion maintains original value.""" + # Skip invalid configurations that would produce non-positive sequence lengths + if model_n_ctx == 1 and use_bos and use_eos: + pytest.skip("Configuration would produce invalid sequence length") gen_seq_len = compute_generator_sequence_length(model_n_ctx, use_bos, use_eos) recovered = compute_model_context_length(gen_seq_len, use_bos, use_eos) assert recovered == model_n_ctx @@ -174,3 +187,10 @@ def test_minimal_vocab_with_tokens(self): def test_large_vocab(self): """Test with large vocabulary sizes.""" assert compute_model_vocab_size(generator_vocab_size=100000, use_bos=True, use_eos=True) == 100002 + + def test_invalid_vocab_size_raises_error(self): + """Test that non-positive vocab sizes raise ValueError.""" + with pytest.raises(ValueError, match="must be positive"): + compute_model_vocab_size(generator_vocab_size=0, use_bos=True, use_eos=False) + with pytest.raises(ValueError, match="must be positive"): + compute_model_vocab_size(generator_vocab_size=-1, use_bos=False, use_eos=False) From 886b7543388169cfb4db258421693dd6309be85e Mon Sep 17 00:00:00 2001 From: adamimos Date: Thu, 2 Oct 2025 17:41:22 -0700 Subject: [PATCH 10/17] Address all PR review feedback from ealt MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace ValueError validation with assertions in config resolution - Consolidate test cases using parametrize - Add test coverage for different filenames and zero-padding - Remove redundant test cases - Reduce total test count from ~100 to 42 while maintaining coverage 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- simplexity/utils/config_resolution.py | 9 +- tests/persistence/test_utils.py | 52 +++------- tests/utils/test_config_resolution.py | 142 +++++++++----------------- tests/utils/test_jnp.py | 7 -- tests/utils/test_pytorch_utils.py | 17 --- 5 files changed, 66 insertions(+), 161 deletions(-) diff --git a/simplexity/utils/config_resolution.py b/simplexity/utils/config_resolution.py index fc54d9dc..ae868c9a 100644 --- a/simplexity/utils/config_resolution.py +++ b/simplexity/utils/config_resolution.py @@ -24,8 +24,7 @@ def compute_generator_sequence_length(model_n_ctx: int, use_bos: bool, use_eos: >>> compute_generator_sequence_length(model_n_ctx=512, use_bos=True, use_eos=True) 511 """ - if model_n_ctx < 0: - raise ValueError(f"model_n_ctx must be non-negative, got {model_n_ctx}") + assert model_n_ctx > 0, f"model_n_ctx must be positive, got {model_n_ctx}" result = model_n_ctx + 1 - int(use_bos) - int(use_eos) if result <= 0: @@ -60,8 +59,7 @@ def compute_model_context_length(generator_seq_len: int, use_bos: bool, use_eos: >>> compute_model_context_length(generator_seq_len=511, use_bos=True, use_eos=True) 512 """ - if generator_seq_len <= 0: - raise ValueError(f"generator_seq_len must be positive, got {generator_seq_len}") + assert generator_seq_len > 0, f"generator_seq_len must be positive, got {generator_seq_len}" result = generator_seq_len - 1 + int(use_bos) + int(use_eos) if result <= 0: @@ -97,6 +95,5 @@ def compute_model_vocab_size(generator_vocab_size: int, use_bos: bool, use_eos: >>> compute_model_vocab_size(generator_vocab_size=100, use_bos=False, use_eos=False) 100 """ - if generator_vocab_size <= 0: - raise ValueError(f"generator_vocab_size must be positive, got {generator_vocab_size}") + assert generator_vocab_size > 0, f"generator_vocab_size must be positive, got {generator_vocab_size}" return generator_vocab_size + int(use_bos) + int(use_eos) diff --git a/tests/persistence/test_utils.py b/tests/persistence/test_utils.py index d2b35d06..d907e03c 100644 --- a/tests/persistence/test_utils.py +++ b/tests/persistence/test_utils.py @@ -16,10 +16,14 @@ class TestParseCheckpointStep: ("path/to/500/model.pt", 500), ("0/model.pt", 0), ("prefix/run_name/12345/model.pt", 12345), + ("0000/model.pt", 0), + ("12345/checkpoint.pt", 12345), + ("12345/state.pt", 12345), + ("12345/weights.eqx", 12345), ], ) def test_directory_model_format(self, path: str, expected: int): - """Test parsing {step}/model.pt format.""" + """Test parsing {step}/filename format with various filenames.""" assert parse_checkpoint_step(path) == expected @pytest.mark.parametrize( @@ -30,40 +34,31 @@ def test_directory_model_format(self, path: str, expected: int): "weights/model.eqx", "random_file.txt", "nonumeric/model.pt", + "abc123/model.pt", + "123abc/model.pt", ], ) def test_no_match_returns_none(self, path: str): - """Test paths that should not match any pattern.""" + """Test paths with numbers but invalid format return None.""" assert parse_checkpoint_step(path) is None - def test_zero_padded_step_numbers(self): - """Test that zero-padded step numbers are correctly parsed.""" - assert parse_checkpoint_step("0000/model.pt") == 0 - class TestGetCheckpointPath: """Test get_checkpoint_path function.""" - def test_basic_path_construction(self): - """Test basic checkpoint path construction.""" - path = get_checkpoint_path(Path("checkpoints"), 12345) - assert path == Path("checkpoints/12345/model.pt") - - def test_custom_filename(self): - """Test with custom filename.""" - path = get_checkpoint_path(Path("weights"), 100, "state.pt") - assert path == Path("weights/100/state.pt") - @pytest.mark.parametrize( ("directory", "step", "filename", "expected"), [ (Path("checkpoints"), 0, "model.pt", Path("checkpoints/0/model.pt")), + (Path("checkpoints"), 12345, "model.pt", Path("checkpoints/12345/model.pt")), (Path("runs/exp1"), 1000, "checkpoint.pt", Path("runs/exp1/1000/checkpoint.pt")), + (Path("weights"), 100, "state.pt", Path("weights/100/state.pt")), (Path("."), 42, "model.pt", Path("42/model.pt")), + (Path("checkpoints"), 99999, "model.pt", Path("checkpoints/99999/model.pt")), ], ) def test_parametrized_paths(self, directory: Path, step: int, filename: str, expected: Path): - """Test various path combinations.""" + """Test various path combinations including custom filenames and zero-padding.""" assert get_checkpoint_path(directory, step, filename) == expected def test_negative_step_raises_error(self): @@ -75,21 +70,6 @@ def test_negative_step_raises_error(self): class TestFormatStepNumber: """Test format_step_number function.""" - def test_basic_formatting(self): - """Test basic zero-padding behavior.""" - assert format_step_number(42, max_steps=100) == "042" - assert format_step_number(5, max_steps=1000) == "0005" - - def test_no_padding_needed(self): - """Test when step already has maximum width.""" - assert format_step_number(999, max_steps=999) == "999" - assert format_step_number(100, max_steps=100) == "100" - - def test_zero_step(self): - """Test formatting step 0.""" - assert format_step_number(0, max_steps=100) == "000" - assert format_step_number(0, max_steps=10000) == "00000" - @pytest.mark.parametrize( ("step", "max_steps", "expected"), [ @@ -113,14 +93,6 @@ def test_lexicographic_ordering(self): formatted = [format_step_number(i, max_steps) for i in [1, 10, 100, 1000, 9999]] assert formatted == sorted(formatted) - def test_width_computation(self): - """Verify format_step_number computes width correctly.""" - max_steps = 100000 - step = 42 - formatted = format_step_number(step, max_steps) - expected_width = len(str(max_steps)) - assert len(formatted) == expected_width - def test_invalid_step_raises_error(self): """Test that invalid step values raise ValueError.""" with pytest.raises(ValueError, match="must be between 0 and"): diff --git a/tests/utils/test_config_resolution.py b/tests/utils/test_config_resolution.py index f393498a..07554b80 100644 --- a/tests/utils/test_config_resolution.py +++ b/tests/utils/test_config_resolution.py @@ -10,25 +10,18 @@ class TestComputeGeneratorSequenceLength: """Test compute_generator_sequence_length function.""" - def test_with_bos_token(self): - """When BOS is used, generator_seq_len should equal model_n_ctx.""" - assert compute_generator_sequence_length(model_n_ctx=512, use_bos=True) == 512 - assert compute_generator_sequence_length(model_n_ctx=100, use_bos=True) == 100 - - def test_without_bos_token(self): - """When BOS is not used, generator_seq_len should be model_n_ctx + 1.""" - assert compute_generator_sequence_length(model_n_ctx=512, use_bos=False) == 513 - assert compute_generator_sequence_length(model_n_ctx=100, use_bos=False) == 101 - - def test_with_eos_token(self): - """When EOS is used, generator_seq_len should be model_n_ctx.""" - assert compute_generator_sequence_length(model_n_ctx=512, use_bos=False, use_eos=True) == 512 - assert compute_generator_sequence_length(model_n_ctx=100, use_bos=False, use_eos=True) == 100 - - def test_with_bos_and_eos(self): - """When both BOS and EOS are used, generator_seq_len should be model_n_ctx - 1.""" - assert compute_generator_sequence_length(model_n_ctx=512, use_bos=True, use_eos=True) == 511 - assert compute_generator_sequence_length(model_n_ctx=100, use_bos=True, use_eos=True) == 99 + @pytest.mark.parametrize( + ("model_n_ctx", "use_bos", "use_eos", "expected"), + [ + (512, False, False, 513), + (512, True, False, 512), + (512, False, True, 512), + (512, True, True, 511), + ], + ) + def test_bos_eos_combinations(self, model_n_ctx: int, use_bos: bool, use_eos: bool, expected: int): + """Test all combinations of BOS and EOS tokens with same model_n_ctx.""" + assert compute_generator_sequence_length(model_n_ctx, use_bos, use_eos) == expected @pytest.mark.parametrize( ("model_n_ctx", "use_bos", "use_eos", "expected"), @@ -54,38 +47,25 @@ def test_invalid_configuration_raises_error(self): """Test that invalid configurations raise ValueError.""" with pytest.raises(ValueError, match="non-positive generator sequence length"): compute_generator_sequence_length(model_n_ctx=1, use_bos=True, use_eos=True) - with pytest.raises(ValueError, match="non-positive generator sequence length"): + with pytest.raises(AssertionError, match="must be positive"): compute_generator_sequence_length(model_n_ctx=0, use_bos=True, use_eos=False) - with pytest.raises(ValueError, match="non-negative"): - compute_generator_sequence_length(model_n_ctx=-1, use_bos=False, use_eos=False) - - def test_zero_context_without_bos(self): - """Edge case: zero context length without BOS.""" - assert compute_generator_sequence_length(model_n_ctx=0, use_bos=False) == 1 class TestComputeModelContextLength: """Test compute_model_context_length function.""" - def test_with_bos_token(self): - """When BOS is used, model_n_ctx should equal generator_seq_len.""" - assert compute_model_context_length(generator_seq_len=512, use_bos=True) == 512 - assert compute_model_context_length(generator_seq_len=100, use_bos=True) == 100 - - def test_without_bos_token(self): - """When BOS is not used, model_n_ctx should be generator_seq_len - 1.""" - assert compute_model_context_length(generator_seq_len=513, use_bos=False) == 512 - assert compute_model_context_length(generator_seq_len=101, use_bos=False) == 100 - - def test_with_eos_token(self): - """When EOS is used, model_n_ctx should be generator_seq_len.""" - assert compute_model_context_length(generator_seq_len=512, use_bos=False, use_eos=True) == 512 - assert compute_model_context_length(generator_seq_len=100, use_bos=False, use_eos=True) == 100 - - def test_with_bos_and_eos(self): - """When both BOS and EOS are used, model_n_ctx should be generator_seq_len + 1.""" - assert compute_model_context_length(generator_seq_len=511, use_bos=True, use_eos=True) == 512 - assert compute_model_context_length(generator_seq_len=99, use_bos=True, use_eos=True) == 100 + @pytest.mark.parametrize( + ("generator_seq_len", "use_bos", "use_eos", "expected"), + [ + (513, False, False, 512), + (512, True, False, 512), + (512, False, True, 512), + (511, True, True, 512), + ], + ) + def test_bos_eos_combinations(self, generator_seq_len: int, use_bos: bool, use_eos: bool, expected: int): + """Test all combinations of BOS and EOS tokens with resulting model_n_ctx=512.""" + assert compute_model_context_length(generator_seq_len, use_bos, use_eos) == expected @pytest.mark.parametrize( ("generator_seq_len", "use_bos", "use_eos", "expected"), @@ -108,26 +88,24 @@ def test_parametrized_cases(self, generator_seq_len: int, use_bos: bool, use_eos assert compute_model_context_length(generator_seq_len, use_bos, use_eos) == expected def test_invalid_inputs_raise_error(self): - """Test that invalid inputs raise ValueError.""" - with pytest.raises(ValueError, match="must be positive"): + """Test that invalid inputs raise AssertionError.""" + with pytest.raises(AssertionError, match="must be positive"): compute_model_context_length(generator_seq_len=0, use_bos=True, use_eos=True) - with pytest.raises(ValueError, match="must be positive"): - compute_model_context_length(generator_seq_len=-1, use_bos=False, use_eos=False) - - def test_inverse_relationship_with_bos(self): - """Verify inverse relationship with compute_generator_sequence_length when using BOS.""" - model_n_ctx = 512 - use_bos = True - gen_seq_len = compute_generator_sequence_length(model_n_ctx, use_bos) - recovered_n_ctx = compute_model_context_length(gen_seq_len, use_bos) - assert recovered_n_ctx == model_n_ctx - def test_inverse_relationship_without_bos(self): - """Verify inverse relationship with compute_generator_sequence_length without BOS.""" + @pytest.mark.parametrize( + ("use_bos", "use_eos"), + [ + (False, False), + (True, False), + (False, True), + (True, True), + ], + ) + def test_inverse_relationship(self, use_bos: bool, use_eos: bool): + """Verify inverse relationship with compute_generator_sequence_length.""" model_n_ctx = 512 - use_bos = False - gen_seq_len = compute_generator_sequence_length(model_n_ctx, use_bos) - recovered_n_ctx = compute_model_context_length(gen_seq_len, use_bos) + gen_seq_len = compute_generator_sequence_length(model_n_ctx, use_bos, use_eos) + recovered_n_ctx = compute_model_context_length(gen_seq_len, use_bos, use_eos) assert recovered_n_ctx == model_n_ctx @pytest.mark.parametrize("model_n_ctx", [1, 64, 128, 512, 1024]) @@ -135,7 +113,6 @@ def test_inverse_relationship_without_bos(self): @pytest.mark.parametrize("use_eos", [True, False]) def test_round_trip_consistency(self, model_n_ctx: int, use_bos: bool, use_eos: bool): """Verify round-trip conversion maintains original value.""" - # Skip invalid configurations that would produce non-positive sequence lengths if model_n_ctx == 1 and use_bos and use_eos: pytest.skip("Configuration would produce invalid sequence length") gen_seq_len = compute_generator_sequence_length(model_n_ctx, use_bos, use_eos) @@ -146,22 +123,6 @@ def test_round_trip_consistency(self, model_n_ctx: int, use_bos: bool, use_eos: class TestComputeModelVocabSize: """Test compute_model_vocab_size function.""" - def test_no_special_tokens(self): - """When no special tokens are used, vocab size should equal generator vocab.""" - assert compute_model_vocab_size(generator_vocab_size=100, use_bos=False, use_eos=False) == 100 - - def test_with_bos_only(self): - """When only BOS is used, vocab size should be generator_vocab + 1.""" - assert compute_model_vocab_size(generator_vocab_size=100, use_bos=True, use_eos=False) == 101 - - def test_with_eos_only(self): - """When only EOS is used, vocab size should be generator_vocab + 1.""" - assert compute_model_vocab_size(generator_vocab_size=100, use_bos=False, use_eos=True) == 101 - - def test_with_both_special_tokens(self): - """When both BOS and EOS are used, vocab size should be generator_vocab + 2.""" - assert compute_model_vocab_size(generator_vocab_size=100, use_bos=True, use_eos=True) == 102 - @pytest.mark.parametrize( ("generator_vocab_size", "use_bos", "use_eos", "expected"), [ @@ -169,6 +130,15 @@ def test_with_both_special_tokens(self): (100, True, False, 101), (100, False, True, 101), (100, True, True, 102), + ], + ) + def test_bos_eos_combinations(self, generator_vocab_size: int, use_bos: bool, use_eos: bool, expected: int): + """Test all combinations of BOS and EOS tokens with same generator_vocab_size.""" + assert compute_model_vocab_size(generator_vocab_size, use_bos, use_eos) == expected + + @pytest.mark.parametrize( + ("generator_vocab_size", "use_bos", "use_eos", "expected"), + [ (1, False, False, 1), (1, True, True, 3), (50257, False, False, 50257), @@ -180,17 +150,7 @@ def test_parametrized_cases(self, generator_vocab_size: int, use_bos: bool, use_ """Test various combinations of vocab size and special tokens.""" assert compute_model_vocab_size(generator_vocab_size, use_bos, use_eos) == expected - def test_minimal_vocab_with_tokens(self): - """Edge case: minimal vocabulary with special tokens.""" - assert compute_model_vocab_size(generator_vocab_size=2, use_bos=True, use_eos=True) == 4 - - def test_large_vocab(self): - """Test with large vocabulary sizes.""" - assert compute_model_vocab_size(generator_vocab_size=100000, use_bos=True, use_eos=True) == 100002 - def test_invalid_vocab_size_raises_error(self): - """Test that non-positive vocab sizes raise ValueError.""" - with pytest.raises(ValueError, match="must be positive"): + """Test that non-positive vocab sizes raise AssertionError.""" + with pytest.raises(AssertionError, match="must be positive"): compute_model_vocab_size(generator_vocab_size=0, use_bos=True, use_eos=False) - with pytest.raises(ValueError, match="must be positive"): - compute_model_vocab_size(generator_vocab_size=-1, use_bos=False, use_eos=False) diff --git a/tests/utils/test_jnp.py b/tests/utils/test_jnp.py index ecbd5b88..1b5d5fb1 100644 --- a/tests/utils/test_jnp.py +++ b/tests/utils/test_jnp.py @@ -119,10 +119,3 @@ def test_invalid_spec_raises_value_error(self): """Test invalid device spec raises ValueError.""" with pytest.raises(ValueError, match="Unknown device specification"): resolve_jax_device("invalid_device") - - def test_unknown_specs_raise_value_error(self): - """Test various unknown specs raise ValueError.""" - invalid_specs = ["tpu", "gpu0", "cuda:0", "mps", "unknown"] - for spec in invalid_specs: - with pytest.raises(ValueError, match="Unknown device specification"): - resolve_jax_device(spec) diff --git a/tests/utils/test_pytorch_utils.py b/tests/utils/test_pytorch_utils.py index 711e28cd..9f39fd74 100644 --- a/tests/utils/test_pytorch_utils.py +++ b/tests/utils/test_pytorch_utils.py @@ -92,20 +92,3 @@ def test_invalid_spec_raises_value_error(self): """Test invalid device spec raises ValueError.""" with pytest.raises(ValueError, match="Unknown device specification"): resolve_device("invalid_device") - - def test_unknown_specs_raise_value_error(self): - """Test various unknown specs raise ValueError.""" - invalid_specs = ["gpu", "cuda:0", "cuda:1", "tpu", "unknown"] - for spec in invalid_specs: - with pytest.raises(ValueError, match="Unknown device specification"): - resolve_device(spec) - - def test_auto_mode_priority_order(self): - """Test auto mode follows CUDA -> MPS -> CPU priority.""" - device = resolve_device("auto") - if torch.cuda.is_available(): - assert device == "cuda" - elif torch.backends.mps.is_available(): - assert device == "mps" - else: - assert device == "cpu" From b63dbc575995b533f793ad863b06be98ca28ba1e Mon Sep 17 00:00:00 2001 From: adamimos Date: Thu, 2 Oct 2025 22:04:55 -0700 Subject: [PATCH 11/17] Fix parse_checkpoint_step to handle .eqx extension MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Equinox (JAX) checkpoints use .eqx extension, update parser to accept both .pt and .eqx checkpoint files. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- simplexity/persistence/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/simplexity/persistence/utils.py b/simplexity/persistence/utils.py index 2b9e7cc5..98627a85 100644 --- a/simplexity/persistence/utils.py +++ b/simplexity/persistence/utils.py @@ -44,7 +44,7 @@ def parse_checkpoint_step(path: str) -> int | None: 12345 """ parts = path.split("/") - if len(parts) >= 2 and parts[-1].endswith(".pt"): + if len(parts) >= 2 and parts[-1].endswith((".pt", ".eqx")): try: return int(parts[-2]) except ValueError: From 2844fa60afd3114c63f4560f094e3c0a4e3a973f Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Fri, 3 Oct 2025 19:39:18 +0000 Subject: [PATCH 12/17] Address PR review feedback - Add filename validation helper function with support for multiple extensions (.pt, .eqx, .pkl, .ckpt, .pth) - Update get_checkpoint_path to optionally use format_step_number for zero-padded paths - Make use_bos and use_eos keyword-only arguments with default values (False) in all config resolution functions - Consolidate test cases to minimum necessary while maintaining comprehensive coverage - Update all test calls to use keyword arguments for boolean parameters Co-authored-by: ealt --- simplexity/persistence/utils.py | 43 ++++++++++- simplexity/utils/config_resolution.py | 6 +- tests/persistence/test_utils.py | 49 ++++++------ tests/utils/test_config_resolution.py | 103 +++++++++----------------- 4 files changed, 101 insertions(+), 100 deletions(-) diff --git a/simplexity/persistence/utils.py b/simplexity/persistence/utils.py index 98627a85..120dc261 100644 --- a/simplexity/persistence/utils.py +++ b/simplexity/persistence/utils.py @@ -1,29 +1,64 @@ from pathlib import Path +SUPPORTED_EXTENSIONS = (".pt", ".eqx", ".pkl", ".ckpt", ".pth") -def get_checkpoint_path(directory: Path, step: int, filename: str = "model.pt") -> Path: + +def _is_valid_checkpoint_filename(filename: str) -> bool: + """Check if filename is a valid checkpoint filename with supported extension. + + Args: + filename: The checkpoint filename to validate + + Returns: + True if filename has a supported extension, False otherwise + + Examples: + >>> _is_valid_checkpoint_filename("model.pt") + True + >>> _is_valid_checkpoint_filename("state.eqx") + True + >>> _is_valid_checkpoint_filename("invalid.txt") + False + """ + return filename.endswith(SUPPORTED_EXTENSIONS) + + +def get_checkpoint_path( + directory: Path, step: int, filename: str = "model.pt", max_steps: int | None = None +) -> Path: """Construct checkpoint path following the standard naming convention. Args: directory: Base directory for checkpoints step: Training step number (must be non-negative) filename: Checkpoint filename (default: "model.pt") + max_steps: Maximum number of training steps. If provided, step will be zero-padded Returns: Path to checkpoint file: {directory}/{step}/{filename} Raises: - ValueError: If step is negative + ValueError: If step is negative or filename has unsupported extension Examples: >>> get_checkpoint_path(Path("checkpoints"), 12345) PosixPath('checkpoints/12345/model.pt') >>> get_checkpoint_path(Path("weights"), 100, "state.pt") PosixPath('weights/100/state.pt') + >>> get_checkpoint_path(Path("checkpoints"), 42, max_steps=100000) + PosixPath('checkpoints/000042/model.pt') """ if step < 0: raise ValueError(f"Step must be non-negative, got {step}") - return directory / str(step) / filename + if not _is_valid_checkpoint_filename(filename): + raise ValueError(f"Filename must have one of these extensions: {SUPPORTED_EXTENSIONS}, got {filename}") + + if max_steps is not None: + step_str = format_step_number(step, max_steps) + else: + step_str = str(step) + + return directory / step_str / filename def parse_checkpoint_step(path: str) -> int | None: @@ -44,7 +79,7 @@ def parse_checkpoint_step(path: str) -> int | None: 12345 """ parts = path.split("/") - if len(parts) >= 2 and parts[-1].endswith((".pt", ".eqx")): + if len(parts) >= 2 and _is_valid_checkpoint_filename(parts[-1]): try: return int(parts[-2]) except ValueError: diff --git a/simplexity/utils/config_resolution.py b/simplexity/utils/config_resolution.py index ae868c9a..3c03325d 100644 --- a/simplexity/utils/config_resolution.py +++ b/simplexity/utils/config_resolution.py @@ -1,4 +1,4 @@ -def compute_generator_sequence_length(model_n_ctx: int, use_bos: bool, use_eos: bool = False) -> int: +def compute_generator_sequence_length(model_n_ctx: int, *, use_bos: bool = False, use_eos: bool = False) -> int: """Compute the generator's sequence length from model context length and special token usage. The relationship is: model_n_ctx = generator_seq_len - 1 + BOS + EOS @@ -35,7 +35,7 @@ def compute_generator_sequence_length(model_n_ctx: int, use_bos: bool, use_eos: return result -def compute_model_context_length(generator_seq_len: int, use_bos: bool, use_eos: bool = False) -> int: +def compute_model_context_length(generator_seq_len: int, *, use_bos: bool = False, use_eos: bool = False) -> int: """Compute the model's context length from generator sequence length and special token usage. The relationship is: model_n_ctx = generator_seq_len - 1 + BOS + EOS @@ -70,7 +70,7 @@ def compute_model_context_length(generator_seq_len: int, use_bos: bool, use_eos: return result -def compute_model_vocab_size(generator_vocab_size: int, use_bos: bool, use_eos: bool = False) -> int: +def compute_model_vocab_size(generator_vocab_size: int, *, use_bos: bool = False, use_eos: bool = False) -> int: """Compute the model's vocabulary size from generator vocab and special tokens. When BOS or EOS tokens are used during data generation, they are added to the vocabulary, diff --git a/tests/persistence/test_utils.py b/tests/persistence/test_utils.py index d907e03c..dd4e919e 100644 --- a/tests/persistence/test_utils.py +++ b/tests/persistence/test_utils.py @@ -12,34 +12,27 @@ class TestParseCheckpointStep: ("path", "expected"), [ ("12345/model.pt", 12345), - ("checkpoints/12345/model.pt", 12345), - ("path/to/500/model.pt", 500), - ("0/model.pt", 0), - ("prefix/run_name/12345/model.pt", 12345), - ("0000/model.pt", 0), - ("12345/checkpoint.pt", 12345), - ("12345/state.pt", 12345), - ("12345/weights.eqx", 12345), + ("checkpoints/12345/checkpoint.pt", 12345), + ("path/to/500/state.pt", 500), + ("0000/weights.eqx", 0), + ("prefix/run_name/12345/model.pkl", 12345), ], ) def test_directory_model_format(self, path: str, expected: int): - """Test parsing {step}/filename format with various filenames.""" + """Test parsing {step}/filename format with various filenames and zero-padding.""" assert parse_checkpoint_step(path) == expected @pytest.mark.parametrize( "path", [ "model.pt", - "checkpoint.pt", "weights/model.eqx", - "random_file.txt", - "nonumeric/model.pt", "abc123/model.pt", - "123abc/model.pt", + "123abc/checkpoint.pt", ], ) def test_no_match_returns_none(self, path: str): - """Test paths with numbers but invalid format return None.""" + """Test paths with numbers in invalid positions return None.""" assert parse_checkpoint_step(path) is None @@ -47,25 +40,30 @@ class TestGetCheckpointPath: """Test get_checkpoint_path function.""" @pytest.mark.parametrize( - ("directory", "step", "filename", "expected"), + ("directory", "step", "filename", "max_steps", "expected"), [ - (Path("checkpoints"), 0, "model.pt", Path("checkpoints/0/model.pt")), - (Path("checkpoints"), 12345, "model.pt", Path("checkpoints/12345/model.pt")), - (Path("runs/exp1"), 1000, "checkpoint.pt", Path("runs/exp1/1000/checkpoint.pt")), - (Path("weights"), 100, "state.pt", Path("weights/100/state.pt")), - (Path("."), 42, "model.pt", Path("42/model.pt")), - (Path("checkpoints"), 99999, "model.pt", Path("checkpoints/99999/model.pt")), + (Path("checkpoints"), 12345, "model.pt", None, Path("checkpoints/12345/model.pt")), + (Path("runs/exp1"), 1000, "checkpoint.pt", None, Path("runs/exp1/1000/checkpoint.pt")), + (Path("weights"), 42, "state.eqx", 100000, Path("weights/000042/state.eqx")), + (Path("checkpoints"), 0, "model.pt", 999, Path("checkpoints/000/model.pt")), ], ) - def test_parametrized_paths(self, directory: Path, step: int, filename: str, expected: Path): + def test_parametrized_paths( + self, directory: Path, step: int, filename: str, max_steps: int | None, expected: Path + ): """Test various path combinations including custom filenames and zero-padding.""" - assert get_checkpoint_path(directory, step, filename) == expected + assert get_checkpoint_path(directory, step, filename, max_steps) == expected def test_negative_step_raises_error(self): """Test that negative step values raise ValueError.""" with pytest.raises(ValueError, match="must be non-negative"): get_checkpoint_path(Path("checkpoints"), -1) + def test_invalid_filename_raises_error(self): + """Test that invalid filenames raise ValueError.""" + with pytest.raises(ValueError, match="must have one of these extensions"): + get_checkpoint_path(Path("checkpoints"), 100, "invalid.txt") + class TestFormatStepNumber: """Test format_step_number function.""" @@ -74,17 +72,14 @@ class TestFormatStepNumber: ("step", "max_steps", "expected"), [ (0, 999, "000"), - (1, 999, "001"), (42, 999, "042"), (999, 999, "999"), - (0, 100000, "000000"), (42, 100000, "000042"), - (12345, 100000, "012345"), (100000, 100000, "100000"), ], ) def test_parametrized_formatting(self, step: int, max_steps: int, expected: str): - """Test various step and max_steps combinations.""" + """Test various step and max_steps combinations with zero-padding.""" assert format_step_number(step, max_steps) == expected def test_lexicographic_ordering(self): diff --git a/tests/utils/test_config_resolution.py b/tests/utils/test_config_resolution.py index 07554b80..d7219c28 100644 --- a/tests/utils/test_config_resolution.py +++ b/tests/utils/test_config_resolution.py @@ -11,37 +11,17 @@ class TestComputeGeneratorSequenceLength: """Test compute_generator_sequence_length function.""" @pytest.mark.parametrize( - ("model_n_ctx", "use_bos", "use_eos", "expected"), + ("use_bos", "use_eos", "expected"), [ - (512, False, False, 513), - (512, True, False, 512), - (512, False, True, 512), - (512, True, True, 511), + (False, False, 65), + (True, False, 64), + (False, True, 64), + (True, True, 63), ], ) - def test_bos_eos_combinations(self, model_n_ctx: int, use_bos: bool, use_eos: bool, expected: int): - """Test all combinations of BOS and EOS tokens with same model_n_ctx.""" - assert compute_generator_sequence_length(model_n_ctx, use_bos, use_eos) == expected - - @pytest.mark.parametrize( - ("model_n_ctx", "use_bos", "use_eos", "expected"), - [ - (1, True, False, 1), - (1, False, False, 2), - (1, False, True, 1), - (64, True, False, 64), - (64, False, False, 65), - (64, False, True, 64), - (64, True, True, 63), - (1024, True, False, 1024), - (1024, False, False, 1025), - (1024, False, True, 1024), - (1024, True, True, 1023), - ], - ) - def test_parametrized_cases(self, model_n_ctx: int, use_bos: bool, use_eos: bool, expected: int): - """Test various combinations of model_n_ctx, use_bos, and use_eos.""" - assert compute_generator_sequence_length(model_n_ctx, use_bos, use_eos) == expected + def test_bos_eos_combinations(self, use_bos: bool, use_eos: bool, expected: int): + """Test all combinations of BOS and EOS tokens with model_n_ctx=64.""" + assert compute_generator_sequence_length(64, use_bos=use_bos, use_eos=use_eos) == expected def test_invalid_configuration_raises_error(self): """Test that invalid configurations raise ValueError.""" @@ -55,37 +35,30 @@ class TestComputeModelContextLength: """Test compute_model_context_length function.""" @pytest.mark.parametrize( - ("generator_seq_len", "use_bos", "use_eos", "expected"), + ("use_bos", "use_eos", "expected"), [ - (513, False, False, 512), - (512, True, False, 512), - (512, False, True, 512), - (511, True, True, 512), + (False, False, 63), + (True, False, 64), + (False, True, 64), + (True, True, 65), ], ) - def test_bos_eos_combinations(self, generator_seq_len: int, use_bos: bool, use_eos: bool, expected: int): - """Test all combinations of BOS and EOS tokens with resulting model_n_ctx=512.""" - assert compute_model_context_length(generator_seq_len, use_bos, use_eos) == expected + def test_bos_eos_combinations(self, use_bos: bool, use_eos: bool, expected: int): + """Test all combinations of BOS and EOS tokens with generator_seq_len=64.""" + assert compute_model_context_length(64, use_bos=use_bos, use_eos=use_eos) == expected @pytest.mark.parametrize( - ("generator_seq_len", "use_bos", "use_eos", "expected"), + ("use_bos", "use_eos", "expected"), [ - (1, True, False, 1), - (2, False, False, 1), - (1, False, True, 1), - (64, True, False, 64), - (65, False, False, 64), - (64, False, True, 64), - (63, True, True, 64), - (1024, True, False, 1024), - (1025, False, False, 1024), - (1024, False, True, 1024), - (1023, True, True, 1024), + (False, False, 511), + (True, False, 512), + (False, True, 512), + (True, True, 513), ], ) - def test_parametrized_cases(self, generator_seq_len: int, use_bos: bool, use_eos: bool, expected: int): - """Test various combinations of generator_seq_len, use_bos, and use_eos.""" - assert compute_model_context_length(generator_seq_len, use_bos, use_eos) == expected + def test_parametrized_cases(self, use_bos: bool, use_eos: bool, expected: int): + """Test all combinations with generator_seq_len=512.""" + assert compute_model_context_length(512, use_bos=use_bos, use_eos=use_eos) == expected def test_invalid_inputs_raise_error(self): """Test that invalid inputs raise AssertionError.""" @@ -104,8 +77,8 @@ def test_invalid_inputs_raise_error(self): def test_inverse_relationship(self, use_bos: bool, use_eos: bool): """Verify inverse relationship with compute_generator_sequence_length.""" model_n_ctx = 512 - gen_seq_len = compute_generator_sequence_length(model_n_ctx, use_bos, use_eos) - recovered_n_ctx = compute_model_context_length(gen_seq_len, use_bos, use_eos) + gen_seq_len = compute_generator_sequence_length(model_n_ctx, use_bos=use_bos, use_eos=use_eos) + recovered_n_ctx = compute_model_context_length(gen_seq_len, use_bos=use_bos, use_eos=use_eos) assert recovered_n_ctx == model_n_ctx @pytest.mark.parametrize("model_n_ctx", [1, 64, 128, 512, 1024]) @@ -115,8 +88,8 @@ def test_round_trip_consistency(self, model_n_ctx: int, use_bos: bool, use_eos: """Verify round-trip conversion maintains original value.""" if model_n_ctx == 1 and use_bos and use_eos: pytest.skip("Configuration would produce invalid sequence length") - gen_seq_len = compute_generator_sequence_length(model_n_ctx, use_bos, use_eos) - recovered = compute_model_context_length(gen_seq_len, use_bos, use_eos) + gen_seq_len = compute_generator_sequence_length(model_n_ctx, use_bos=use_bos, use_eos=use_eos) + recovered = compute_model_context_length(gen_seq_len, use_bos=use_bos, use_eos=use_eos) assert recovered == model_n_ctx @@ -124,23 +97,21 @@ class TestComputeModelVocabSize: """Test compute_model_vocab_size function.""" @pytest.mark.parametrize( - ("generator_vocab_size", "use_bos", "use_eos", "expected"), + ("use_bos", "use_eos", "expected"), [ - (100, False, False, 100), - (100, True, False, 101), - (100, False, True, 101), - (100, True, True, 102), + (False, False, 100), + (True, False, 101), + (False, True, 101), + (True, True, 102), ], ) - def test_bos_eos_combinations(self, generator_vocab_size: int, use_bos: bool, use_eos: bool, expected: int): - """Test all combinations of BOS and EOS tokens with same generator_vocab_size.""" - assert compute_model_vocab_size(generator_vocab_size, use_bos, use_eos) == expected + def test_bos_eos_combinations(self, use_bos: bool, use_eos: bool, expected: int): + """Test all combinations of BOS and EOS tokens with generator_vocab_size=100.""" + assert compute_model_vocab_size(100, use_bos=use_bos, use_eos=use_eos) == expected @pytest.mark.parametrize( ("generator_vocab_size", "use_bos", "use_eos", "expected"), [ - (1, False, False, 1), - (1, True, True, 3), (50257, False, False, 50257), (50257, True, False, 50258), (50257, True, True, 50259), @@ -148,7 +119,7 @@ def test_bos_eos_combinations(self, generator_vocab_size: int, use_bos: bool, us ) def test_parametrized_cases(self, generator_vocab_size: int, use_bos: bool, use_eos: bool, expected: int): """Test various combinations of vocab size and special tokens.""" - assert compute_model_vocab_size(generator_vocab_size, use_bos, use_eos) == expected + assert compute_model_vocab_size(generator_vocab_size, use_bos=use_bos, use_eos=use_eos) == expected def test_invalid_vocab_size_raises_error(self): """Test that non-positive vocab sizes raise AssertionError.""" From bb0098398ff09c18fc0946b68a63bb49e97a5772 Mon Sep 17 00:00:00 2001 From: adamimos Date: Fri, 3 Oct 2025 14:16:41 -0700 Subject: [PATCH 13/17] Apply ruff formatting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- simplexity/persistence/utils.py | 4 +--- tests/persistence/test_utils.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/simplexity/persistence/utils.py b/simplexity/persistence/utils.py index 120dc261..bd69dfbf 100644 --- a/simplexity/persistence/utils.py +++ b/simplexity/persistence/utils.py @@ -23,9 +23,7 @@ def _is_valid_checkpoint_filename(filename: str) -> bool: return filename.endswith(SUPPORTED_EXTENSIONS) -def get_checkpoint_path( - directory: Path, step: int, filename: str = "model.pt", max_steps: int | None = None -) -> Path: +def get_checkpoint_path(directory: Path, step: int, filename: str = "model.pt", max_steps: int | None = None) -> Path: """Construct checkpoint path following the standard naming convention. Args: diff --git a/tests/persistence/test_utils.py b/tests/persistence/test_utils.py index dd4e919e..8e91f51e 100644 --- a/tests/persistence/test_utils.py +++ b/tests/persistence/test_utils.py @@ -48,9 +48,7 @@ class TestGetCheckpointPath: (Path("checkpoints"), 0, "model.pt", 999, Path("checkpoints/000/model.pt")), ], ) - def test_parametrized_paths( - self, directory: Path, step: int, filename: str, max_steps: int | None, expected: Path - ): + def test_parametrized_paths(self, directory: Path, step: int, filename: str, max_steps: int | None, expected: Path): """Test various path combinations including custom filenames and zero-padding.""" assert get_checkpoint_path(directory, step, filename, max_steps) == expected From 4416f0d26c6527124f3a89804dcd273d192f1db1 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 3 Oct 2025 22:32:51 +0000 Subject: [PATCH 14/17] Switch from Unity Catalog to Workspace Model Registry --- simplexity/configs/logging/mlflow_logger.yaml | 2 +- simplexity/logging/mlflow_logger.py | 6 +- simplexity/persistence/mlflow_persister.py | 4 +- simplexity/utils/mlflow_utils.py | 62 +++++++++++++++++++ tests/utils/test_mlflow_utils.py | 38 ++++++++++++ 5 files changed, 108 insertions(+), 4 deletions(-) create mode 100644 simplexity/utils/mlflow_utils.py create mode 100644 tests/utils/test_mlflow_utils.py diff --git a/simplexity/configs/logging/mlflow_logger.yaml b/simplexity/configs/logging/mlflow_logger.yaml index 882b1fc4..b9068b7e 100644 --- a/simplexity/configs/logging/mlflow_logger.yaml +++ b/simplexity/configs/logging/mlflow_logger.yaml @@ -4,4 +4,4 @@ instance: experiment_name: /Shared/${experiment_name} run_name: ${run_name} tracking_uri: databricks - registry_uri: databricks-uc + registry_uri: databricks diff --git a/simplexity/logging/mlflow_logger.py b/simplexity/logging/mlflow_logger.py index a5eb8a28..4c9aca86 100644 --- a/simplexity/logging/mlflow_logger.py +++ b/simplexity/logging/mlflow_logger.py @@ -18,6 +18,7 @@ from omegaconf import DictConfig, OmegaConf from simplexity.logging.logger import Logger +from simplexity.utils.mlflow_utils import resolve_registry_uri dotenv.load_dotenv() _DATABRICKS_HOST = os.getenv("DATABRICKS_HOST") @@ -34,7 +35,8 @@ def __init__( registry_uri: str | None = None, ): """Initialize MLflow logger.""" - self._client = mlflow.MlflowClient(tracking_uri=tracking_uri, registry_uri=registry_uri) + resolved_registry_uri = resolve_registry_uri(tracking_uri, registry_uri) + self._client = mlflow.MlflowClient(tracking_uri=tracking_uri, registry_uri=resolved_registry_uri) experiment = self._client.get_experiment_by_name(experiment_name) if experiment: experiment_id = experiment.experiment_id @@ -43,7 +45,7 @@ def __init__( run = self._client.create_run(experiment_id=experiment_id, run_name=run_name) self._run_id = run.info.run_id self._tracking_uri = tracking_uri - self._registry_uri = registry_uri + self._registry_uri = resolved_registry_uri @property def client(self) -> mlflow.MlflowClient: diff --git a/simplexity/persistence/mlflow_persister.py b/simplexity/persistence/mlflow_persister.py index 38f9a950..f242e6ec 100644 --- a/simplexity/persistence/mlflow_persister.py +++ b/simplexity/persistence/mlflow_persister.py @@ -10,6 +10,7 @@ from simplexity.persistence.model_persister import ModelPersister from simplexity.predictive_models.predictive_model import PredictiveModel from simplexity.predictive_models.types import ModelFramework +from simplexity.utils.mlflow_utils import resolve_registry_uri if TYPE_CHECKING: from mlflow import MlflowClient @@ -79,7 +80,8 @@ def from_experiment( """Create a persister from an MLflow experiment.""" import mlflow - client = mlflow.MlflowClient(tracking_uri=tracking_uri, registry_uri=registry_uri) + resolved_registry_uri = resolve_registry_uri(tracking_uri, registry_uri) + client = mlflow.MlflowClient(tracking_uri=tracking_uri, registry_uri=resolved_registry_uri) experiment = client.get_experiment_by_name(experiment_name) if experiment: experiment_id = experiment.experiment_id diff --git a/simplexity/utils/mlflow_utils.py b/simplexity/utils/mlflow_utils.py new file mode 100644 index 00000000..ddf3c48a --- /dev/null +++ b/simplexity/utils/mlflow_utils.py @@ -0,0 +1,62 @@ +"""Utilities for working with MLflow in different Databricks environments.""" + +from __future__ import annotations + +import warnings +from typing import Final + +_UC_PREFIX: Final = "databricks-uc" +_WORKSPACE_PREFIX: Final = "databricks" +_SCHEME_SEPARATOR: Final = "://" + + +def _normalize_databricks_uri(uri: str) -> tuple[str, bool]: + """Convert Databricks Unity Catalog URIs to workspace-compatible equivalents.""" + if uri == _UC_PREFIX: + return _WORKSPACE_PREFIX, True + prefix = f"{_UC_PREFIX}{_SCHEME_SEPARATOR}" + if uri.startswith(prefix): + suffix = uri.split(_SCHEME_SEPARATOR, 1)[1] + return f"{_WORKSPACE_PREFIX}{_SCHEME_SEPARATOR}{suffix}", True + return uri, False + + +def resolve_registry_uri(tracking_uri: str | None, registry_uri: str | None) -> str | None: + """Determine a workspace model registry URI for MLflow operations. + + - If an explicit registry URI is provided, convert Unity Catalog URIs to their + workspace equivalents while warning the caller about the downgrade. + - If no registry URI is provided, infer one from a Databricks tracking URI. + - For non-Databricks configurations, return ``None`` so MLflow uses its defaults. + """ + if registry_uri: + normalized, converted = _normalize_databricks_uri(registry_uri) + if converted: + warnings.warn( + ( + f"Unity Catalog registry URI '{registry_uri}' is not supported by this environment; " + f"using workspace registry URI '{normalized}' instead." + ), + stacklevel=2, + ) + return normalized + + if not tracking_uri: + return None + + normalized, converted = _normalize_databricks_uri(tracking_uri) + if normalized.startswith(_WORKSPACE_PREFIX): + if converted: + warnings.warn( + ( + f"Unity Catalog tracking URI '{tracking_uri}' detected; " + f"falling back to workspace registry URI '{normalized}'." + ), + stacklevel=2, + ) + return normalized + + return None + + +__all__ = ["resolve_registry_uri"] diff --git a/tests/utils/test_mlflow_utils.py b/tests/utils/test_mlflow_utils.py new file mode 100644 index 00000000..fa4f849c --- /dev/null +++ b/tests/utils/test_mlflow_utils.py @@ -0,0 +1,38 @@ +"""Tests for MLflow registry URI resolution helpers.""" + +from __future__ import annotations + +import pytest + +from simplexity.utils.mlflow_utils import resolve_registry_uri + + +def test_resolve_registry_uri_prefers_explicit_workspace() -> None: + """Explicit workspace URIs are returned unchanged.""" + assert resolve_registry_uri("databricks", "databricks") == "databricks" + + +def test_resolve_registry_uri_converts_uc_registry_uri(recwarn: pytest.WarningsRecorder) -> None: + """Unity Catalog registry URIs are downgraded to workspace URIs with a warning.""" + result = resolve_registry_uri(None, "databricks-uc") + assert result == "databricks" + warning = recwarn.pop(UserWarning) + assert "Unity Catalog" in str(warning.message) + + +def test_resolve_registry_uri_infers_from_tracking() -> None: + """Databricks tracking URIs are reused for the registry by default.""" + assert resolve_registry_uri("databricks://profile", None) == "databricks://profile" + + +def test_resolve_registry_uri_demotes_tracking_uc(recwarn: pytest.WarningsRecorder) -> None: + """Unity Catalog tracking URIs fall back to workspace registry URIs.""" + result = resolve_registry_uri("databricks-uc://profile", None) + assert result == "databricks://profile" + warning = recwarn.pop(UserWarning) + assert "Unity Catalog tracking URI" in str(warning.message) + + +def test_resolve_registry_uri_non_databricks() -> None: + """Non-Databricks tracking URIs leave the registry unset.""" + assert resolve_registry_uri("file:///tmp", None) is None From 5b210c773bf537861ec655f9984ec4027536c8c3 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 3 Oct 2025 22:40:34 +0000 Subject: [PATCH 15/17] Add workspace fallback, document potential migration --- docs/databricks_model_registry.md | 38 +++++++++++++++++++ simplexity/configs/logging/config.py | 1 + simplexity/configs/logging/mlflow_logger.yaml | 1 + simplexity/logging/mlflow_logger.py | 7 +++- simplexity/persistence/mlflow_persister.py | 7 +++- simplexity/utils/mlflow_utils.py | 19 ++++++++-- tests/utils/test_mlflow_utils.py | 14 +++++++ 7 files changed, 81 insertions(+), 6 deletions(-) create mode 100644 docs/databricks_model_registry.md diff --git a/docs/databricks_model_registry.md b/docs/databricks_model_registry.md new file mode 100644 index 00000000..91228333 --- /dev/null +++ b/docs/databricks_model_registry.md @@ -0,0 +1,38 @@ +# Databricks Model Registry Guide + +This project targets the Databricks **Workspace Model Registry** by default because many subscriptions (including ours) do not provide Unity Catalog access. The integration is designed so that switching to Unity Catalog later only requires configuration changes—no code changes. + +## Default Behaviour + +- Hydra config `logging=mlflow_logger` sets `tracking_uri=databricks` and `registry_uri=databricks`. +- `simplexity.utils.mlflow_utils.resolve_registry_uri` downgrades Unity Catalog URIs (``databricks-uc``) to workspace URIs when `allow_workspace_fallback=True` (the default) and emits a warning so you know a downgrade happened. +- `MLFlowLogger` and `MLFlowPersister.from_experiment` both call `resolve_registry_uri`, so any code path that uses Simplexity helpers gets the same fallback logic. +- `examples/mlflow_workspace_registry_demo.py` mirrors this behaviour and can be used to sanity-check Databricks connectivity. + +## Preparing for a Future Unity Catalog Migration + +To keep migration friction low we expose an `allow_workspace_fallback` flag everywhere MLflow clients are created. + +- **Logger config** (`simplexity/configs/logging/mlflow_logger.yaml`): + - Set `registry_uri: databricks-uc` once your workspace is UC-enabled. + - Flip `allow_workspace_fallback: false` to stop the automatic downgrade. +- **Programmatic use**: `MLFlowLogger(..., allow_workspace_fallback=False)` or `MLFlowPersister.from_experiment(..., allow_workspace_fallback=False)` preserves Unity Catalog URIs. +- **Environment variables**: you can still rely on `MLFLOW_TRACKING_URI` / `MLFLOW_REGISTRY_URI`. When fallback is disabled those values are forwarded unchanged. + +Because the flag defaults to `True`, current jobs continue working even if a Unity Catalog URI is supplied accidentally—Simplexity automatically falls back to the workspace registry and logs a warning. When you are ready to migrate, toggling the flag allows UC usage without touching the codebase. + +## Suggested Migration Checklist + +1. **Enable Unity Catalog in Databricks** and make sure the MLflow registry permissions are set up (see the official Databricks migration guide). +2. **Create the Unity Catalog equivalents** of any workspace-registered models if you plan to keep history—Databricks provides automated migration jobs for this. +3. **Update configuration**: + - Set `registry_uri` (and optionally `tracking_uri`) to the appropriate `databricks-uc` endpoint. + - Set `allow_workspace_fallback: false` to surface real UC connectivity errors instead of silently downgrading. +4. **Smoke test** using `examples/mlflow_workspace_registry_demo.py` with the updated config. The script will now run against UC and should register the demo model there. +5. **Monitor warnings**: once fallback is disabled, any remaining downgrade warnings indicate stale configs or code paths that still pass the workspace URI. + +## Operational Notes + +- Keeping fallback enabled during the transition phase is helpful because it avoids runtime failures, but remember that models will continue to land in the workspace registry until you turn it off. +- After migration you can remove the fallback flag entirely or leave it `False` so that future regressions are caught early. +- If you need parallel workspace/UC logging (for validation) you can run two Hydrated jobs with different logger configs—no application code changes are required. diff --git a/simplexity/configs/logging/config.py b/simplexity/configs/logging/config.py index 49dc5370..222fe01b 100644 --- a/simplexity/configs/logging/config.py +++ b/simplexity/configs/logging/config.py @@ -30,6 +30,7 @@ class MLFlowLoggerConfig(LoggingInstanceConfig): run_name: str tracking_uri: str registry_uri: str | None = None + allow_workspace_fallback: bool = True @dataclass diff --git a/simplexity/configs/logging/mlflow_logger.yaml b/simplexity/configs/logging/mlflow_logger.yaml index b9068b7e..7a72dc52 100644 --- a/simplexity/configs/logging/mlflow_logger.yaml +++ b/simplexity/configs/logging/mlflow_logger.yaml @@ -5,3 +5,4 @@ instance: run_name: ${run_name} tracking_uri: databricks registry_uri: databricks + allow_workspace_fallback: true diff --git a/simplexity/logging/mlflow_logger.py b/simplexity/logging/mlflow_logger.py index 4c9aca86..81227563 100644 --- a/simplexity/logging/mlflow_logger.py +++ b/simplexity/logging/mlflow_logger.py @@ -33,9 +33,14 @@ def __init__( run_name: str | None = None, tracking_uri: str | None = None, registry_uri: str | None = None, + allow_workspace_fallback: bool = True, ): """Initialize MLflow logger.""" - resolved_registry_uri = resolve_registry_uri(tracking_uri, registry_uri) + resolved_registry_uri = resolve_registry_uri( + tracking_uri, + registry_uri, + allow_workspace_fallback=allow_workspace_fallback, + ) self._client = mlflow.MlflowClient(tracking_uri=tracking_uri, registry_uri=resolved_registry_uri) experiment = self._client.get_experiment_by_name(experiment_name) if experiment: diff --git a/simplexity/persistence/mlflow_persister.py b/simplexity/persistence/mlflow_persister.py index f242e6ec..f4220475 100644 --- a/simplexity/persistence/mlflow_persister.py +++ b/simplexity/persistence/mlflow_persister.py @@ -76,11 +76,16 @@ def from_experiment( artifact_path: str = "models", model_framework: ModelFramework = ModelFramework.Equinox, registered_model_name: str | None = None, + allow_workspace_fallback: bool = True, ) -> MLFlowPersister: """Create a persister from an MLflow experiment.""" import mlflow - resolved_registry_uri = resolve_registry_uri(tracking_uri, registry_uri) + resolved_registry_uri = resolve_registry_uri( + tracking_uri, + registry_uri, + allow_workspace_fallback=allow_workspace_fallback, + ) client = mlflow.MlflowClient(tracking_uri=tracking_uri, registry_uri=resolved_registry_uri) experiment = client.get_experiment_by_name(experiment_name) if experiment: diff --git a/simplexity/utils/mlflow_utils.py b/simplexity/utils/mlflow_utils.py index ddf3c48a..3d84e3d6 100644 --- a/simplexity/utils/mlflow_utils.py +++ b/simplexity/utils/mlflow_utils.py @@ -21,7 +21,12 @@ def _normalize_databricks_uri(uri: str) -> tuple[str, bool]: return uri, False -def resolve_registry_uri(tracking_uri: str | None, registry_uri: str | None) -> str | None: +def resolve_registry_uri( + tracking_uri: str | None, + registry_uri: str | None, + *, + allow_workspace_fallback: bool = True, +) -> str | None: """Determine a workspace model registry URI for MLflow operations. - If an explicit registry URI is provided, convert Unity Catalog URIs to their @@ -29,8 +34,14 @@ def resolve_registry_uri(tracking_uri: str | None, registry_uri: str | None) -> - If no registry URI is provided, infer one from a Databricks tracking URI. - For non-Databricks configurations, return ``None`` so MLflow uses its defaults. """ + def _convert(uri: str) -> tuple[str, bool]: + normalized, converted = _normalize_databricks_uri(uri) + if converted and not allow_workspace_fallback: + return uri, False + return normalized, converted + if registry_uri: - normalized, converted = _normalize_databricks_uri(registry_uri) + normalized, converted = _convert(registry_uri) if converted: warnings.warn( ( @@ -44,8 +55,8 @@ def resolve_registry_uri(tracking_uri: str | None, registry_uri: str | None) -> if not tracking_uri: return None - normalized, converted = _normalize_databricks_uri(tracking_uri) - if normalized.startswith(_WORKSPACE_PREFIX): + normalized, converted = _convert(tracking_uri) + if normalized.startswith((_WORKSPACE_PREFIX, _UC_PREFIX)): if converted: warnings.warn( ( diff --git a/tests/utils/test_mlflow_utils.py b/tests/utils/test_mlflow_utils.py index fa4f849c..e318342e 100644 --- a/tests/utils/test_mlflow_utils.py +++ b/tests/utils/test_mlflow_utils.py @@ -20,6 +20,13 @@ def test_resolve_registry_uri_converts_uc_registry_uri(recwarn: pytest.WarningsR assert "Unity Catalog" in str(warning.message) +def test_resolve_registry_uri_respects_disabled_fallback(recwarn: pytest.WarningsRecorder) -> None: + """Fallback can be disabled to keep Unity Catalog URIs intact.""" + result = resolve_registry_uri(None, "databricks-uc", allow_workspace_fallback=False) + assert result == "databricks-uc" + assert not recwarn.list + + def test_resolve_registry_uri_infers_from_tracking() -> None: """Databricks tracking URIs are reused for the registry by default.""" assert resolve_registry_uri("databricks://profile", None) == "databricks://profile" @@ -33,6 +40,13 @@ def test_resolve_registry_uri_demotes_tracking_uc(recwarn: pytest.WarningsRecord assert "Unity Catalog tracking URI" in str(warning.message) +def test_resolve_registry_uri_tracking_fallback_toggle(recwarn: pytest.WarningsRecorder) -> None: + """Unity Catalog tracking URIs stay untouched when fallback is disabled.""" + result = resolve_registry_uri("databricks-uc://profile", None, allow_workspace_fallback=False) + assert result == "databricks-uc://profile" + assert not recwarn.list + + def test_resolve_registry_uri_non_databricks() -> None: """Non-Databricks tracking URIs leave the registry unset.""" assert resolve_registry_uri("file:///tmp", None) is None From 85691a109111cde57e0b9b977bb7fdb7eac3cf8e Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 3 Oct 2025 22:40:44 +0000 Subject: [PATCH 16/17] Create demo --- examples/mlflow_workspace_registry_demo.py | 287 +++++++++++++++++++++ 1 file changed, 287 insertions(+) create mode 100644 examples/mlflow_workspace_registry_demo.py diff --git a/examples/mlflow_workspace_registry_demo.py b/examples/mlflow_workspace_registry_demo.py new file mode 100644 index 00000000..6e134a37 --- /dev/null +++ b/examples/mlflow_workspace_registry_demo.py @@ -0,0 +1,287 @@ +"""Demonstrate saving and loading a PyTorch model with the MLflow workspace registry.""" + +from __future__ import annotations + +import os +import sys +import time +import urllib.parse +from dataclasses import dataclass, field + +import hydra +import mlflow +from hydra.core.config_store import ConfigStore +from mlflow.entities.model_registry import ModelVersion +from omegaconf import MISSING + +from simplexity.utils.mlflow_utils import resolve_registry_uri + +try: + import torch + from torch import nn +except ImportError as exc: # pragma: no cover - script guard + raise SystemExit( + "PyTorch is required for this demo. Install it with `pip install torch` " + "or add the `pytorch` extra when installing this project." + ) from exc + + +WORKSPACE_REGISTRY_URI = "databricks" + + +class TinyClassifier(nn.Module): + """A tiny classifier for testing.""" + + def __init__(self) -> None: + super().__init__() + self.model = nn.Sequential( + nn.Linear(4, 16), + nn.ReLU(), + nn.Linear(16, 2), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override] + """Forward pass.""" + return self.model(x) + + +@dataclass +class DemoConfig: + """Configuration for the MLflow workspace registry demo.""" + + experiment: str = "WorkspaceRegistryDemo" + run_name: str | None = None + registered_model_name: str = MISSING + tracking_uri: str | None = field(default_factory=lambda: os.getenv("MLFLOW_TRACKING_URI")) + registry_uri: str | None = field(default_factory=lambda: os.getenv("MLFLOW_REGISTRY_URI", WORKSPACE_REGISTRY_URI)) + artifact_path: str = "pytorch-model" + poll_interval: float = 2.0 + poll_timeout: float = 300.0 + databricks_host: str | None = field(default_factory=lambda: os.getenv("DATABRICKS_HOST")) + allow_workspace_fallback: bool = True + + +CONFIG_NAME = "mlflow_workspace_registry_demo" +LEGACY_CONFIG_NAME = "mlflow_unity_catalog_demo" + +config_store = ConfigStore.instance() +config_store.store(name=CONFIG_NAME, node=DemoConfig) +config_store.store(name=LEGACY_CONFIG_NAME, node=DemoConfig) + + +def ensure_experiment(client: mlflow.MlflowClient, name: str) -> str: + """Ensure an experiment exists.""" + experiment = client.get_experiment_by_name(name) + if experiment: + return experiment.experiment_id + return client.create_experiment(name) + + +def await_model_version_ready( + client: mlflow.MlflowClient, + model_name: str, + version: str, + poll_interval: float, + poll_timeout: float, +) -> ModelVersion: + """Wait for a model version to be ready.""" + deadline = time.monotonic() + poll_timeout + while True: + current = client.get_model_version(name=model_name, version=version) + if current.status == "READY": + return current + if current.status == "FAILED": + raise RuntimeError(f"Model version {model_name}/{version} failed to register: {current.status_message}") + if time.monotonic() > deadline: + raise TimeoutError(f"Model version {model_name}/{version} did not become READY within {poll_timeout}s") + time.sleep(poll_interval) + + +def search_model_version_for_run( + client: mlflow.MlflowClient, + model_name: str, + run_id: str, +) -> ModelVersion: + """Search for a model version for a run.""" + versions = client.search_model_versions(f"name = '{model_name}' and run_id = '{run_id}'") + if not versions: + raise RuntimeError( + "No model versions were created for this run. Ensure the run has permission to register a model." + ) + # MLflow returns the newest model version first for this query. + return versions[0] + + +def build_databricks_urls( + host: str | None, + experiment_id: str, + run_id: str, + model_name: str, + model_version: str, +) -> tuple[str | None, str | None]: + """Build Databricks URLs for a model version.""" + if not host: + return None, None + base = host.rstrip("/") + encoded_name = urllib.parse.quote(model_name, safe="") + run_url = f"{base}/#mlflow/experiments/{experiment_id}/runs/{run_id}" + model_url = f"{base}/#mlflow/models/{encoded_name}/versions/{model_version}" + return run_url, model_url + + +def run_demo(config: DemoConfig) -> None: + """Run the MLflow workspace registry demo.""" + resolved_registry_uri = resolve_registry_uri( + config.tracking_uri, + config.registry_uri, + allow_workspace_fallback=config.allow_workspace_fallback, + ) + if config.tracking_uri: + mlflow.set_tracking_uri(config.tracking_uri) + if resolved_registry_uri: + mlflow.set_registry_uri(resolved_registry_uri) + + client = mlflow.MlflowClient(tracking_uri=mlflow.get_tracking_uri(), registry_uri=mlflow.get_registry_uri()) + experiment_id = ensure_experiment(client, config.experiment) + + torch.manual_seed(7) + model = TinyClassifier() + sample_input = torch.randn(4, 4) + + run_id: str = "" # Initialize to avoid "possibly unbound" error + model_version: ModelVersion | None = None # Initialize to avoid "possibly unbound" error + + with mlflow.start_run(experiment_id=experiment_id, run_name=config.run_name) as run: + run_id = run.info.run_id + mlflow.log_params({"demo": "workspace_registry", "framework": "pytorch", "layers": len(list(model.modules()))}) + + # First log the model without registering it + mlflow.pytorch.log_model( # type: ignore[attr-defined] + model, + artifact_path=config.artifact_path, + ) + + # Then register the model separately + try: + client.create_registered_model(config.registered_model_name) + print(f"Created registered model: {config.registered_model_name}") + except Exception as e: + if "already exists" in str(e).lower(): + print(f"Registered model {config.registered_model_name} already exists") + else: + raise + + # Create model version using the model URI from the logged model + model_uri = f"runs:/{run_id}/{config.artifact_path}" + model_version = client.create_model_version( + name=config.registered_model_name, + source=model_uri, + run_id=run_id, + description="Demo model from workspace registry", + ) + print(f"Created model version: {model_version.version}") + + predictions = model(sample_input).detach() + mlflow.log_artifact( + _dump_tensor(predictions, "predictions.txt"), + artifact_path="artifacts", + ) + + # Wait for model version to be ready + if model_version is None: + raise RuntimeError("Failed to create model version") + ready_version = await_model_version_ready( + client, + config.registered_model_name, + model_version.version, + config.poll_interval, + config.poll_timeout, + ) + + model_uri = f"models:/{config.registered_model_name}/{ready_version.version}" + loaded_model = mlflow.pytorch.load_model(model_uri) # type: ignore[attr-defined] + restored_model = TinyClassifier() + restored_model.load_state_dict(loaded_model.state_dict()) + + verification_input = torch.randn(2, 4) + original_output = model(verification_input) + restored_output = restored_model(verification_input) + if not torch.allclose(original_output, restored_output, atol=1e-5): + raise RuntimeError("Loaded weights differ from the original model outputs.") + + run_url, model_url = build_databricks_urls( + config.databricks_host, + experiment_id, + run_id, + config.registered_model_name, + ready_version.version, + ) + + info_lines = [ + "MLflow workspace registry demo complete!", + f"Run ID: {run_id}", + f"Model URI: {model_uri}", + f"Model version status: {ready_version.status}", + ] + if run_url: + info_lines.append(f"Run UI: {run_url}") + if model_url: + info_lines.append(f"Model UI: {model_url}") + print("\n".join(info_lines)) + + +def _dump_tensor(tensor: torch.Tensor, filename: str) -> str: + """Dump a tensor to a file.""" + path = os.path.join(_ensure_temp_dir(), filename) + with open(path, "w", encoding="utf-8") as handle: + for row in tensor.tolist(): + handle.write(",".join(f"{value:.6f}" for value in row)) + handle.write("\n") + return path + + +_TEMP_DIR: str | None = None + + +def _ensure_temp_dir() -> str: + """Ensure a temporary directory exists.""" + global _TEMP_DIR + if _TEMP_DIR is None: + import tempfile + + _TEMP_DIR = tempfile.mkdtemp(prefix="mlflow-workspace-demo-") + return _TEMP_DIR + + +def _cleanup_temp_dir() -> None: + """Cleanup the temporary directory.""" + global _TEMP_DIR + if _TEMP_DIR and os.path.isdir(_TEMP_DIR): + import shutil + + shutil.rmtree(_TEMP_DIR, ignore_errors=True) + _TEMP_DIR = None + + +def _register_atexit() -> None: + """Register an atexit handler to cleanup the temporary directory.""" + import atexit + + atexit.register(_cleanup_temp_dir) + + +_register_atexit() + + +@hydra.main(version_base="1.2", config_name=CONFIG_NAME) +def main(config: DemoConfig) -> None: + """Main entry point for the MLflow workspace registry demo.""" + try: + run_demo(config) + except (RuntimeError, TimeoutError) as error: + print(f"Error: {error}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() From ab0da89d2c590c073769606a0faeb735cea82aa4 Mon Sep 17 00:00:00 2001 From: adamimos Date: Sat, 4 Oct 2025 11:57:02 -0700 Subject: [PATCH 17/17] Fix frozen instance error in MLFlowPersister MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use object.__setattr__ to bypass Equinox frozen instance check when updating _registered_model_checked flag during model registration. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- simplexity/persistence/mlflow_persister.py | 3 +- uv.lock | 531 ++++++++++++++++++--- 2 files changed, 479 insertions(+), 55 deletions(-) diff --git a/simplexity/persistence/mlflow_persister.py b/simplexity/persistence/mlflow_persister.py index f4220475..3512fbeb 100644 --- a/simplexity/persistence/mlflow_persister.py +++ b/simplexity/persistence/mlflow_persister.py @@ -211,7 +211,8 @@ def _maybe_register_model(self, artifact_path: str) -> None: with contextlib.suppress(Exception): self.client.create_registered_model(self.registered_model_name) - self._registered_model_checked = True + # Use object.__setattr__ to bypass frozen instance check + object.__setattr__(self, "_registered_model_checked", True) source = f"runs:/{self.run_id}/{artifact_path}" import contextlib diff --git a/uv.lock b/uv.lock index ef2042c4..6b810a58 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.12" resolution-markers = [ "python_full_version >= '3.13' and sys_platform != 'win32'", @@ -143,6 +143,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/77/06/bb80f5f86020c4551da315d78b3ab75e8228f89f0162f2c3a819e407941a/attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3", size = 63815, upload-time = "2025-03-13T11:10:21.14Z" }, ] +[[package]] +name = "authlib" +version = "1.6.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cryptography" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cd/3f/1d3bbd0bf23bdd99276d4def22f29c27a914067b4cf66f753ff9b8bbd0f3/authlib-1.6.5.tar.gz", hash = "sha256:6aaf9c79b7cc96c900f0b284061691c5d4e61221640a948fe690b556a6d6d10b", size = 164553, upload-time = "2025-10-02T13:36:09.489Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/aa/5082412d1ee302e9e7d80b6949bc4d2a8fa1149aaab610c5fc24709605d6/authlib-1.6.5-py2.py3-none-any.whl", hash = "sha256:3e0e0507807f842b02175507bdee8957a1d5707fd4afb17c32fb43fee90b6e3a", size = 243608, upload-time = "2025-10-02T13:36:07.637Z" }, +] + [[package]] name = "autopage" version = "0.5.2" @@ -521,6 +533,41 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/59/f1/4da7717f0063a222db253e7121bd6a56f6fb1ba439dcc36659088793347c/coverage-7.8.0-py3-none-any.whl", hash = "sha256:dbf364b4c5e7bae9250528167dfe40219b62e2d573c854d74be213e1e52069f7", size = 203435, upload-time = "2025-03-30T20:36:43.61Z" }, ] +[[package]] +name = "cryptography" +version = "45.0.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi", marker = "platform_python_implementation != 'PyPy'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a7/35/c495bffc2056f2dadb32434f1feedd79abde2a7f8363e1974afa9c33c7e2/cryptography-45.0.7.tar.gz", hash = "sha256:4b1654dfc64ea479c242508eb8c724044f1e964a47d1d1cacc5132292d851971", size = 744980, upload-time = "2025-09-01T11:15:03.146Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/91/925c0ac74362172ae4516000fe877912e33b5983df735ff290c653de4913/cryptography-45.0.7-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:3be4f21c6245930688bd9e162829480de027f8bf962ede33d4f8ba7d67a00cee", size = 7041105, upload-time = "2025-09-01T11:13:59.684Z" }, + { url = "https://files.pythonhosted.org/packages/fc/63/43641c5acce3a6105cf8bd5baeceeb1846bb63067d26dae3e5db59f1513a/cryptography-45.0.7-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:67285f8a611b0ebc0857ced2081e30302909f571a46bfa7a3cc0ad303fe015c6", size = 4205799, upload-time = "2025-09-01T11:14:02.517Z" }, + { url = "https://files.pythonhosted.org/packages/bc/29/c238dd9107f10bfde09a4d1c52fd38828b1aa353ced11f358b5dd2507d24/cryptography-45.0.7-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:577470e39e60a6cd7780793202e63536026d9b8641de011ed9d8174da9ca5339", size = 4430504, upload-time = "2025-09-01T11:14:04.522Z" }, + { url = "https://files.pythonhosted.org/packages/62/62/24203e7cbcc9bd7c94739428cd30680b18ae6b18377ae66075c8e4771b1b/cryptography-45.0.7-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:4bd3e5c4b9682bc112d634f2c6ccc6736ed3635fc3319ac2bb11d768cc5a00d8", size = 4209542, upload-time = "2025-09-01T11:14:06.309Z" }, + { url = "https://files.pythonhosted.org/packages/cd/e3/e7de4771a08620eef2389b86cd87a2c50326827dea5528feb70595439ce4/cryptography-45.0.7-cp311-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:465ccac9d70115cd4de7186e60cfe989de73f7bb23e8a7aa45af18f7412e75bf", size = 3889244, upload-time = "2025-09-01T11:14:08.152Z" }, + { url = "https://files.pythonhosted.org/packages/96/b8/bca71059e79a0bb2f8e4ec61d9c205fbe97876318566cde3b5092529faa9/cryptography-45.0.7-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:16ede8a4f7929b4b7ff3642eba2bf79aa1d71f24ab6ee443935c0d269b6bc513", size = 4461975, upload-time = "2025-09-01T11:14:09.755Z" }, + { url = "https://files.pythonhosted.org/packages/58/67/3f5b26937fe1218c40e95ef4ff8d23c8dc05aa950d54200cc7ea5fb58d28/cryptography-45.0.7-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:8978132287a9d3ad6b54fcd1e08548033cc09dc6aacacb6c004c73c3eb5d3ac3", size = 4209082, upload-time = "2025-09-01T11:14:11.229Z" }, + { url = "https://files.pythonhosted.org/packages/0e/e4/b3e68a4ac363406a56cf7b741eeb80d05284d8c60ee1a55cdc7587e2a553/cryptography-45.0.7-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:b6a0e535baec27b528cb07a119f321ac024592388c5681a5ced167ae98e9fff3", size = 4460397, upload-time = "2025-09-01T11:14:12.924Z" }, + { url = "https://files.pythonhosted.org/packages/22/49/2c93f3cd4e3efc8cb22b02678c1fad691cff9dd71bb889e030d100acbfe0/cryptography-45.0.7-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:a24ee598d10befaec178efdff6054bc4d7e883f615bfbcd08126a0f4931c83a6", size = 4337244, upload-time = "2025-09-01T11:14:14.431Z" }, + { url = "https://files.pythonhosted.org/packages/04/19/030f400de0bccccc09aa262706d90f2ec23d56bc4eb4f4e8268d0ddf3fb8/cryptography-45.0.7-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:fa26fa54c0a9384c27fcdc905a2fb7d60ac6e47d14bc2692145f2b3b1e2cfdbd", size = 4568862, upload-time = "2025-09-01T11:14:16.185Z" }, + { url = "https://files.pythonhosted.org/packages/29/56/3034a3a353efa65116fa20eb3c990a8c9f0d3db4085429040a7eef9ada5f/cryptography-45.0.7-cp311-abi3-win32.whl", hash = "sha256:bef32a5e327bd8e5af915d3416ffefdbe65ed975b646b3805be81b23580b57b8", size = 2936578, upload-time = "2025-09-01T11:14:17.638Z" }, + { url = "https://files.pythonhosted.org/packages/b3/61/0ab90f421c6194705a99d0fa9f6ee2045d916e4455fdbb095a9c2c9a520f/cryptography-45.0.7-cp311-abi3-win_amd64.whl", hash = "sha256:3808e6b2e5f0b46d981c24d79648e5c25c35e59902ea4391a0dcb3e667bf7443", size = 3405400, upload-time = "2025-09-01T11:14:18.958Z" }, + { url = "https://files.pythonhosted.org/packages/63/e8/c436233ddf19c5f15b25ace33979a9dd2e7aa1a59209a0ee8554179f1cc0/cryptography-45.0.7-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:bfb4c801f65dd61cedfc61a83732327fafbac55a47282e6f26f073ca7a41c3b2", size = 7021824, upload-time = "2025-09-01T11:14:20.954Z" }, + { url = "https://files.pythonhosted.org/packages/bc/4c/8f57f2500d0ccd2675c5d0cc462095adf3faa8c52294ba085c036befb901/cryptography-45.0.7-cp37-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:81823935e2f8d476707e85a78a405953a03ef7b7b4f55f93f7c2d9680e5e0691", size = 4202233, upload-time = "2025-09-01T11:14:22.454Z" }, + { url = "https://files.pythonhosted.org/packages/eb/ac/59b7790b4ccaed739fc44775ce4645c9b8ce54cbec53edf16c74fd80cb2b/cryptography-45.0.7-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3994c809c17fc570c2af12c9b840d7cea85a9fd3e5c0e0491f4fa3c029216d59", size = 4423075, upload-time = "2025-09-01T11:14:24.287Z" }, + { url = "https://files.pythonhosted.org/packages/b8/56/d4f07ea21434bf891faa088a6ac15d6d98093a66e75e30ad08e88aa2b9ba/cryptography-45.0.7-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:dad43797959a74103cb59c5dac71409f9c27d34c8a05921341fb64ea8ccb1dd4", size = 4204517, upload-time = "2025-09-01T11:14:25.679Z" }, + { url = "https://files.pythonhosted.org/packages/e8/ac/924a723299848b4c741c1059752c7cfe09473b6fd77d2920398fc26bfb53/cryptography-45.0.7-cp37-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:ce7a453385e4c4693985b4a4a3533e041558851eae061a58a5405363b098fcd3", size = 3882893, upload-time = "2025-09-01T11:14:27.1Z" }, + { url = "https://files.pythonhosted.org/packages/83/dc/4dab2ff0a871cc2d81d3ae6d780991c0192b259c35e4d83fe1de18b20c70/cryptography-45.0.7-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:b04f85ac3a90c227b6e5890acb0edbaf3140938dbecf07bff618bf3638578cf1", size = 4450132, upload-time = "2025-09-01T11:14:28.58Z" }, + { url = "https://files.pythonhosted.org/packages/12/dd/b2882b65db8fc944585d7fb00d67cf84a9cef4e77d9ba8f69082e911d0de/cryptography-45.0.7-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:48c41a44ef8b8c2e80ca4527ee81daa4c527df3ecbc9423c41a420a9559d0e27", size = 4204086, upload-time = "2025-09-01T11:14:30.572Z" }, + { url = "https://files.pythonhosted.org/packages/5d/fa/1d5745d878048699b8eb87c984d4ccc5da4f5008dfd3ad7a94040caca23a/cryptography-45.0.7-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:f3df7b3d0f91b88b2106031fd995802a2e9ae13e02c36c1fc075b43f420f3a17", size = 4449383, upload-time = "2025-09-01T11:14:32.046Z" }, + { url = "https://files.pythonhosted.org/packages/36/8b/fc61f87931bc030598e1876c45b936867bb72777eac693e905ab89832670/cryptography-45.0.7-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:dd342f085542f6eb894ca00ef70236ea46070c8a13824c6bde0dfdcd36065b9b", size = 4332186, upload-time = "2025-09-01T11:14:33.95Z" }, + { url = "https://files.pythonhosted.org/packages/0b/11/09700ddad7443ccb11d674efdbe9a832b4455dc1f16566d9bd3834922ce5/cryptography-45.0.7-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:1993a1bb7e4eccfb922b6cd414f072e08ff5816702a0bdb8941c247a6b1b287c", size = 4561639, upload-time = "2025-09-01T11:14:35.343Z" }, + { url = "https://files.pythonhosted.org/packages/71/ed/8f4c1337e9d3b94d8e50ae0b08ad0304a5709d483bfcadfcc77a23dbcb52/cryptography-45.0.7-cp37-abi3-win32.whl", hash = "sha256:18fcf70f243fe07252dcb1b268a687f2358025ce32f9f88028ca5c364b123ef5", size = 2926552, upload-time = "2025-09-01T11:14:36.929Z" }, + { url = "https://files.pythonhosted.org/packages/bc/ff/026513ecad58dacd45d1d24ebe52b852165a26e287177de1d545325c0c25/cryptography-45.0.7-cp37-abi3-win_amd64.whl", hash = "sha256:7285a89df4900ed3bfaad5679b1e668cb4b38a8de1ccbfc84b05f34512da0a90", size = 3392742, upload-time = "2025-09-01T11:14:38.368Z" }, +] + [[package]] name = "cycler" version = "0.12.1" @@ -530,6 +577,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30", size = 8321, upload-time = "2023-10-07T05:32:16.783Z" }, ] +[[package]] +name = "cyclopts" +version = "3.24.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "docstring-parser", marker = "python_full_version < '4'" }, + { name = "rich" }, + { name = "rich-rst" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/30/ca/7782da3b03242d5f0a16c20371dff99d4bd1fedafe26bc48ff82e42be8c9/cyclopts-3.24.0.tar.gz", hash = "sha256:de6964a041dfb3c57bf043b41e68c43548227a17de1bad246e3a0bfc5c4b7417", size = 76131, upload-time = "2025-09-08T15:40:57.75Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f0/8b/2c95f0645c6f40211896375e6fa51f504b8ccb29c21f6ae661fe87ab044e/cyclopts-3.24.0-py3-none-any.whl", hash = "sha256:809d04cde9108617106091140c3964ee6fceb33cecdd537f7ffa360bde13ed71", size = 86154, upload-time = "2025-09-08T15:40:56.41Z" }, +] + [[package]] name = "databricks-sdk" version = "0.49.0" @@ -590,6 +652,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6e/c6/ac0b6c1e2d138f1002bcf799d330bd6d85084fece321e662a14223794041/Deprecated-1.2.18-py2.py3-none-any.whl", hash = "sha256:bd5011788200372a32418f888e326a09ff80d0214bd961147cfed01b5c018eec", size = 9998, upload-time = "2025-01-27T10:46:09.186Z" }, ] +[[package]] +name = "dnspython" +version = "2.8.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8c/8b/57666417c0f90f08bcafa776861060426765fdb422eb10212086fb811d26/dnspython-2.8.0.tar.gz", hash = "sha256:181d3c6996452cb1189c4046c61599b84a5a86e099562ffde77d26984ff26d0f", size = 368251, upload-time = "2025-09-07T18:58:00.022Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ba/5a/18ad964b0086c6e62e2e7500f7edc89e3faa45033c71c1893d34eed2b2de/dnspython-2.8.0-py3-none-any.whl", hash = "sha256:01d9bbc4a2d76bf0db7c1f729812ded6d912bd318d3b1cf81d30c0f845dbf3af", size = 331094, upload-time = "2025-09-07T18:57:58.071Z" }, +] + [[package]] name = "docker" version = "7.1.0" @@ -604,6 +675,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e3/26/57c6fb270950d476074c087527a558ccb6f4436657314bfb6cdf484114c4/docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0", size = 147774, upload-time = "2024-05-23T11:13:55.01Z" }, ] +[[package]] +name = "docstring-parser" +version = "0.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b2/9d/c3b43da9515bd270df0f80548d9944e389870713cc1fe2b8fb35fe2bcefd/docstring_parser-0.17.0.tar.gz", hash = "sha256:583de4a309722b3315439bb31d64ba3eebada841f2e2cee23b99df001434c912", size = 27442, upload-time = "2025-07-21T07:35:01.868Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/55/e2/2537ebcff11c1ee1ff17d8d0b6f4db75873e3b0fb32c2d4a2ee31ecb310a/docstring_parser-0.17.0-py3-none-any.whl", hash = "sha256:cf2569abd23dce8099b300f9b4fa8191e9582dda731fd533daf54c4551658708", size = 36896, upload-time = "2025-07-21T07:35:00.684Z" }, +] + +[[package]] +name = "docutils" +version = "0.22.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/c0/89fe6215b443b919cb98a5002e107cb5026854ed1ccb6b5833e0768419d1/docutils-0.22.2.tar.gz", hash = "sha256:9fdb771707c8784c8f2728b67cb2c691305933d68137ef95a75db5f4dfbc213d", size = 2289092, upload-time = "2025-09-20T17:55:47.994Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/66/dd/f95350e853a4468ec37478414fc04ae2d61dad7a947b3015c3dcc51a09b9/docutils-0.22.2-py3-none-any.whl", hash = "sha256:b0e98d679283fc3bb0ead8a5da7f501baa632654e7056e9c5846842213d674d8", size = 632667, upload-time = "2025-09-20T17:55:43.052Z" }, +] + [[package]] name = "dotenv" version = "0.9.9" @@ -615,6 +704,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b2/b7/545d2c10c1fc15e48653c91efde329a790f2eecfbbf2bd16003b5db2bab0/dotenv-0.9.9-py2.py3-none-any.whl", hash = "sha256:29cf74a087b31dafdb5a446b6d7e11cbce8ed2741540e2339c69fbef92c94ce9", size = 1892, upload-time = "2025-02-19T22:15:01.647Z" }, ] +[[package]] +name = "email-validator" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "dnspython" }, + { name = "idna" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f5/22/900cb125c76b7aaa450ce02fd727f452243f2e91a61af068b40adba60ea9/email_validator-2.3.0.tar.gz", hash = "sha256:9fc05c37f2f6cf439ff414f8fc46d917929974a82244c20eb10231ba60c54426", size = 51238, upload-time = "2025-08-26T13:09:06.831Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/de/15/545e2b6cf2e3be84bc1ed85613edd75b8aea69807a71c26f4ca6a9258e82/email_validator-2.3.0-py3-none-any.whl", hash = "sha256:80f13f623413e6b197ae73bb10bf4eb0908faf509ad8362c5edeb0be7fd450b4", size = 35604, upload-time = "2025-08-26T13:09:05.858Z" }, +] + [[package]] name = "equinox" version = "0.12.1" @@ -650,6 +752,18 @@ epy = [ { name = "typing-extensions" }, ] +[[package]] +name = "exceptiongroup" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/36/f4/c6e662dade71f56cd2f3735141b265c3c79293c109549c1e6933b0651ffc/exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10", size = 16674, upload-time = "2025-05-10T17:42:49.33Z" }, +] + [[package]] name = "executing" version = "2.2.0" @@ -682,6 +796,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/90/2b/0817a2b257fe88725c25589d89aec060581aabf668707a8d03b2e9e0cb2a/fastjsonschema-2.21.1-py3-none-any.whl", hash = "sha256:c9e5b7e908310918cf494a434eeb31384dd84a98b57a30bcb1f535015b554667", size = 23924, upload-time = "2024-12-02T10:55:07.599Z" }, ] +[[package]] +name = "fastmcp" +version = "2.12.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "authlib" }, + { name = "cyclopts" }, + { name = "exceptiongroup" }, + { name = "httpx" }, + { name = "mcp" }, + { name = "openapi-core" }, + { name = "openapi-pydantic" }, + { name = "pydantic", extra = ["email"] }, + { name = "pyperclip" }, + { name = "python-dotenv" }, + { name = "rich" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a8/b2/57845353a9bc63002995a982e66f3d0be4ec761e7bcb89e7d0638518d42a/fastmcp-2.12.4.tar.gz", hash = "sha256:b55fe89537038f19d0f4476544f9ca5ac171033f61811cc8f12bdeadcbea5016", size = 7167745, upload-time = "2025-09-26T16:43:27.71Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e2/c7/562ff39f25de27caec01e4c1e88cbb5fcae5160802ba3d90be33165df24f/fastmcp-2.12.4-py3-none-any.whl", hash = "sha256:56188fbbc1a9df58c537063f25958c57b5c4d715f73e395c41b51550b247d140", size = 329090, upload-time = "2025-09-26T16:43:25.314Z" }, +] + [[package]] name = "filelock" version = "3.18.0" @@ -918,6 +1054,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, ] +[[package]] +name = "httpx-sse" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6e/fa/66bd985dd0b7c109a3bcb89272ee0bfb7e2b4d06309ad7b38ff866734b2a/httpx_sse-0.4.1.tar.gz", hash = "sha256:8f44d34414bc7b21bf3602713005c5df4917884f76072479b21f68befa4ea26e", size = 12998, upload-time = "2025-06-24T13:21:05.71Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/25/0a/6269e3473b09aed2dab8aa1a600c70f31f00ae1349bee30658f7e358a159/httpx_sse-0.4.1-py3-none-any.whl", hash = "sha256:cba42174344c3a5b06f255ce65b350880f962d99ead85e776f23c6618a377a37", size = 8054, upload-time = "2025-06-24T13:21:04.772Z" }, +] + [[package]] name = "humanize" version = "4.12.2" @@ -1066,6 +1211,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/22/2d/9c0b76f2f9cc0ebede1b9371b6f317243028ed60b90705863d493bae622e/ipywidgets-8.1.5-py3-none-any.whl", hash = "sha256:3290f526f87ae6e77655555baba4f36681c555b8bdbbff430b70e52c34c86245", size = 139767, upload-time = "2024-08-22T12:19:49.494Z" }, ] +[[package]] +name = "isodate" +version = "0.7.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/54/4d/e940025e2ce31a8ce1202635910747e5a87cc3a6a6bb2d00973375014749/isodate-0.7.2.tar.gz", hash = "sha256:4cd1aa0f43ca76f4a6c6c0292a85f40b35ec2e43e315b59f06e6d32171a953e6", size = 29705, upload-time = "2024-10-08T23:04:11.5Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/15/aa/0aca39a37d3c7eb941ba736ede56d689e7be91cab5d9ca846bde3999eba6/isodate-0.7.2-py3-none-any.whl", hash = "sha256:28009937d8031054830160fce6d409ed342816b543597cece116d966c6d99e15", size = 22320, upload-time = "2024-10-08T23:04:09.501Z" }, +] + [[package]] name = "isoduration" version = "20.11.0" @@ -1291,6 +1445,21 @@ format-nongpl = [ { name = "webcolors" }, ] +[[package]] +name = "jsonschema-path" +version = "0.3.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pathable" }, + { name = "pyyaml" }, + { name = "referencing" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6e/45/41ebc679c2a4fced6a722f624c18d658dee42612b83ea24c1caf7c0eb3a8/jsonschema_path-0.3.4.tar.gz", hash = "sha256:8365356039f16cc65fddffafda5f58766e34bebab7d6d105616ab52bc4297001", size = 11159, upload-time = "2025-01-24T14:33:16.547Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/58/3485da8cb93d2f393bce453adeef16896751f14ba3e2024bc21dc9597646/jsonschema_path-0.3.4-py3-none-any.whl", hash = "sha256:f502191fdc2b22050f9a81c9237be9d27145b9001c55842bece5e94e382e52f8", size = 14810, upload-time = "2025-01-24T14:33:14.652Z" }, +] + [[package]] name = "jsonschema-specifications" version = "2024.10.1" @@ -1554,6 +1723,38 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4c/fa/be89a49c640930180657482a74970cdcf6f7072c8d2471e1babe17a222dc/kiwisolver-1.4.8-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:be4816dc51c8a471749d664161b434912eee82f2ea66bd7628bd14583a833e85", size = 2349213, upload-time = "2024-12-24T18:30:40.019Z" }, ] +[[package]] +name = "lazy-object-proxy" +version = "1.12.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/08/a2/69df9c6ba6d316cfd81fe2381e464db3e6de5db45f8c43c6a23504abf8cb/lazy_object_proxy-1.12.0.tar.gz", hash = "sha256:1f5a462d92fd0cfb82f1fab28b51bfb209fabbe6aabf7f0d51472c0c124c0c61", size = 43681, upload-time = "2025-08-22T13:50:06.783Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0d/1b/b5f5bd6bda26f1e15cd3232b223892e4498e34ec70a7f4f11c401ac969f1/lazy_object_proxy-1.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8ee0d6027b760a11cc18281e702c0309dd92da458a74b4c15025d7fc490deede", size = 26746, upload-time = "2025-08-22T13:42:37.572Z" }, + { url = "https://files.pythonhosted.org/packages/55/64/314889b618075c2bfc19293ffa9153ce880ac6153aacfd0a52fcabf21a66/lazy_object_proxy-1.12.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:4ab2c584e3cc8be0dfca422e05ad30a9abe3555ce63e9ab7a559f62f8dbc6ff9", size = 71457, upload-time = "2025-08-22T13:42:38.743Z" }, + { url = "https://files.pythonhosted.org/packages/11/53/857fc2827fc1e13fbdfc0ba2629a7d2579645a06192d5461809540b78913/lazy_object_proxy-1.12.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:14e348185adbd03ec17d051e169ec45686dcd840a3779c9d4c10aabe2ca6e1c0", size = 71036, upload-time = "2025-08-22T13:42:40.184Z" }, + { url = "https://files.pythonhosted.org/packages/2b/24/e581ffed864cd33c1b445b5763d617448ebb880f48675fc9de0471a95cbc/lazy_object_proxy-1.12.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:c4fcbe74fb85df8ba7825fa05eddca764138da752904b378f0ae5ab33a36c308", size = 69329, upload-time = "2025-08-22T13:42:41.311Z" }, + { url = "https://files.pythonhosted.org/packages/78/be/15f8f5a0b0b2e668e756a152257d26370132c97f2f1943329b08f057eff0/lazy_object_proxy-1.12.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:563d2ec8e4d4b68ee7848c5ab4d6057a6d703cb7963b342968bb8758dda33a23", size = 70690, upload-time = "2025-08-22T13:42:42.51Z" }, + { url = "https://files.pythonhosted.org/packages/5d/aa/f02be9bbfb270e13ee608c2b28b8771f20a5f64356c6d9317b20043c6129/lazy_object_proxy-1.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:53c7fd99eb156bbb82cbc5d5188891d8fdd805ba6c1e3b92b90092da2a837073", size = 26563, upload-time = "2025-08-22T13:42:43.685Z" }, + { url = "https://files.pythonhosted.org/packages/f4/26/b74c791008841f8ad896c7f293415136c66cc27e7c7577de4ee68040c110/lazy_object_proxy-1.12.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:86fd61cb2ba249b9f436d789d1356deae69ad3231dc3c0f17293ac535162672e", size = 26745, upload-time = "2025-08-22T13:42:44.982Z" }, + { url = "https://files.pythonhosted.org/packages/9b/52/641870d309e5d1fb1ea7d462a818ca727e43bfa431d8c34b173eb090348c/lazy_object_proxy-1.12.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:81d1852fb30fab81696f93db1b1e55a5d1ff7940838191062f5f56987d5fcc3e", size = 71537, upload-time = "2025-08-22T13:42:46.141Z" }, + { url = "https://files.pythonhosted.org/packages/47/b6/919118e99d51c5e76e8bf5a27df406884921c0acf2c7b8a3b38d847ab3e9/lazy_object_proxy-1.12.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:be9045646d83f6c2664c1330904b245ae2371b5c57a3195e4028aedc9f999655", size = 71141, upload-time = "2025-08-22T13:42:47.375Z" }, + { url = "https://files.pythonhosted.org/packages/e5/47/1d20e626567b41de085cf4d4fb3661a56c159feaa73c825917b3b4d4f806/lazy_object_proxy-1.12.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:67f07ab742f1adfb3966c40f630baaa7902be4222a17941f3d85fd1dae5565ff", size = 69449, upload-time = "2025-08-22T13:42:48.49Z" }, + { url = "https://files.pythonhosted.org/packages/58/8d/25c20ff1a1a8426d9af2d0b6f29f6388005fc8cd10d6ee71f48bff86fdd0/lazy_object_proxy-1.12.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:75ba769017b944fcacbf6a80c18b2761a1795b03f8899acdad1f1c39db4409be", size = 70744, upload-time = "2025-08-22T13:42:49.608Z" }, + { url = "https://files.pythonhosted.org/packages/c0/67/8ec9abe15c4f8a4bcc6e65160a2c667240d025cbb6591b879bea55625263/lazy_object_proxy-1.12.0-cp313-cp313-win_amd64.whl", hash = "sha256:7b22c2bbfb155706b928ac4d74c1a63ac8552a55ba7fff4445155523ea4067e1", size = 26568, upload-time = "2025-08-22T13:42:57.719Z" }, + { url = "https://files.pythonhosted.org/packages/23/12/cd2235463f3469fd6c62d41d92b7f120e8134f76e52421413a0ad16d493e/lazy_object_proxy-1.12.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:4a79b909aa16bde8ae606f06e6bbc9d3219d2e57fb3e0076e17879072b742c65", size = 27391, upload-time = "2025-08-22T13:42:50.62Z" }, + { url = "https://files.pythonhosted.org/packages/60/9e/f1c53e39bbebad2e8609c67d0830cc275f694d0ea23d78e8f6db526c12d3/lazy_object_proxy-1.12.0-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:338ab2f132276203e404951205fe80c3fd59429b3a724e7b662b2eb539bb1be9", size = 80552, upload-time = "2025-08-22T13:42:51.731Z" }, + { url = "https://files.pythonhosted.org/packages/4c/b6/6c513693448dcb317d9d8c91d91f47addc09553613379e504435b4cc8b3e/lazy_object_proxy-1.12.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8c40b3c9faee2e32bfce0df4ae63f4e73529766893258eca78548bac801c8f66", size = 82857, upload-time = "2025-08-22T13:42:53.225Z" }, + { url = "https://files.pythonhosted.org/packages/12/1c/d9c4aaa4c75da11eb7c22c43d7c90a53b4fca0e27784a5ab207768debea7/lazy_object_proxy-1.12.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:717484c309df78cedf48396e420fa57fc8a2b1f06ea889df7248fdd156e58847", size = 80833, upload-time = "2025-08-22T13:42:54.391Z" }, + { url = "https://files.pythonhosted.org/packages/0b/ae/29117275aac7d7d78ae4f5a4787f36ff33262499d486ac0bf3e0b97889f6/lazy_object_proxy-1.12.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:a6b7ea5ea1ffe15059eb44bcbcb258f97bcb40e139b88152c40d07b1a1dfc9ac", size = 79516, upload-time = "2025-08-22T13:42:55.812Z" }, + { url = "https://files.pythonhosted.org/packages/19/40/b4e48b2c38c69392ae702ae7afa7b6551e0ca5d38263198b7c79de8b3bdf/lazy_object_proxy-1.12.0-cp313-cp313t-win_amd64.whl", hash = "sha256:08c465fb5cd23527512f9bd7b4c7ba6cec33e28aad36fbbe46bf7b858f9f3f7f", size = 27656, upload-time = "2025-08-22T13:42:56.793Z" }, + { url = "https://files.pythonhosted.org/packages/ef/3a/277857b51ae419a1574557c0b12e0d06bf327b758ba94cafc664cb1e2f66/lazy_object_proxy-1.12.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:c9defba70ab943f1df98a656247966d7729da2fe9c2d5d85346464bf320820a3", size = 26582, upload-time = "2025-08-22T13:49:49.366Z" }, + { url = "https://files.pythonhosted.org/packages/1a/b6/c5e0fa43535bb9c87880e0ba037cdb1c50e01850b0831e80eb4f4762f270/lazy_object_proxy-1.12.0-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:6763941dbf97eea6b90f5b06eb4da9418cc088fce0e3883f5816090f9afcde4a", size = 71059, upload-time = "2025-08-22T13:49:50.488Z" }, + { url = "https://files.pythonhosted.org/packages/06/8a/7dcad19c685963c652624702f1a968ff10220b16bfcc442257038216bf55/lazy_object_proxy-1.12.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fdc70d81235fc586b9e3d1aeef7d1553259b62ecaae9db2167a5d2550dcc391a", size = 71034, upload-time = "2025-08-22T13:49:54.224Z" }, + { url = "https://files.pythonhosted.org/packages/12/ac/34cbfb433a10e28c7fd830f91c5a348462ba748413cbb950c7f259e67aa7/lazy_object_proxy-1.12.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:0a83c6f7a6b2bfc11ef3ed67f8cbe99f8ff500b05655d8e7df9aab993a6abc95", size = 69529, upload-time = "2025-08-22T13:49:55.29Z" }, + { url = "https://files.pythonhosted.org/packages/6f/6a/11ad7e349307c3ca4c0175db7a77d60ce42a41c60bcb11800aabd6a8acb8/lazy_object_proxy-1.12.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:256262384ebd2a77b023ad02fbcc9326282bcfd16484d5531154b02bc304f4c5", size = 70391, upload-time = "2025-08-22T13:49:56.35Z" }, + { url = "https://files.pythonhosted.org/packages/59/97/9b410ed8fbc6e79c1ee8b13f8777a80137d4bc189caf2c6202358e66192c/lazy_object_proxy-1.12.0-cp314-cp314-win_amd64.whl", hash = "sha256:7601ec171c7e8584f8ff3f4e440aa2eebf93e854f04639263875b8c2971f819f", size = 26988, upload-time = "2025-08-22T13:49:57.302Z" }, +] + [[package]] name = "mako" version = "1.3.9" @@ -1567,12 +1768,15 @@ wheels = [ ] [[package]] -name = "markdown" -version = "3.7" +name = "markdown-it-py" +version = "4.0.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/54/28/3af612670f82f4c056911fbbbb42760255801b3068c48de792d354ff4472/markdown-3.7.tar.gz", hash = "sha256:2ae2471477cfd02dbbf038d5d9bc226d40def84b4fe2986e49b59b6b472bbed2", size = 357086, upload-time = "2024-08-16T15:55:17.812Z" } +dependencies = [ + { name = "mdurl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5b/f5/4ec618ed16cc4f8fb3b701563655a69816155e79e24a17b651541804721d/markdown_it_py-4.0.0.tar.gz", hash = "sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3", size = 73070, upload-time = "2025-08-11T12:57:52.854Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3f/08/83871f3c50fc983b88547c196d11cf8c3340e37c32d2e9d6152abe2c61f7/Markdown-3.7-py3-none-any.whl", hash = "sha256:7eb6df5690b81a1d7942992c97fad2938e956e79df20cbc6186e9c3a77b1c803", size = 106349, upload-time = "2024-08-16T15:55:16.176Z" }, + { url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" }, ] [[package]] @@ -1662,6 +1866,37 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8f/8e/9ad090d3553c280a8060fbf6e24dc1c0c29704ee7d1c372f0c174aa59285/matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca", size = 9899, upload-time = "2024-04-15T13:44:43.265Z" }, ] +[[package]] +name = "mcp" +version = "1.16.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "httpx" }, + { name = "httpx-sse" }, + { name = "jsonschema" }, + { name = "pydantic" }, + { name = "pydantic-settings" }, + { name = "python-multipart" }, + { name = "pywin32", marker = "sys_platform == 'win32'" }, + { name = "sse-starlette" }, + { name = "starlette" }, + { name = "uvicorn", marker = "sys_platform != 'emscripten'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3d/a1/b1f328da3b153683d2ec34f849b4b6eac2790fb240e3aef06ff2fab3df9d/mcp-1.16.0.tar.gz", hash = "sha256:39b8ca25460c578ee2cdad33feeea122694cfdf73eef58bee76c42f6ef0589df", size = 472918, upload-time = "2025-10-02T16:58:20.631Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c9/0e/7cebc88e17daf94ebe28c95633af595ccb2864dc2ee7abd75542d98495cc/mcp-1.16.0-py3-none-any.whl", hash = "sha256:ec917be9a5d31b09ba331e1768aa576e0af45470d657a0319996a20a57d7d633", size = 167266, upload-time = "2025-10-02T16:58:19.039Z" }, +] + +[[package]] +name = "mdurl" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, +] + [[package]] name = "mistune" version = "3.1.3" @@ -1695,18 +1930,19 @@ wheels = [ [[package]] name = "mlflow" -version = "2.21.2" +version = "3.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "alembic" }, + { name = "cryptography" }, { name = "docker" }, + { name = "fastmcp" }, { name = "flask" }, { name = "graphene" }, { name = "gunicorn", marker = "sys_platform != 'win32'" }, - { name = "jinja2" }, - { name = "markdown" }, { name = "matplotlib" }, { name = "mlflow-skinny" }, + { name = "mlflow-tracing" }, { name = "numpy" }, { name = "pandas" }, { name = "pyarrow" }, @@ -1715,14 +1951,14 @@ dependencies = [ { name = "sqlalchemy" }, { name = "waitress", marker = "sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/63/ed/c6b0d24610b51fcea2200f5e091a7b33674d1351959478c398822afed255/mlflow-2.21.2.tar.gz", hash = "sha256:ea6beaaa6cdf296db0c961f7269e6be90d409f5d48be743ecdce4b608d88297e", size = 27620389, upload-time = "2025-03-26T15:47:25.587Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b3/6b/94e454bf1ff34eb503701c3cb20742a72abab33957392f1f2b3e9b4d5601/mlflow-3.4.0.tar.gz", hash = "sha256:a564f9296b860fe710c0574f9f309b53ae30662eb969994df2453b198fa4c3bb", size = 26061019, upload-time = "2025-09-17T06:24:29.411Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/72/9c/08620b82d822eca5ed6054e8cc6d0fa4c46b2e883db9f7ddd3ac470339fb/mlflow-2.21.2-py3-none-any.whl", hash = "sha256:e1edb4b44e124d5bcf526d9880c5f4a96817bf384fa04089859bb8fc06698951", size = 28235160, upload-time = "2025-03-26T15:47:21.908Z" }, + { url = "https://files.pythonhosted.org/packages/52/fe/1ed27f800cd1709a272c6e26b78ec3d77a5ba482171ea1b5bfbcf4c067c0/mlflow-3.4.0-py3-none-any.whl", hash = "sha256:065ca7f9acda7bdfbc01deefdcb31172c91ff954ad76405a9d1f9d67dea4c33c", size = 26726629, upload-time = "2025-09-17T06:24:26.457Z" }, ] [[package]] name = "mlflow-skinny" -version = "2.21.2" +version = "3.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cachetools" }, @@ -1733,19 +1969,49 @@ dependencies = [ { name = "gitpython" }, { name = "importlib-metadata" }, { name = "opentelemetry-api" }, + { name = "opentelemetry-proto" }, { name = "opentelemetry-sdk" }, { name = "packaging" }, { name = "protobuf" }, { name = "pydantic" }, + { name = "python-dotenv" }, { name = "pyyaml" }, { name = "requests" }, { name = "sqlparse" }, { name = "typing-extensions" }, { name = "uvicorn" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2b/a9/7c8e9f8e587aaba7ecccc8ec0570668389fc92b5b6f94d073299ea5cea9c/mlflow_skinny-2.21.2.tar.gz", hash = "sha256:6e0ee8be1b38a8b02a5a0547a0c36b175d06590f78e38a3ff127cbc3be8e27d6", size = 5771939, upload-time = "2025-03-26T15:29:03.283Z" } +sdist = { url = "https://files.pythonhosted.org/packages/95/90/ddfcfba5b64fb2a9a874998fcd0a1a6e4013b95744eaeeb7a0b8a78f25c5/mlflow_skinny-3.4.0.tar.gz", hash = "sha256:1730207e64811b00ebfa2d5b9b899212a7e6a06e8cd49eb3f90888ff7e7bc3a7", size = 1851246, upload-time = "2025-09-17T06:11:11.966Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e2/ec/adfbccfac31e0d85003d74e03694a30ba54b5e0e3c7783de64952e9c3b84/mlflow_skinny-2.21.2-py3-none-any.whl", hash = "sha256:8d735f7f15100d4b50834cc540e3e7da5833138a8147c9306cf03c4218b8afa3", size = 6143970, upload-time = "2025-03-26T15:29:01.052Z" }, + { url = "https://files.pythonhosted.org/packages/1b/94/7acd7c6970cc75da1fd3b550e43d8b99068032022f47b0ef224a137ec679/mlflow_skinny-3.4.0-py3-none-any.whl", hash = "sha256:51e06c1f717093501a9a1b2d5b7bea382bd1b7c3542a52f824c510263f86f0c7", size = 2221734, upload-time = "2025-09-17T06:11:09.89Z" }, +] + +[[package]] +name = "mlflow-tracing" +version = "3.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cachetools" }, + { name = "databricks-sdk" }, + { name = "opentelemetry-api" }, + { name = "opentelemetry-proto" }, + { name = "opentelemetry-sdk" }, + { name = "packaging" }, + { name = "protobuf" }, + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/85/61/a2c17b64534728995302c5a3f7abe9fcfa848beeffdc8c069e0dbcafa30e/mlflow_tracing-3.4.0.tar.gz", hash = "sha256:805537d43387717c355bcc07c065941f1614ed037de75b73c168cdf60d5e6e08", size = 1011159, upload-time = "2025-09-17T06:13:41.108Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ae/96/403b1191ccf587f19a8c94085477600d6e6b3d61a7aff46f353b20b450f9/mlflow_tracing-3.4.0-py3-none-any.whl", hash = "sha256:06e4a423373c96507f3e40d00a564665a375e0d78856917e52dd78d8b833edf2", size = 1220253, upload-time = "2025-09-17T06:13:39.199Z" }, +] + +[[package]] +name = "more-itertools" +version = "10.8.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ea/5d/38b681d3fce7a266dd9ab73c66959406d565b3e85f21d5e66e1181d93721/more_itertools-10.8.0.tar.gz", hash = "sha256:f638ddf8a1a0d134181275fb5d58b086ead7c6a72429ad725c67503f13ba30bd", size = 137431, upload-time = "2025-09-02T15:23:11.018Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/8e/469e5a4a2f5855992e425f3cb33804cc07bf18d48f2db061aec61ce50270/more_itertools-10.8.0-py3-none-any.whl", hash = "sha256:52d4362373dcf7c52546bc4af9a86ee7c4579df9a8dc268be0a2f949d376cc9b", size = 69667, upload-time = "2025-09-02T15:23:09.635Z" }, ] [[package]] @@ -2271,6 +2537,67 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e3/94/1843518e420fa3ed6919835845df698c7e27e183cb997394e4a670973a65/omegaconf-2.3.0-py3-none-any.whl", hash = "sha256:7b4df175cdb08ba400f45cae3bdcae7ba8365db4d165fc65fd04b050ab63b46b", size = 79500, upload-time = "2022-12-08T20:59:19.686Z" }, ] +[[package]] +name = "openapi-core" +version = "0.19.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "isodate" }, + { name = "jsonschema" }, + { name = "jsonschema-path" }, + { name = "more-itertools" }, + { name = "openapi-schema-validator" }, + { name = "openapi-spec-validator" }, + { name = "parse" }, + { name = "typing-extensions" }, + { name = "werkzeug" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b1/35/1acaa5f2fcc6e54eded34a2ec74b479439c4e469fc4e8d0e803fda0234db/openapi_core-0.19.5.tar.gz", hash = "sha256:421e753da56c391704454e66afe4803a290108590ac8fa6f4a4487f4ec11f2d3", size = 103264, upload-time = "2025-03-20T20:17:28.193Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/27/6f/83ead0e2e30a90445ee4fc0135f43741aebc30cca5b43f20968b603e30b6/openapi_core-0.19.5-py3-none-any.whl", hash = "sha256:ef7210e83a59394f46ce282639d8d26ad6fc8094aa904c9c16eb1bac8908911f", size = 106595, upload-time = "2025-03-20T20:17:26.77Z" }, +] + +[[package]] +name = "openapi-pydantic" +version = "0.5.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/02/2e/58d83848dd1a79cb92ed8e63f6ba901ca282c5f09d04af9423ec26c56fd7/openapi_pydantic-0.5.1.tar.gz", hash = "sha256:ff6835af6bde7a459fb93eb93bb92b8749b754fc6e51b2f1590a19dc3005ee0d", size = 60892, upload-time = "2025-01-08T19:29:27.083Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/12/cf/03675d8bd8ecbf4445504d8071adab19f5f993676795708e36402ab38263/openapi_pydantic-0.5.1-py3-none-any.whl", hash = "sha256:a3a09ef4586f5bd760a8df7f43028b60cafb6d9f61de2acba9574766255ab146", size = 96381, upload-time = "2025-01-08T19:29:25.275Z" }, +] + +[[package]] +name = "openapi-schema-validator" +version = "0.6.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jsonschema" }, + { name = "jsonschema-specifications" }, + { name = "rfc3339-validator" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8b/f3/5507ad3325169347cd8ced61c232ff3df70e2b250c49f0fe140edb4973c6/openapi_schema_validator-0.6.3.tar.gz", hash = "sha256:f37bace4fc2a5d96692f4f8b31dc0f8d7400fd04f3a937798eaf880d425de6ee", size = 11550, upload-time = "2025-01-10T18:08:22.268Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/21/c6/ad0fba32775ae749016829dace42ed80f4407b171da41313d1a3a5f102e4/openapi_schema_validator-0.6.3-py3-none-any.whl", hash = "sha256:f3b9870f4e556b5a62a1c39da72a6b4b16f3ad9c73dc80084b1b11e74ba148a3", size = 8755, upload-time = "2025-01-10T18:08:19.758Z" }, +] + +[[package]] +name = "openapi-spec-validator" +version = "0.7.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jsonschema" }, + { name = "jsonschema-path" }, + { name = "lazy-object-proxy" }, + { name = "openapi-schema-validator" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/82/af/fe2d7618d6eae6fb3a82766a44ed87cd8d6d82b4564ed1c7cfb0f6378e91/openapi_spec_validator-0.7.2.tar.gz", hash = "sha256:cc029309b5c5dbc7859df0372d55e9d1ff43e96d678b9ba087f7c56fc586f734", size = 36855, upload-time = "2025-06-07T14:48:56.299Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/27/dd/b3fd642260cb17532f66cc1e8250f3507d1e580483e209dc1e9d13bd980d/openapi_spec_validator-0.7.2-py3-none-any.whl", hash = "sha256:4bbdc0894ec85f1d1bea1d6d9c8b2c3c8d7ccaa13577ef40da9c006c9fd0eb60", size = 39713, upload-time = "2025-06-07T14:48:54.077Z" }, +] + [[package]] name = "opentelemetry-api" version = "1.31.1" @@ -2284,6 +2611,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6c/c8/86557ff0da32f3817bc4face57ea35cfdc2f9d3bcefd42311ef860dcefb7/opentelemetry_api-1.31.1-py3-none-any.whl", hash = "sha256:1511a3f470c9c8a32eeea68d4ea37835880c0eed09dd1a0187acc8b1301da0a1", size = 65197, upload-time = "2025-03-20T14:43:57.518Z" }, ] +[[package]] +name = "opentelemetry-proto" +version = "1.37.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/dd/ea/a75f36b463a36f3c5a10c0b5292c58b31dbdde74f6f905d3d0ab2313987b/opentelemetry_proto-1.37.0.tar.gz", hash = "sha256:30f5c494faf66f77faeaefa35ed4443c5edb3b0aa46dad073ed7210e1a789538", size = 46151, upload-time = "2025-09-11T10:29:11.04Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c4/25/f89ea66c59bd7687e218361826c969443c4fa15dfe89733f3bf1e2a9e971/opentelemetry_proto-1.37.0-py3-none-any.whl", hash = "sha256:8ed8c066ae8828bbf0c39229979bdf583a126981142378a9cbe9d6fd5701c6e2", size = 72534, upload-time = "2025-09-11T10:28:56.831Z" }, +] + [[package]] name = "opentelemetry-sdk" version = "1.31.1" @@ -2451,6 +2790,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ef/af/4fbc8cab944db5d21b7e2a5b8e9211a03a79852b1157e2c102fcc61ac440/pandocfilters-1.5.1-py2.py3-none-any.whl", hash = "sha256:93be382804a9cdb0a7267585f157e5d1731bbe5545a85b268d6f5fe6232de2bc", size = 8663, upload-time = "2024-01-18T20:08:11.28Z" }, ] +[[package]] +name = "parse" +version = "1.20.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4f/78/d9b09ba24bb36ef8b83b71be547e118d46214735b6dfb39e4bfde0e9b9dd/parse-1.20.2.tar.gz", hash = "sha256:b41d604d16503c79d81af5165155c0b20f6c8d6c559efa66b4b695c3e5a0a0ce", size = 29391, upload-time = "2024-06-11T04:41:57.34Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d0/31/ba45bf0b2aa7898d81cbbfac0e88c267befb59ad91a19e36e1bc5578ddb1/parse-1.20.2-py2.py3-none-any.whl", hash = "sha256:967095588cb802add9177d0c0b6133b5ba33b1ea9007ca800e526f42a85af558", size = 20126, upload-time = "2024-06-11T04:41:55.057Z" }, +] + [[package]] name = "parso" version = "0.8.4" @@ -2460,6 +2808,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c6/ac/dac4a63f978e4dcb3c6d3a78c4d8e0192a113d288502a1216950c41b1027/parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18", size = 103650, upload-time = "2024-04-05T09:43:53.299Z" }, ] +[[package]] +name = "pathable" +version = "0.4.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/67/93/8f2c2075b180c12c1e9f6a09d1a985bc2036906b13dff1d8917e395f2048/pathable-0.4.4.tar.gz", hash = "sha256:6905a3cd17804edfac7875b5f6c9142a218c7caef78693c2dbbbfbac186d88b2", size = 8124, upload-time = "2025-01-10T18:43:13.247Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7d/eb/b6260b31b1a96386c0a880edebe26f89669098acea8e0318bff6adb378fd/pathable-0.4.4-py3-none-any.whl", hash = "sha256:5ae9e94793b6ef5a4cbe0a7ce9dbbefc1eec38df253763fd0aeeacf2762dbbc2", size = 9592, upload-time = "2025-01-10T18:43:11.88Z" }, +] + [[package]] name = "pbr" version = "6.1.1" @@ -2722,7 +3079,7 @@ wheels = [ [[package]] name = "pydantic" -version = "2.11.1" +version = "2.11.10" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "annotated-types" }, @@ -2730,51 +3087,70 @@ dependencies = [ { name = "typing-extensions" }, { name = "typing-inspection" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/93/a3/698b87a4d4d303d7c5f62ea5fbf7a79cab236ccfbd0a17847b7f77f8163e/pydantic-2.11.1.tar.gz", hash = "sha256:442557d2910e75c991c39f4b4ab18963d57b9b55122c8b2a9cd176d8c29ce968", size = 782817, upload-time = "2025-03-28T21:14:58.347Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ae/54/ecab642b3bed45f7d5f59b38443dcb36ef50f85af192e6ece103dbfe9587/pydantic-2.11.10.tar.gz", hash = "sha256:dc280f0982fbda6c38fada4e476dc0a4f3aeaf9c6ad4c28df68a666ec3c61423", size = 788494, upload-time = "2025-10-04T10:40:41.338Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/cc/12/f9221a949f2419e2e23847303c002476c26fbcfd62dc7f3d25d0bec5ca99/pydantic-2.11.1-py3-none-any.whl", hash = "sha256:5b6c415eee9f8123a14d859be0c84363fec6b1feb6b688d6435801230b56e0b8", size = 442648, upload-time = "2025-03-28T21:14:55.856Z" }, + { url = "https://files.pythonhosted.org/packages/bd/1f/73c53fcbfb0b5a78f91176df41945ca466e71e9d9d836e5c522abda39ee7/pydantic-2.11.10-py3-none-any.whl", hash = "sha256:802a655709d49bd004c31e865ef37da30b540786a46bfce02333e0e24b5fe29a", size = 444823, upload-time = "2025-10-04T10:40:39.055Z" }, +] + +[package.optional-dependencies] +email = [ + { name = "email-validator" }, ] [[package]] name = "pydantic-core" -version = "2.33.0" +version = "2.33.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b9/05/91ce14dfd5a3a99555fce436318cc0fd1f08c4daa32b3248ad63669ea8b4/pydantic_core-2.33.0.tar.gz", hash = "sha256:40eb8af662ba409c3cbf4a8150ad32ae73514cd7cb1f1a2113af39763dd616b3", size = 434080, upload-time = "2025-03-26T20:30:05.906Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a9/c4/c9381323cbdc1bb26d352bc184422ce77c4bc2f2312b782761093a59fafc/pydantic_core-2.33.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:6c32a40712e3662bebe524abe8abb757f2fa2000028d64cc5a1006016c06af43", size = 2025127, upload-time = "2025-03-26T20:27:27.704Z" }, - { url = "https://files.pythonhosted.org/packages/6f/bd/af35278080716ecab8f57e84515c7dc535ed95d1c7f52c1c6f7b313a9dab/pydantic_core-2.33.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8ec86b5baa36f0a0bfb37db86c7d52652f8e8aa076ab745ef7725784183c3fdd", size = 1851687, upload-time = "2025-03-26T20:27:29.67Z" }, - { url = "https://files.pythonhosted.org/packages/12/e4/a01461225809c3533c23bd1916b1e8c2e21727f0fea60ab1acbffc4e2fca/pydantic_core-2.33.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4deac83a8cc1d09e40683be0bc6d1fa4cde8df0a9bf0cda5693f9b0569ac01b6", size = 1892232, upload-time = "2025-03-26T20:27:31.374Z" }, - { url = "https://files.pythonhosted.org/packages/51/17/3d53d62a328fb0a49911c2962036b9e7a4f781b7d15e9093c26299e5f76d/pydantic_core-2.33.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:175ab598fb457a9aee63206a1993874badf3ed9a456e0654273e56f00747bbd6", size = 1977896, upload-time = "2025-03-26T20:27:33.055Z" }, - { url = "https://files.pythonhosted.org/packages/30/98/01f9d86e02ec4a38f4b02086acf067f2c776b845d43f901bd1ee1c21bc4b/pydantic_core-2.33.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5f36afd0d56a6c42cf4e8465b6441cf546ed69d3a4ec92724cc9c8c61bd6ecf4", size = 2127717, upload-time = "2025-03-26T20:27:34.768Z" }, - { url = "https://files.pythonhosted.org/packages/3c/43/6f381575c61b7c58b0fd0b92134c5a1897deea4cdfc3d47567b3ff460a4e/pydantic_core-2.33.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0a98257451164666afafc7cbf5fb00d613e33f7e7ebb322fbcd99345695a9a61", size = 2680287, upload-time = "2025-03-26T20:27:36.826Z" }, - { url = "https://files.pythonhosted.org/packages/01/42/c0d10d1451d161a9a0da9bbef023b8005aa26e9993a8cc24dc9e3aa96c93/pydantic_core-2.33.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ecc6d02d69b54a2eb83ebcc6f29df04957f734bcf309d346b4f83354d8376862", size = 2008276, upload-time = "2025-03-26T20:27:38.609Z" }, - { url = "https://files.pythonhosted.org/packages/20/ca/e08df9dba546905c70bae44ced9f3bea25432e34448d95618d41968f40b7/pydantic_core-2.33.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1a69b7596c6603afd049ce7f3835bcf57dd3892fc7279f0ddf987bebed8caa5a", size = 2115305, upload-time = "2025-03-26T20:27:41.717Z" }, - { url = "https://files.pythonhosted.org/packages/03/1f/9b01d990730a98833113581a78e595fd40ed4c20f9693f5a658fb5f91eff/pydantic_core-2.33.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ea30239c148b6ef41364c6f51d103c2988965b643d62e10b233b5efdca8c0099", size = 2068999, upload-time = "2025-03-26T20:27:43.42Z" }, - { url = "https://files.pythonhosted.org/packages/20/18/fe752476a709191148e8b1e1139147841ea5d2b22adcde6ee6abb6c8e7cf/pydantic_core-2.33.0-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:abfa44cf2f7f7d7a199be6c6ec141c9024063205545aa09304349781b9a125e6", size = 2241488, upload-time = "2025-03-26T20:27:46.744Z" }, - { url = "https://files.pythonhosted.org/packages/81/22/14738ad0a0bf484b928c9e52004f5e0b81dd8dabbdf23b843717b37a71d1/pydantic_core-2.33.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:20d4275f3c4659d92048c70797e5fdc396c6e4446caf517ba5cad2db60cd39d3", size = 2248430, upload-time = "2025-03-26T20:27:48.458Z" }, - { url = "https://files.pythonhosted.org/packages/e8/27/be7571e215ac8d321712f2433c445b03dbcd645366a18f67b334df8912bc/pydantic_core-2.33.0-cp312-cp312-win32.whl", hash = "sha256:918f2013d7eadea1d88d1a35fd4a1e16aaf90343eb446f91cb091ce7f9b431a2", size = 1908353, upload-time = "2025-03-26T20:27:50.488Z" }, - { url = "https://files.pythonhosted.org/packages/be/3a/be78f28732f93128bd0e3944bdd4b3970b389a1fbd44907c97291c8dcdec/pydantic_core-2.33.0-cp312-cp312-win_amd64.whl", hash = "sha256:aec79acc183865bad120b0190afac467c20b15289050648b876b07777e67ea48", size = 1955956, upload-time = "2025-03-26T20:27:52.239Z" }, - { url = "https://files.pythonhosted.org/packages/21/26/b8911ac74faa994694b76ee6a22875cc7a4abea3c381fdba4edc6c6bef84/pydantic_core-2.33.0-cp312-cp312-win_arm64.whl", hash = "sha256:5461934e895968655225dfa8b3be79e7e927e95d4bd6c2d40edd2fa7052e71b6", size = 1903259, upload-time = "2025-03-26T20:27:54.06Z" }, - { url = "https://files.pythonhosted.org/packages/79/20/de2ad03ce8f5b3accf2196ea9b44f31b0cd16ac6e8cfc6b21976ed45ec35/pydantic_core-2.33.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:f00e8b59e1fc8f09d05594aa7d2b726f1b277ca6155fc84c0396db1b373c4555", size = 2032214, upload-time = "2025-03-26T20:27:56.197Z" }, - { url = "https://files.pythonhosted.org/packages/f9/af/6817dfda9aac4958d8b516cbb94af507eb171c997ea66453d4d162ae8948/pydantic_core-2.33.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1a73be93ecef45786d7d95b0c5e9b294faf35629d03d5b145b09b81258c7cd6d", size = 1852338, upload-time = "2025-03-26T20:27:57.876Z" }, - { url = "https://files.pythonhosted.org/packages/44/f3/49193a312d9c49314f2b953fb55740b7c530710977cabe7183b8ef111b7f/pydantic_core-2.33.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ff48a55be9da6930254565ff5238d71d5e9cd8c5487a191cb85df3bdb8c77365", size = 1896913, upload-time = "2025-03-26T20:27:59.719Z" }, - { url = "https://files.pythonhosted.org/packages/06/e0/c746677825b2e29a2fa02122a8991c83cdd5b4c5f638f0664d4e35edd4b2/pydantic_core-2.33.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:26a4ea04195638dcd8c53dadb545d70badba51735b1594810e9768c2c0b4a5da", size = 1986046, upload-time = "2025-03-26T20:28:01.583Z" }, - { url = "https://files.pythonhosted.org/packages/11/ec/44914e7ff78cef16afb5e5273d480c136725acd73d894affdbe2a1bbaad5/pydantic_core-2.33.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:41d698dcbe12b60661f0632b543dbb119e6ba088103b364ff65e951610cb7ce0", size = 2128097, upload-time = "2025-03-26T20:28:03.437Z" }, - { url = "https://files.pythonhosted.org/packages/fe/f5/c6247d424d01f605ed2e3802f338691cae17137cee6484dce9f1ac0b872b/pydantic_core-2.33.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ae62032ef513fe6281ef0009e30838a01057b832dc265da32c10469622613885", size = 2681062, upload-time = "2025-03-26T20:28:05.498Z" }, - { url = "https://files.pythonhosted.org/packages/f0/85/114a2113b126fdd7cf9a9443b1b1fe1b572e5bd259d50ba9d5d3e1927fa9/pydantic_core-2.33.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f225f3a3995dbbc26affc191d0443c6c4aa71b83358fd4c2b7d63e2f6f0336f9", size = 2007487, upload-time = "2025-03-26T20:28:07.879Z" }, - { url = "https://files.pythonhosted.org/packages/e6/40/3c05ed28d225c7a9acd2b34c5c8010c279683a870219b97e9f164a5a8af0/pydantic_core-2.33.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5bdd36b362f419c78d09630cbaebc64913f66f62bda6d42d5fbb08da8cc4f181", size = 2121382, upload-time = "2025-03-26T20:28:09.651Z" }, - { url = "https://files.pythonhosted.org/packages/8a/22/e70c086f41eebd323e6baa92cc906c3f38ddce7486007eb2bdb3b11c8f64/pydantic_core-2.33.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:2a0147c0bef783fd9abc9f016d66edb6cac466dc54a17ec5f5ada08ff65caf5d", size = 2072473, upload-time = "2025-03-26T20:28:11.69Z" }, - { url = "https://files.pythonhosted.org/packages/3e/84/d1614dedd8fe5114f6a0e348bcd1535f97d76c038d6102f271433cd1361d/pydantic_core-2.33.0-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:c860773a0f205926172c6644c394e02c25421dc9a456deff16f64c0e299487d3", size = 2249468, upload-time = "2025-03-26T20:28:13.651Z" }, - { url = "https://files.pythonhosted.org/packages/b0/c0/787061eef44135e00fddb4b56b387a06c303bfd3884a6df9bea5cb730230/pydantic_core-2.33.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:138d31e3f90087f42aa6286fb640f3c7a8eb7bdae829418265e7e7474bd2574b", size = 2254716, upload-time = "2025-03-26T20:28:16.105Z" }, - { url = "https://files.pythonhosted.org/packages/ae/e2/27262eb04963201e89f9c280f1e10c493a7a37bc877e023f31aa72d2f911/pydantic_core-2.33.0-cp313-cp313-win32.whl", hash = "sha256:d20cbb9d3e95114325780f3cfe990f3ecae24de7a2d75f978783878cce2ad585", size = 1916450, upload-time = "2025-03-26T20:28:18.252Z" }, - { url = "https://files.pythonhosted.org/packages/13/8d/25ff96f1e89b19e0b70b3cd607c9ea7ca27e1dcb810a9cd4255ed6abf869/pydantic_core-2.33.0-cp313-cp313-win_amd64.whl", hash = "sha256:ca1103d70306489e3d006b0f79db8ca5dd3c977f6f13b2c59ff745249431a606", size = 1956092, upload-time = "2025-03-26T20:28:20.129Z" }, - { url = "https://files.pythonhosted.org/packages/1b/64/66a2efeff657b04323ffcd7b898cb0354d36dae3a561049e092134a83e9c/pydantic_core-2.33.0-cp313-cp313-win_arm64.whl", hash = "sha256:6291797cad239285275558e0a27872da735b05c75d5237bbade8736f80e4c225", size = 1908367, upload-time = "2025-03-26T20:28:22.498Z" }, - { url = "https://files.pythonhosted.org/packages/52/54/295e38769133363d7ec4a5863a4d579f331728c71a6644ff1024ee529315/pydantic_core-2.33.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:7b79af799630af263eca9ec87db519426d8c9b3be35016eddad1832bac812d87", size = 1813331, upload-time = "2025-03-26T20:28:25.004Z" }, - { url = "https://files.pythonhosted.org/packages/4c/9c/0c8ea02db8d682aa1ef48938abae833c1d69bdfa6e5ec13b21734b01ae70/pydantic_core-2.33.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eabf946a4739b5237f4f56d77fa6668263bc466d06a8036c055587c130a46f7b", size = 1986653, upload-time = "2025-03-26T20:28:27.02Z" }, - { url = "https://files.pythonhosted.org/packages/8e/4f/3fb47d6cbc08c7e00f92300e64ba655428c05c56b8ab6723bd290bae6458/pydantic_core-2.33.0-cp313-cp313t-win_amd64.whl", hash = "sha256:8a1d581e8cdbb857b0e0e81df98603376c1a5c34dc5e54039dcc00f043df81e7", size = 1931234, upload-time = "2025-03-26T20:28:29.237Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/ad/88/5f2260bdfae97aabf98f1778d43f69574390ad787afb646292a638c923d4/pydantic_core-2.33.2.tar.gz", hash = "sha256:7cb8bc3605c29176e1b105350d2e6474142d7c1bd1d9327c4a9bdb46bf827acc", size = 435195, upload-time = "2025-04-23T18:33:52.104Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/18/8a/2b41c97f554ec8c71f2a8a5f85cb56a8b0956addfe8b0efb5b3d77e8bdc3/pydantic_core-2.33.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a7ec89dc587667f22b6a0b6579c249fca9026ce7c333fc142ba42411fa243cdc", size = 2009000, upload-time = "2025-04-23T18:31:25.863Z" }, + { url = "https://files.pythonhosted.org/packages/a1/02/6224312aacb3c8ecbaa959897af57181fb6cf3a3d7917fd44d0f2917e6f2/pydantic_core-2.33.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3c6db6e52c6d70aa0d00d45cdb9b40f0433b96380071ea80b09277dba021ddf7", size = 1847996, upload-time = "2025-04-23T18:31:27.341Z" }, + { url = "https://files.pythonhosted.org/packages/d6/46/6dcdf084a523dbe0a0be59d054734b86a981726f221f4562aed313dbcb49/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e61206137cbc65e6d5256e1166f88331d3b6238e082d9f74613b9b765fb9025", size = 1880957, upload-time = "2025-04-23T18:31:28.956Z" }, + { url = "https://files.pythonhosted.org/packages/ec/6b/1ec2c03837ac00886ba8160ce041ce4e325b41d06a034adbef11339ae422/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb8c529b2819c37140eb51b914153063d27ed88e3bdc31b71198a198e921e011", size = 1964199, upload-time = "2025-04-23T18:31:31.025Z" }, + { url = "https://files.pythonhosted.org/packages/2d/1d/6bf34d6adb9debd9136bd197ca72642203ce9aaaa85cfcbfcf20f9696e83/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c52b02ad8b4e2cf14ca7b3d918f3eb0ee91e63b3167c32591e57c4317e134f8f", size = 2120296, upload-time = "2025-04-23T18:31:32.514Z" }, + { url = "https://files.pythonhosted.org/packages/e0/94/2bd0aaf5a591e974b32a9f7123f16637776c304471a0ab33cf263cf5591a/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:96081f1605125ba0855dfda83f6f3df5ec90c61195421ba72223de35ccfb2f88", size = 2676109, upload-time = "2025-04-23T18:31:33.958Z" }, + { url = "https://files.pythonhosted.org/packages/f9/41/4b043778cf9c4285d59742281a769eac371b9e47e35f98ad321349cc5d61/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f57a69461af2a5fa6e6bbd7a5f60d3b7e6cebb687f55106933188e79ad155c1", size = 2002028, upload-time = "2025-04-23T18:31:39.095Z" }, + { url = "https://files.pythonhosted.org/packages/cb/d5/7bb781bf2748ce3d03af04d5c969fa1308880e1dca35a9bd94e1a96a922e/pydantic_core-2.33.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:572c7e6c8bb4774d2ac88929e3d1f12bc45714ae5ee6d9a788a9fb35e60bb04b", size = 2100044, upload-time = "2025-04-23T18:31:41.034Z" }, + { url = "https://files.pythonhosted.org/packages/fe/36/def5e53e1eb0ad896785702a5bbfd25eed546cdcf4087ad285021a90ed53/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:db4b41f9bd95fbe5acd76d89920336ba96f03e149097365afe1cb092fceb89a1", size = 2058881, upload-time = "2025-04-23T18:31:42.757Z" }, + { url = "https://files.pythonhosted.org/packages/01/6c/57f8d70b2ee57fc3dc8b9610315949837fa8c11d86927b9bb044f8705419/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:fa854f5cf7e33842a892e5c73f45327760bc7bc516339fda888c75ae60edaeb6", size = 2227034, upload-time = "2025-04-23T18:31:44.304Z" }, + { url = "https://files.pythonhosted.org/packages/27/b9/9c17f0396a82b3d5cbea4c24d742083422639e7bb1d5bf600e12cb176a13/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5f483cfb75ff703095c59e365360cb73e00185e01aaea067cd19acffd2ab20ea", size = 2234187, upload-time = "2025-04-23T18:31:45.891Z" }, + { url = "https://files.pythonhosted.org/packages/b0/6a/adf5734ffd52bf86d865093ad70b2ce543415e0e356f6cacabbc0d9ad910/pydantic_core-2.33.2-cp312-cp312-win32.whl", hash = "sha256:9cb1da0f5a471435a7bc7e439b8a728e8b61e59784b2af70d7c169f8dd8ae290", size = 1892628, upload-time = "2025-04-23T18:31:47.819Z" }, + { url = "https://files.pythonhosted.org/packages/43/e4/5479fecb3606c1368d496a825d8411e126133c41224c1e7238be58b87d7e/pydantic_core-2.33.2-cp312-cp312-win_amd64.whl", hash = "sha256:f941635f2a3d96b2973e867144fde513665c87f13fe0e193c158ac51bfaaa7b2", size = 1955866, upload-time = "2025-04-23T18:31:49.635Z" }, + { url = "https://files.pythonhosted.org/packages/0d/24/8b11e8b3e2be9dd82df4b11408a67c61bb4dc4f8e11b5b0fc888b38118b5/pydantic_core-2.33.2-cp312-cp312-win_arm64.whl", hash = "sha256:cca3868ddfaccfbc4bfb1d608e2ccaaebe0ae628e1416aeb9c4d88c001bb45ab", size = 1888894, upload-time = "2025-04-23T18:31:51.609Z" }, + { url = "https://files.pythonhosted.org/packages/46/8c/99040727b41f56616573a28771b1bfa08a3d3fe74d3d513f01251f79f172/pydantic_core-2.33.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:1082dd3e2d7109ad8b7da48e1d4710c8d06c253cbc4a27c1cff4fbcaa97a9e3f", size = 2015688, upload-time = "2025-04-23T18:31:53.175Z" }, + { url = "https://files.pythonhosted.org/packages/3a/cc/5999d1eb705a6cefc31f0b4a90e9f7fc400539b1a1030529700cc1b51838/pydantic_core-2.33.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f517ca031dfc037a9c07e748cefd8d96235088b83b4f4ba8939105d20fa1dcd6", size = 1844808, upload-time = "2025-04-23T18:31:54.79Z" }, + { url = "https://files.pythonhosted.org/packages/6f/5e/a0a7b8885c98889a18b6e376f344da1ef323d270b44edf8174d6bce4d622/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a9f2c9dd19656823cb8250b0724ee9c60a82f3cdf68a080979d13092a3b0fef", size = 1885580, upload-time = "2025-04-23T18:31:57.393Z" }, + { url = "https://files.pythonhosted.org/packages/3b/2a/953581f343c7d11a304581156618c3f592435523dd9d79865903272c256a/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2b0a451c263b01acebe51895bfb0e1cc842a5c666efe06cdf13846c7418caa9a", size = 1973859, upload-time = "2025-04-23T18:31:59.065Z" }, + { url = "https://files.pythonhosted.org/packages/e6/55/f1a813904771c03a3f97f676c62cca0c0a4138654107c1b61f19c644868b/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ea40a64d23faa25e62a70ad163571c0b342b8bf66d5fa612ac0dec4f069d916", size = 2120810, upload-time = "2025-04-23T18:32:00.78Z" }, + { url = "https://files.pythonhosted.org/packages/aa/c3/053389835a996e18853ba107a63caae0b9deb4a276c6b472931ea9ae6e48/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0fb2d542b4d66f9470e8065c5469ec676978d625a8b7a363f07d9a501a9cb36a", size = 2676498, upload-time = "2025-04-23T18:32:02.418Z" }, + { url = "https://files.pythonhosted.org/packages/eb/3c/f4abd740877a35abade05e437245b192f9d0ffb48bbbbd708df33d3cda37/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fdac5d6ffa1b5a83bca06ffe7583f5576555e6c8b3a91fbd25ea7780f825f7d", size = 2000611, upload-time = "2025-04-23T18:32:04.152Z" }, + { url = "https://files.pythonhosted.org/packages/59/a7/63ef2fed1837d1121a894d0ce88439fe3e3b3e48c7543b2a4479eb99c2bd/pydantic_core-2.33.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:04a1a413977ab517154eebb2d326da71638271477d6ad87a769102f7c2488c56", size = 2107924, upload-time = "2025-04-23T18:32:06.129Z" }, + { url = "https://files.pythonhosted.org/packages/04/8f/2551964ef045669801675f1cfc3b0d74147f4901c3ffa42be2ddb1f0efc4/pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:c8e7af2f4e0194c22b5b37205bfb293d166a7344a5b0d0eaccebc376546d77d5", size = 2063196, upload-time = "2025-04-23T18:32:08.178Z" }, + { url = "https://files.pythonhosted.org/packages/26/bd/d9602777e77fc6dbb0c7db9ad356e9a985825547dce5ad1d30ee04903918/pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:5c92edd15cd58b3c2d34873597a1e20f13094f59cf88068adb18947df5455b4e", size = 2236389, upload-time = "2025-04-23T18:32:10.242Z" }, + { url = "https://files.pythonhosted.org/packages/42/db/0e950daa7e2230423ab342ae918a794964b053bec24ba8af013fc7c94846/pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:65132b7b4a1c0beded5e057324b7e16e10910c106d43675d9bd87d4f38dde162", size = 2239223, upload-time = "2025-04-23T18:32:12.382Z" }, + { url = "https://files.pythonhosted.org/packages/58/4d/4f937099c545a8a17eb52cb67fe0447fd9a373b348ccfa9a87f141eeb00f/pydantic_core-2.33.2-cp313-cp313-win32.whl", hash = "sha256:52fb90784e0a242bb96ec53f42196a17278855b0f31ac7c3cc6f5c1ec4811849", size = 1900473, upload-time = "2025-04-23T18:32:14.034Z" }, + { url = "https://files.pythonhosted.org/packages/a0/75/4a0a9bac998d78d889def5e4ef2b065acba8cae8c93696906c3a91f310ca/pydantic_core-2.33.2-cp313-cp313-win_amd64.whl", hash = "sha256:c083a3bdd5a93dfe480f1125926afcdbf2917ae714bdb80b36d34318b2bec5d9", size = 1955269, upload-time = "2025-04-23T18:32:15.783Z" }, + { url = "https://files.pythonhosted.org/packages/f9/86/1beda0576969592f1497b4ce8e7bc8cbdf614c352426271b1b10d5f0aa64/pydantic_core-2.33.2-cp313-cp313-win_arm64.whl", hash = "sha256:e80b087132752f6b3d714f041ccf74403799d3b23a72722ea2e6ba2e892555b9", size = 1893921, upload-time = "2025-04-23T18:32:18.473Z" }, + { url = "https://files.pythonhosted.org/packages/a4/7d/e09391c2eebeab681df2b74bfe6c43422fffede8dc74187b2b0bf6fd7571/pydantic_core-2.33.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:61c18fba8e5e9db3ab908620af374db0ac1baa69f0f32df4f61ae23f15e586ac", size = 1806162, upload-time = "2025-04-23T18:32:20.188Z" }, + { url = "https://files.pythonhosted.org/packages/f1/3d/847b6b1fed9f8ed3bb95a9ad04fbd0b212e832d4f0f50ff4d9ee5a9f15cf/pydantic_core-2.33.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95237e53bb015f67b63c91af7518a62a8660376a6a0db19b89acc77a4d6199f5", size = 1981560, upload-time = "2025-04-23T18:32:22.354Z" }, + { url = "https://files.pythonhosted.org/packages/6f/9a/e73262f6c6656262b5fdd723ad90f518f579b7bc8622e43a942eec53c938/pydantic_core-2.33.2-cp313-cp313t-win_amd64.whl", hash = "sha256:c2fc0a768ef76c15ab9238afa6da7f69895bb5d1ee83aeea2e3509af4472d0b9", size = 1935777, upload-time = "2025-04-23T18:32:25.088Z" }, +] + +[[package]] +name = "pydantic-settings" +version = "2.11.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "python-dotenv" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/20/c5/dbbc27b814c71676593d1c3f718e6cd7d4f00652cefa24b75f7aa3efb25e/pydantic_settings-2.11.0.tar.gz", hash = "sha256:d0e87a1c7d33593beb7194adb8470fc426e95ba02af83a0f23474a04c9a08180", size = 188394, upload-time = "2025-09-24T14:19:11.764Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/83/d6/887a1ff844e64aa823fb4905978d882a633cfe295c32eacad582b78a7d8b/pydantic_settings-2.11.0-py3-none-any.whl", hash = "sha256:fe2cea3413b9530d10f3a5875adffb17ada5c1e1bab0b2885546d7310415207c", size = 48608, upload-time = "2025-09-24T14:19:10.015Z" }, ] [[package]] @@ -2881,6 +3257,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/08/20/0f2523b9e50a8052bc6a8b732dfc8568abbdc42010aef03a2d750bdab3b2/python_json_logger-3.3.0-py3-none-any.whl", hash = "sha256:dd980fae8cffb24c13caf6e158d3d61c0d6d22342f932cb6e9deedab3d35eec7", size = 15163, upload-time = "2025-03-07T07:08:25.627Z" }, ] +[[package]] +name = "python-multipart" +version = "0.0.20" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f3/87/f44d7c9f274c7ee665a29b885ec97089ec5dc034c7f3fafa03da9e39a09e/python_multipart-0.0.20.tar.gz", hash = "sha256:8dd0cab45b8e23064ae09147625994d090fa46f5b0d1e13af944c331a7fa9d13", size = 37158, upload-time = "2024-12-16T19:45:46.972Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/45/58/38b5afbc1a800eeea951b9285d3912613f2603bdf897a4ab0f4bd7f405fc/python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104", size = 24546, upload-time = "2024-12-16T19:45:44.423Z" }, +] + [[package]] name = "pytz" version = "2025.2" @@ -3031,6 +3416,32 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9e/51/17023c0f8f1869d8806b979a2bffa3f861f26a3f1a66b094288323fba52f/rfc3986_validator-0.1.1-py2.py3-none-any.whl", hash = "sha256:2f235c432ef459970b4306369336b9d5dbdda31b510ca1e327636e01f528bfa9", size = 4242, upload-time = "2019-10-28T16:00:13.976Z" }, ] +[[package]] +name = "rich" +version = "14.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fe/75/af448d8e52bf1d8fa6a9d089ca6c07ff4453d86c65c145d0a300bb073b9b/rich-14.1.0.tar.gz", hash = "sha256:e497a48b844b0320d45007cdebfeaeed8db2a4f4bcf49f15e455cfc4af11eaa8", size = 224441, upload-time = "2025-07-25T07:32:58.125Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e3/30/3c4d035596d3cf444529e0b2953ad0466f6049528a879d27534700580395/rich-14.1.0-py3-none-any.whl", hash = "sha256:536f5f1785986d6dbdea3c75205c473f970777b4a0d6c6dd1b696aa05a3fa04f", size = 243368, upload-time = "2025-07-25T07:32:56.73Z" }, +] + +[[package]] +name = "rich-rst" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "docutils" }, + { name = "rich" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b0/69/5514c3a87b5f10f09a34bb011bc0927bc12c596c8dae5915604e71abc386/rich_rst-1.3.1.tar.gz", hash = "sha256:fad46e3ba42785ea8c1785e2ceaa56e0ffa32dbe5410dec432f37e4107c4f383", size = 13839, upload-time = "2024-04-30T04:40:38.125Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fd/bc/cc4e3dbc5e7992398dcb7a8eda0cbcf4fb792a0cdb93f857b478bf3cf884/rich_rst-1.3.1-py3-none-any.whl", hash = "sha256:498a74e3896507ab04492d326e794c3ef76e7cda078703aa592d1853d91098c1", size = 11621, upload-time = "2024-04-30T04:40:32.619Z" }, +] + [[package]] name = "rpds-py" version = "0.24.0" @@ -3306,7 +3717,7 @@ requires-dist = [ { name = "jaxtyping", marker = "extra == 'dev'" }, { name = "jupyter" }, { name = "matplotlib" }, - { name = "mlflow" }, + { name = "mlflow", specifier = ">=3.0.0" }, { name = "nbqa", marker = "extra == 'dev'" }, { name = "optax" }, { name = "orbax-checkpoint" }, @@ -3396,6 +3807,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a9/5c/bfd6bd0bf979426d405cc6e71eceb8701b148b16c21d2dc3c261efc61c7b/sqlparse-0.5.3-py3-none-any.whl", hash = "sha256:cf2196ed3418f3ba5de6af7e82c694a9fbdbfecccdfc72e281548517081f16ca", size = 44415, upload-time = "2024-12-10T12:05:27.824Z" }, ] +[[package]] +name = "sse-starlette" +version = "3.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/6f/22ed6e33f8a9e76ca0a412405f31abb844b779d52c5f96660766edcd737c/sse_starlette-3.0.2.tar.gz", hash = "sha256:ccd60b5765ebb3584d0de2d7a6e4f745672581de4f5005ab31c3a25d10b52b3a", size = 20985, upload-time = "2025-07-27T09:07:44.565Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/10/c78f463b4ef22eef8491f218f692be838282cd65480f6e423d7730dfd1fb/sse_starlette-3.0.2-py3-none-any.whl", hash = "sha256:16b7cbfddbcd4eaca11f7b586f3b8a080f1afe952c15813455b162edea619e5a", size = 11297, upload-time = "2025-07-27T09:07:43.268Z" }, +] + [[package]] name = "stack-data" version = "0.6.3" @@ -3783,14 +4206,14 @@ wheels = [ [[package]] name = "werkzeug" -version = "3.1.3" +version = "3.1.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "markupsafe" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9f/69/83029f1f6300c5fb2471d621ab06f6ec6b3324685a2ce0f9777fd4a8b71e/werkzeug-3.1.3.tar.gz", hash = "sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746", size = 806925, upload-time = "2024-11-08T15:52:18.093Z" } +sdist = { url = "https://files.pythonhosted.org/packages/32/af/d4502dc713b4ccea7175d764718d5183caf8d0867a4f0190d5d4a45cea49/werkzeug-3.1.1.tar.gz", hash = "sha256:8cd39dfbdfc1e051965f156163e2974e52c210f130810e9ad36858f0fd3edad4", size = 806453, upload-time = "2024-11-01T16:40:45.462Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/52/24/ab44c871b0f07f491e5d2ad12c9bd7358e527510618cb1b803a88e986db1/werkzeug-3.1.3-py3-none-any.whl", hash = "sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e", size = 224498, upload-time = "2024-11-08T15:52:16.132Z" }, + { url = "https://files.pythonhosted.org/packages/ee/ea/c67e1dee1ba208ed22c06d1d547ae5e293374bfc43e0eb0ef5e262b68561/werkzeug-3.1.1-py3-none-any.whl", hash = "sha256:a71124d1ef06008baafa3d266c02f56e1836a5984afd6dd6c9230669d60d9fb5", size = 224371, upload-time = "2024-11-01T16:40:43.994Z" }, ] [[package]]