Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions mthd/decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from functools import wraps
from typing import Callable, Optional, cast

from pydantic import BaseModel
from rich.console import Console
from rich.padding import Padding

from mthd.domain.commit import CommitMessage, StageStrategy
from mthd.error import MthdError


def commit(
fn: Optional[Callable] = None,
hypers: str = "hypers",
strategy: StageStrategy = StageStrategy.ALL,
) -> Callable:
"""Decorator to auto-commit experimental code with scientific metadata.

Can be used as @commit or @commit(message="Custom message")
"""

def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
# di = DI()
hyperparameters = cast(BaseModel, kwargs.get(hypers, None))
if not hyperparameters:
raise MthdError(
"Hyperparameters must be provided in the function call."
)
# git_service = di[GitService]
# codebase_service = di[CodebaseService]

# Generate commit message
commit_msg = CommitMessage(
summary="exp: foo bar baz",
hyperparameters=hyperparameters.model_dump(),
# annotations=codebase_service.get_all_annotations(),
)
# print(hyperparameters.model_dump_json(indent=2))
# print(commit_msg.format())

# Run experiment
result = func(*args, **kwargs)

# Commit changes
console = Console()
console.print("Generating commit with message:\n")
console.print(
Padding(commit_msg.format(), pad=(0, 0, 0, 4))
) # Indent by 4 spaces.
# if git_service.should_commit(strategy):
# git_service.stage_and_commit(commit_msg)

return result

return wrapper

if fn is None:
return decorator
return decorator(fn)


if __name__ == "__main__":

class Hyperparameters(BaseModel):
a: int
b: float
c: str

@commit(hypers="hypers")
def test(hypers: Hyperparameters):
print("<Experiment goes here>\n")

test(hypers=Hyperparameters(a=1, b=2.0, c="3"))
18 changes: 18 additions & 0 deletions mthd/domain/commit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from enum import Enum, auto

from mthd.util.model import Model


class CommitMessage(Model):
summary: str
hyperparameters: dict
# annotations: set[Annotation] # @todo: fix anot

def format(self) -> str:
return (
f"{self.summary}\n\n{self.model_dump_json(indent=2, exclude={'summary'})}"
)


class StageStrategy(Enum):
ALL = auto()
Empty file added mthd/domain/experiment.py
Empty file.
4 changes: 4 additions & 0 deletions mthd/domain/repository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from mthd.util.model import Model


class Repository(Model): ...
2 changes: 2 additions & 0 deletions mthd/error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class MthdError(Exception):
"""Something went wrong."""
Empty file added mthd/service/__init__.py
Empty file.
5 changes: 5 additions & 0 deletions mthd/service/codebase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from anot import Annotation


class CodebaseService:
def get_all_annotations(self) -> set[Annotation]: ...
1 change: 1 addition & 0 deletions mthd/service/experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
class ExperimentService: ...
29 changes: 29 additions & 0 deletions mthd/service/git.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import git

from mthd.domain.commit import CommitMessage, StageStrategy


class GitService:
def __init__(self, repo: git.Repo):
self._repo = repo

def stage_and_commit(self, message: CommitMessage):
"""Stage all changes and create a commit with the given message.

Args:
message: CommitMessage object containing commit metadata
"""
# Stage all changes
self._repo.git.add(A=True)

# Create commit with formatted message
self._repo.index.commit(message.format())

def should_commit(self, strategy: StageStrategy) -> bool:
"""Determine if the repo state can be staged and committed

Returns:
@todo: decide if the unstaged files are suitable for committing
"""
if strategy == StageStrategy.ALL:
return True
Empty file added mthd/util/__init__.py
Empty file.
47 changes: 47 additions & 0 deletions mthd/util/di.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Type, TypeVar

from dishka import Provider, Scope, make_container, provide
from git import Repo

from mthd.service.codebase import CodebaseService
from mthd.service.experiment import ExperimentService
from mthd.service.git import GitService

T = TypeVar("T")


class GitProvider(Provider):
@provide(scope=Scope.APP)
def provide_repo(self) -> Repo:
try:
return Repo()
except Exception as e:
raise RuntimeError(f"Failed to initialize Git repository: {e}")


class DI:
def __init__(self):
self._container = make_container(self.services, self.git)

@property
def container(self):
return self._container

@property
def core(self) -> Provider: ...

@property
def git(self) -> Provider:
return GitProvider()

@property
def services(self) -> Provider:
provider = Provider(scope=Scope.APP)
provider.provide(GitService)
provider.provide(ExperimentService)
provider.provide(CodebaseService)

return provider

def __getitem__(self, item: Type[T]) -> T:
return self._container.get(item)
5 changes: 5 additions & 0 deletions mthd/util/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from pydantic import BaseModel, ConfigDict


class Model(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
24 changes: 20 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@ description = "Git-based experiment tracking with semantic metadata."
authors = [
{ name = "Rory Byrne", email = "rory@rory.bio" }
]
requires-python = ">=3.9"
dependencies = []
requires-python = ">=3.10"
dependencies = [
"anot>=0.0.6",
"dishka>=1.4.2",
"gitpython>=3.1.44",
"pydantic>=2.10.5",
"rich>=13.9.4",
]
classifiers = [
"Development Status :: 3 - Alpha",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
Expand Down Expand Up @@ -43,8 +48,19 @@ source = "vcs"
dev-dependencies = [
"coverage>=7.6.9",
"pre-commit>=4.0.1",
"pyright>=1.1.390",
"pytest-mock>=3.14.0",
"pytest>=8.3.4",
"ruff>=0.8.2",
"pyright>=1.1.392.post0",
]

[tool.ruff]
src = ["mthd"]

[tool.ruff.lint]
extend-select = ["I"]

[tool.ruff.lint.isort]
known-first-party = ["mthd"]
relative-imports-order = "closest-to-furthest"
lines-between-types = 1
17 changes: 17 additions & 0 deletions tests/unit/domain/commit_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from pydantic import BaseModel

from mthd.domain.commit import CommitMessage


def test_commitmessage_format_success():
class Hypers(BaseModel):
a: int
b: float
c: str

msg = CommitMessage(
summary="test",
hyperparameters=Hypers(a=1, b=2.0, c="3").model_dump(),
)

print(msg.format())
Loading
Loading