Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
faeee50
feat: Add Run object for manual hyperparameters and metrics tracking
rorybyrne Feb 10, 2025
e916d1a
refactor: Enhance commit decorator with optional context and improved…
rorybyrne Feb 10, 2025
0465b41
refactor: Modify commit decorator to support context-based and direct…
rorybyrne Feb 10, 2025
411d262
refactor: Rename Run to Context and update method signatures
rorybyrne Feb 10, 2025
8f2f0cb
refactor: Update context injection in decorator to improve UX
rorybyrne Feb 10, 2025
7128a19
refactor: Update context usage to run in commit decorator
rorybyrne Feb 10, 2025
5e0c131
refactor: Improve type hints for commit decorator with ParamSpec
rorybyrne Feb 10, 2025
539fb6b
refactor: Fix ParamSpec typing and decorator signature for commit fun…
rorybyrne Feb 10, 2025
7bf073d
refactor: Fix context injection and type hints in decorator
rorybyrne Feb 10, 2025
4c45440
feat: Add overloaded commit decorator with improved type hints and UX
rorybyrne Feb 10, 2025
ecfc83b
refactor: Improve type hints and formatting in commit decorator
rorybyrne Feb 10, 2025
b8e80d1
fix: Correct type hints for commit decorator to resolve pyright error
rorybyrne Feb 10, 2025
4252e46
feat: Add e2e test for context-based experiment tracking with metrics…
rorybyrne Feb 14, 2025
94f8d43
refactor: Update test2 function signature to include foo parameter
rorybyrne Feb 14, 2025
4a08032
feat: Improve type hints for commit decorator with Concatenate
rorybyrne Feb 14, 2025
5fa9725
fix: Simplify decorator typing and remove conflicting function signature
rorybyrne Feb 14, 2025
c0eae37
refactor: Improve type hints for commit decorator using ParamSpec and…
rorybyrne Feb 14, 2025
77ce27b
refactor: rename Context class to Run and improve decorator API seman…
rorybyrne Feb 14, 2025
26d6002
docs: update README with new Run API and improve code example clarity
rorybyrne Feb 14, 2025
4bbc56f
docs: simplify example code in README by removing unnecessary parameters
rorybyrne Feb 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Every time you run an experiment, your code will be auto-committed with metadata
* Query your scientific log, e.g. `mthd query metrics.accuracy < 0.8`.

```python
from mthd import commit
from mthd import commit, Run
from pydantic import BaseModel

class Hypers(BaseModel):
Expand All @@ -39,15 +39,18 @@ class Metrics(BaseModel):
accuracy: float


@commit(hypers="hypers")
def my_experiment(hypers: Hypers) -> Metrics:
@commit
def my_experiment(run: Run) -> Metrics:
...
# experiment

run.set_hyperparameters({ ... })

...

metrics = Metrics(...)
run.set_metrics({ ... })

return metrics
if __name__ == "__main__":
my_experiment()
```

Then run your experiment:
Expand Down
119 changes: 100 additions & 19 deletions mthd/decorator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os

from dataclasses import dataclass
from functools import wraps
from typing import Callable, Optional, cast
from typing import Callable, Concatenate, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload

from pydantic import BaseModel
from rich.console import Console
Expand All @@ -14,39 +15,110 @@
from mthd.util.di import DI


@dataclass
class Run:
"""Tracks experiment hyperparameters and metrics."""

_hypers: Optional[dict] = None
_metrics: Optional[dict] = None

def set_hyperparameters(self, hypers: dict) -> None:
"""Set hyperparameters manually."""
self._hypers = hypers

def set_metrics(self, metrics: dict) -> None:
"""Set metrics manually."""
self._metrics = metrics

@property
def hyperparameters(self) -> Optional[dict]:
return self._hypers

@property
def metrics(self) -> Optional[dict]:
return self._metrics


P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")


@overload
def commit(
fn: None = None,
*,
hypers: str = "hypers",
template: str = "run {experiment}",
strategy: StageStrategy = StageStrategy.ALL,
implicit: Literal[False] = False,
) -> Callable[[Callable[Concatenate[Run, P], R]], Callable[P, R]]: ...


@overload
def commit(
fn: Optional[Callable] = None,
fn: None = None,
*,
hypers: str = "hypers",
template: str = "run {experiment}",
strategy: StageStrategy = StageStrategy.ALL,
) -> Callable:
implicit: Literal[True],
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...


@overload
def commit(
fn: Callable[P, R],
) -> Callable[Concatenate[Run, P], R]: ...


def commit(
fn: Optional[Callable[..., R]] = None,
*,
hypers: str = "hypers",
template: str = "run {experiment}",
strategy: StageStrategy = StageStrategy.ALL,
implicit: bool = False,
) -> Union[Callable[[Callable[..., R]], Callable[..., R]], Callable[..., R]]:
"""Decorator to auto-commit experimental code with scientific metadata.

Can be used as @commit or @commit(message="Custom message")
"""

def decorator(func: Callable) -> Callable:
def decorator(func: Callable[..., R]) -> Callable[..., R]:
@wraps(func)
def wrapper(*args, **kwargs):
di = DI()
# @todo: handle this better (eg. positional args)
hyperparameters = cast(BaseModel, kwargs.get(hypers, None))
if not hyperparameters:
raise MthdError("Hyperparameters must be provided in the function call.")
git_service = di[GitService]
# codebase_service = di[CodebaseService]

# Generate commit message
if not implicit:
context = Run()
args = [context] + list(args)
metrics = func(*args, **kwargs)

if context.hyperparameters is None:
raise MthdError("When using context, hyperparameters must be set via the Context object")
if context.metrics is None:
raise MthdError("When using context, metrics must be set via the Context object")

# Run experiment
metrics = func(*args, **kwargs)
hyper_dict = context.hyperparameters
metric_dict = context.metrics
else:
metrics = func(*args, **kwargs)

hyperparameters = cast(BaseModel, kwargs.get(hypers, None))
if not hyperparameters:
raise MthdError("When not using context, hyperparameters must be provided as function arguments")
if not isinstance(metrics, BaseModel):
raise MthdError("When not using context, metrics must be returned as a BaseModel")

hyper_dict = hyperparameters.model_dump()
metric_dict = metrics.model_dump()

experiment = ExperimentRun(
experiment=func.__name__,
hyperparameters=hyperparameters.model_dump(),
metrics=metrics.model_dump(),
# annotations=codebase_service.get_all_annotations(),
hyperparameters=hyper_dict,
metrics=metric_dict,
)
message = experiment.as_commit_message(template=template)

Expand Down Expand Up @@ -82,9 +154,18 @@ class Metrics(BaseModel):
b: float
c: str

@commit(hypers="hypers", template="run {experiment} at {timestamp}")
def test(hypers: Hyperparameters):
print("\n<Experiment goes here>\n")
# Example using function arguments/return
@commit(hypers="hypers", implicit=True, template="run {experiment} at {timestamp}")
def test1(hypers: Hyperparameters):
print("\n<Experiment 1 goes here>\n")
return Metrics(a=1, b=2.0, c="3")

test(hypers=Hyperparameters(a=1, b=2.0, c="3"))
# Example using Context object
@commit()
def test2(run: Run, foo: int):
print("\n<Experiment 2 goes here>\n")
run.set_hyperparameters({"a": 1, "b": 2.0, "c": "3"})
run.set_metrics({"a": 1, "b": 2.0, "c": "3"})

test1(hypers=Hyperparameters(a=1, b=2.0, c="3"))
test2(5) # No need to pass context
102 changes: 100 additions & 2 deletions tests/e2e/e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class Metrics(BaseModel):
accuracy: float
loss: float

@commit(hypers="hypers")
@commit(hypers="hypers", implicit=True)
def train_model(hypers: Hyperparameters) -> Metrics:
# Simulate training
accuracy = 0.75 + ({iteration} * 0.05) # Gradually improve accuracy
Expand All @@ -86,6 +86,39 @@ def train_model(hypers: Hyperparameters) -> Metrics:
f.write(content)


def create_run_experiment_file(temp_dir: Path, iteration: int) -> None:
"""Create or update the experiment file using run-based API."""
content = f"""
from mthd import commit
from mthd.decorator import Run

@commit
def train_model(run: Run):
# Set hyperparameters
run.set_hyperparameters({{
"learning_rate": 0.001 * ({iteration} + 1),
"batch_size": 32 * ({iteration} + 1),
"epochs": 10 * ({iteration} + 1)
}})

# Simulate training
accuracy = 0.75 + ({iteration} * 0.05) # Gradually improve accuracy
loss = 0.5 - ({iteration} * 0.1) # Gradually decrease loss

# Set metrics
run.set_metrics({{
"accuracy": accuracy,
"loss": max(0.1, loss)
}})

if __name__ == "__main__":
train_model()
"""

with open(temp_dir / "experiment.py", "w") as f:
f.write(content)


def test_multiple_experiments(temp_dir: Path):
"""Test running multiple experiments and creating commits."""
# Get the path to the Python executable in the virtual environment
Expand Down Expand Up @@ -128,7 +161,72 @@ def test_multiple_experiments(temp_dir: Path):
text=True,
cwd=temp_dir,
)
print(result.stdout)

assert result.returncode == 0

# Should find 2 commits (iterations 2 and 4 have accuracy > 0.8)
output_lines = result.stdout.strip().split("\n")
assert len(output_lines) > 0
assert "Found 2 commit(s)" in output_lines[0]

# Test querying for experiments with low loss
result = subprocess.run(
[str(mthd_path), "query", "metrics.loss < 0.5"],
capture_output=True,
text=True,
cwd=temp_dir,
)

assert result.returncode == 0

# Should find 3 commits (iterations 2, 3, and 4 have loss < 0.5)
output_lines = result.stdout.strip().split("\n")
assert len(output_lines) > 0
assert "Found 3 commit(s)" in output_lines[0]


def test_multiple_run_experiments(temp_dir: Path):
"""Test running multiple experiments using run-based API."""
# Get the path to the Python executable in the virtual environment
if os.name == "nt": # Windows
python_path = temp_dir / ".venv" / "Scripts" / "python.exe"
mthd_path = temp_dir / ".venv" / "Scripts" / "mthd.exe"
else: # Unix-like
python_path = temp_dir / ".venv" / "bin" / "python"
mthd_path = temp_dir / ".venv" / "bin" / "mthd"

# Run multiple iterations of the experiment
for i in range(4):
create_run_experiment_file(temp_dir, i)

# Run the experiment using the virtualenv python
subprocess.run(
[str(python_path), str(temp_dir / "experiment.py")],
check=True,
)

# Verify the commits
repo = git.Repo(temp_dir)
commits = list(repo.iter_commits())

# Should have 4 commits
assert len(commits) == 4

# All commits should be experiment commits
for commit in commits:
message = commit.message if isinstance(commit.message, str) else commit.message.decode("utf-8")

assert message.startswith("exp: ")
assert "metrics" in message
assert "hyperparameters" in message

# Test querying for experiments with high accuracy
result = subprocess.run(
[str(mthd_path), "query", "metrics.accuracy > 0.8"],
capture_output=True,
text=True,
cwd=temp_dir,
)

assert result.returncode == 0

Expand Down