diff --git a/README.md b/README.md index d09b338..4da627d 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ Every time you run an experiment, your code will be auto-committed with metadata * Query your scientific log, e.g. `mthd query metrics.accuracy < 0.8`. ```python -from mthd import commit +from mthd import commit, Run from pydantic import BaseModel class Hypers(BaseModel): @@ -39,15 +39,18 @@ class Metrics(BaseModel): accuracy: float -@commit(hypers="hypers") -def my_experiment(hypers: Hypers) -> Metrics: +@commit +def my_experiment(run: Run) -> Metrics: ... - # experiment + + run.set_hyperparameters({ ... }) + ... - metrics = Metrics(...) + run.set_metrics({ ... }) - return metrics +if __name__ == "__main__": + my_experiment() ``` Then run your experiment: diff --git a/mthd/decorator.py b/mthd/decorator.py index 7180e50..ba1fd1f 100644 --- a/mthd/decorator.py +++ b/mthd/decorator.py @@ -1,7 +1,8 @@ import os +from dataclasses import dataclass from functools import wraps -from typing import Callable, Optional, cast +from typing import Callable, Concatenate, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload from pydantic import BaseModel from rich.console import Console @@ -14,39 +15,110 @@ from mthd.util.di import DI +@dataclass +class Run: + """Tracks experiment hyperparameters and metrics.""" + + _hypers: Optional[dict] = None + _metrics: Optional[dict] = None + + def set_hyperparameters(self, hypers: dict) -> None: + """Set hyperparameters manually.""" + self._hypers = hypers + + def set_metrics(self, metrics: dict) -> None: + """Set metrics manually.""" + self._metrics = metrics + + @property + def hyperparameters(self) -> Optional[dict]: + return self._hypers + + @property + def metrics(self) -> Optional[dict]: + return self._metrics + + +P = ParamSpec("P") +R = TypeVar("R") +T = TypeVar("T") + + +@overload +def commit( + fn: None = None, + *, + hypers: str = "hypers", + template: str = "run {experiment}", + strategy: StageStrategy = StageStrategy.ALL, + implicit: Literal[False] = False, +) -> Callable[[Callable[Concatenate[Run, P], R]], Callable[P, R]]: ... + + +@overload def commit( - fn: Optional[Callable] = None, + fn: None = None, *, hypers: str = "hypers", template: str = "run {experiment}", strategy: StageStrategy = StageStrategy.ALL, -) -> Callable: + implicit: Literal[True], +) -> Callable[[Callable[P, R]], Callable[P, R]]: ... + + +@overload +def commit( + fn: Callable[P, R], +) -> Callable[Concatenate[Run, P], R]: ... + + +def commit( + fn: Optional[Callable[..., R]] = None, + *, + hypers: str = "hypers", + template: str = "run {experiment}", + strategy: StageStrategy = StageStrategy.ALL, + implicit: bool = False, +) -> Union[Callable[[Callable[..., R]], Callable[..., R]], Callable[..., R]]: """Decorator to auto-commit experimental code with scientific metadata. Can be used as @commit or @commit(message="Custom message") """ - def decorator(func: Callable) -> Callable: + def decorator(func: Callable[..., R]) -> Callable[..., R]: @wraps(func) def wrapper(*args, **kwargs): di = DI() - # @todo: handle this better (eg. positional args) - hyperparameters = cast(BaseModel, kwargs.get(hypers, None)) - if not hyperparameters: - raise MthdError("Hyperparameters must be provided in the function call.") git_service = di[GitService] - # codebase_service = di[CodebaseService] - # Generate commit message + if not implicit: + context = Run() + args = [context] + list(args) + metrics = func(*args, **kwargs) + + if context.hyperparameters is None: + raise MthdError("When using context, hyperparameters must be set via the Context object") + if context.metrics is None: + raise MthdError("When using context, metrics must be set via the Context object") - # Run experiment - metrics = func(*args, **kwargs) + hyper_dict = context.hyperparameters + metric_dict = context.metrics + else: + metrics = func(*args, **kwargs) + + hyperparameters = cast(BaseModel, kwargs.get(hypers, None)) + if not hyperparameters: + raise MthdError("When not using context, hyperparameters must be provided as function arguments") + if not isinstance(metrics, BaseModel): + raise MthdError("When not using context, metrics must be returned as a BaseModel") + + hyper_dict = hyperparameters.model_dump() + metric_dict = metrics.model_dump() experiment = ExperimentRun( experiment=func.__name__, - hyperparameters=hyperparameters.model_dump(), - metrics=metrics.model_dump(), - # annotations=codebase_service.get_all_annotations(), + hyperparameters=hyper_dict, + metrics=metric_dict, ) message = experiment.as_commit_message(template=template) @@ -82,9 +154,18 @@ class Metrics(BaseModel): b: float c: str - @commit(hypers="hypers", template="run {experiment} at {timestamp}") - def test(hypers: Hyperparameters): - print("\n\n") + # Example using function arguments/return + @commit(hypers="hypers", implicit=True, template="run {experiment} at {timestamp}") + def test1(hypers: Hyperparameters): + print("\n\n") return Metrics(a=1, b=2.0, c="3") - test(hypers=Hyperparameters(a=1, b=2.0, c="3")) + # Example using Context object + @commit() + def test2(run: Run, foo: int): + print("\n\n") + run.set_hyperparameters({"a": 1, "b": 2.0, "c": "3"}) + run.set_metrics({"a": 1, "b": 2.0, "c": "3"}) + + test1(hypers=Hyperparameters(a=1, b=2.0, c="3")) + test2(5) # No need to pass context diff --git a/tests/e2e/e2e_test.py b/tests/e2e/e2e_test.py index bf6a688..9cd58bc 100644 --- a/tests/e2e/e2e_test.py +++ b/tests/e2e/e2e_test.py @@ -62,7 +62,7 @@ class Metrics(BaseModel): accuracy: float loss: float -@commit(hypers="hypers") +@commit(hypers="hypers", implicit=True) def train_model(hypers: Hyperparameters) -> Metrics: # Simulate training accuracy = 0.75 + ({iteration} * 0.05) # Gradually improve accuracy @@ -86,6 +86,39 @@ def train_model(hypers: Hyperparameters) -> Metrics: f.write(content) +def create_run_experiment_file(temp_dir: Path, iteration: int) -> None: + """Create or update the experiment file using run-based API.""" + content = f""" +from mthd import commit +from mthd.decorator import Run + +@commit +def train_model(run: Run): + # Set hyperparameters + run.set_hyperparameters({{ + "learning_rate": 0.001 * ({iteration} + 1), + "batch_size": 32 * ({iteration} + 1), + "epochs": 10 * ({iteration} + 1) + }}) + + # Simulate training + accuracy = 0.75 + ({iteration} * 0.05) # Gradually improve accuracy + loss = 0.5 - ({iteration} * 0.1) # Gradually decrease loss + + # Set metrics + run.set_metrics({{ + "accuracy": accuracy, + "loss": max(0.1, loss) + }}) + +if __name__ == "__main__": + train_model() +""" + + with open(temp_dir / "experiment.py", "w") as f: + f.write(content) + + def test_multiple_experiments(temp_dir: Path): """Test running multiple experiments and creating commits.""" # Get the path to the Python executable in the virtual environment @@ -128,7 +161,72 @@ def test_multiple_experiments(temp_dir: Path): text=True, cwd=temp_dir, ) - print(result.stdout) + + assert result.returncode == 0 + + # Should find 2 commits (iterations 2 and 4 have accuracy > 0.8) + output_lines = result.stdout.strip().split("\n") + assert len(output_lines) > 0 + assert "Found 2 commit(s)" in output_lines[0] + + # Test querying for experiments with low loss + result = subprocess.run( + [str(mthd_path), "query", "metrics.loss < 0.5"], + capture_output=True, + text=True, + cwd=temp_dir, + ) + + assert result.returncode == 0 + + # Should find 3 commits (iterations 2, 3, and 4 have loss < 0.5) + output_lines = result.stdout.strip().split("\n") + assert len(output_lines) > 0 + assert "Found 3 commit(s)" in output_lines[0] + + +def test_multiple_run_experiments(temp_dir: Path): + """Test running multiple experiments using run-based API.""" + # Get the path to the Python executable in the virtual environment + if os.name == "nt": # Windows + python_path = temp_dir / ".venv" / "Scripts" / "python.exe" + mthd_path = temp_dir / ".venv" / "Scripts" / "mthd.exe" + else: # Unix-like + python_path = temp_dir / ".venv" / "bin" / "python" + mthd_path = temp_dir / ".venv" / "bin" / "mthd" + + # Run multiple iterations of the experiment + for i in range(4): + create_run_experiment_file(temp_dir, i) + + # Run the experiment using the virtualenv python + subprocess.run( + [str(python_path), str(temp_dir / "experiment.py")], + check=True, + ) + + # Verify the commits + repo = git.Repo(temp_dir) + commits = list(repo.iter_commits()) + + # Should have 4 commits + assert len(commits) == 4 + + # All commits should be experiment commits + for commit in commits: + message = commit.message if isinstance(commit.message, str) else commit.message.decode("utf-8") + + assert message.startswith("exp: ") + assert "metrics" in message + assert "hyperparameters" in message + + # Test querying for experiments with high accuracy + result = subprocess.run( + [str(mthd_path), "query", "metrics.accuracy > 0.8"], + capture_output=True, + text=True, + cwd=temp_dir, + ) assert result.returncode == 0