Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,22 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.11"]
python-version: ["3.12"]

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install mypy pytest pytest-cov
pip install -r requirements.txt
- name: Install uv and set the python version
uses: astral-sh/setup-uv@v6
- name: mypy
run: |
mypy --install-types --non-interactive .
uv run mypy --install-types --non-interactive .
- name: pytest
if: always()
run: |
pytest --cov=ai_toolkit --cov-report= --durations=0 -k "not test_one_epoch"
uv run pytest --cov=ai_toolkit --cov-report= --durations=0 -k "not test_one_epoch"
- name: codecov
uses: codecov/codecov-action@v1
uses: codecov/codecov-action@v5
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,24 @@ ci:
skip: [mypy, pytest]
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.4
rev: v0.13.2
hooks:
- id: ruff
- id: ruff-check
args: [--fix]
- id: ruff-format

- repo: local
hooks:
- id: mypy
name: mypy
entry: mypy
entry: uv run mypy
language: python
types: [python]
require_serial: true

- id: pytest
name: pytest
entry: pytest --cov=ai_toolkit --cov-report=html --durations=0
entry: uv run pytest --cov=ai_toolkit --cov-report=html --durations=0
language: python
types: [python]
always_run: true
Expand Down
6 changes: 5 additions & 1 deletion ai_toolkit/datasets/dataset_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ class DatasetRNN(DatasetLoader):
def pad_collate(batch):
(xx, yy) = zip(*batch, strict=False)
x_lens = torch.tensor([len(x) for x in xx])
xx_pad = pad_sequence(xx, batch_first=True, padding_value=0) # type: ignore[arg-type]
xx_pad = pad_sequence(
xx,
batch_first=True,
padding_value=0,
)
yy_pad = torch.stack(yy)
return xx_pad, yy_pad, x_lens

Expand Down
3 changes: 3 additions & 0 deletions ai_toolkit/metric_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def __eq__(self, other: object) -> bool:
return True
return NotImplemented

def __hash__(self) -> int:
return 0

def __repr__(self) -> str:
return str(self.json_repr())

Expand Down
2 changes: 1 addition & 1 deletion ai_toolkit/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class Accuracy(Metric):
def __repr__(self) -> str:
return f"{self.name}: {100. * self.value:.2f}%"
return f"{self.name}: {100.0 * self.value:.2f}%"

@staticmethod
def calculate_accuracy(output: torch.Tensor, target: torch.Tensor) -> float:
Expand Down
13 changes: 7 additions & 6 deletions ai_toolkit/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from __future__ import annotations

import sys
from typing import cast

from torch import nn
from typing import TYPE_CHECKING, cast

from .cnn import BasicCNN
from .dense import DenseNet
Expand All @@ -14,19 +12,22 @@
# from .unet import UNet
# from .efficient_net import EfficientNet

if TYPE_CHECKING:
from torch import nn


def get_model_initializer(model_name: str) -> type[nn.Module]:
"""Retrieves class initializer from its string name."""
if not hasattr(sys.modules[__name__], model_name):
raise RuntimeError(f"Model class {model_name} not found in models/")
return cast(type[nn.Module], getattr(sys.modules[__name__], model_name))
return cast("type[nn.Module]", getattr(sys.modules[__name__], model_name))


__all__ = (
"BasicCNN",
"DenseNet",
"BasicLSTM",
"MaskRCNN",
"BasicRNN",
"DenseNet",
"MaskRCNN",
"get_model_initializer",
)
2 changes: 1 addition & 1 deletion ai_toolkit/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_model(
test_loss /= test_len
print(
f"\nTest set: Average loss: {test_loss:.4f},",
f"Accuracy: {correct}/{test_len} ({100. * correct / test_len:.2f}%)\n",
f"Accuracy: {correct}/{test_len} ({100.0 * correct / test_len:.2f}%)\n",
)


Expand Down
18 changes: 9 additions & 9 deletions ai_toolkit/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import numpy as np
import torch
from torch import nn, optim
from torch.optim import lr_scheduler
from torch import nn
from torch.optim import AdamW, lr_scheduler

from ai_toolkit import util
from ai_toolkit.args import Arguments, init_pipeline
Expand All @@ -22,6 +22,8 @@
if TYPE_CHECKING:
from collections.abc import Iterator

from torch.optim.optimizer import Optimizer

tqdm: Any
if "google.colab" in sys.modules:
from tqdm import tqdm_notebook as tqdm
Expand All @@ -33,7 +35,7 @@ def train_and_validate(
args: Arguments,
model: nn.Module,
loader: TensorDataLoader,
optimizer: optim.Optimizer | None,
optimizer: Optimizer | None,
criterion: nn.Module,
metrics: MetricTracker,
mode: Mode,
Expand Down Expand Up @@ -80,14 +82,12 @@ def train_and_validate(
metrics.epoch_update(mode)


def get_optimizer(args: Arguments, model: nn.Module) -> optim.Optimizer:
def get_optimizer(args: Arguments, model: nn.Module) -> Optimizer:
params = filter(lambda p: p.requires_grad, model.parameters())
return optim.AdamW(params, lr=args.lr)
return AdamW(params, lr=args.lr)


def get_scheduler(
args: Arguments, optimizer: optim.Optimizer
) -> lr_scheduler.LRScheduler:
def get_scheduler(args: Arguments, optimizer: Optimizer) -> lr_scheduler.LRScheduler:
return lr_scheduler.StepLR(optimizer, step_size=1, gamma=args.gamma)


Expand All @@ -96,7 +96,7 @@ def load_model(
device: torch.device,
init_params: tuple[Any, ...],
loader: Iterator[Any],
) -> tuple[nn.Module, nn.Module, optim.Optimizer, lr_scheduler.LRScheduler | None]:
) -> tuple[nn.Module, nn.Module, Optimizer, lr_scheduler.LRScheduler | None]:
criterion = get_loss_initializer(args.loss)()
model = get_model_initializer(args.model)(*init_params).to(device)
optimizer = get_optimizer(args, model)
Expand Down
5 changes: 3 additions & 2 deletions ai_toolkit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@

import numpy as np
import torch
from torch import nn, optim
from torch import nn
from torch.utils.data import DataLoader

if TYPE_CHECKING:
from collections.abc import Iterator

from torch.optim import lr_scheduler
from torch.optim.optimizer import Optimizer

# Redefining here to avoid circular import
TensorDataLoader = DataLoader[tuple[torch.Tensor, ...]]
Expand Down Expand Up @@ -69,7 +70,7 @@ def load_checkpoint(checkpoint_path: Path, use_best: bool = False) -> dict[str,
def load_state_dict(
checkpoint: dict[str, Any],
model: nn.Module,
optimizer: optim.Optimizer | None = None,
optimizer: Optimizer | None = None,
scheduler: lr_scheduler.LRScheduler | None = None,
) -> None:
"""
Expand Down
11 changes: 6 additions & 5 deletions ai_toolkit/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import torch
import torchinfo
from torch import nn, optim
from torch import nn
from torch.optim.optimizer import Optimizer

from ai_toolkit.args import Arguments

Expand All @@ -20,7 +21,7 @@ def verify_model(
args: Arguments,
model: nn.Module,
loader: Iterator[Any],
optimizer: optim.Optimizer,
optimizer: Optimizer,
criterion: nn.Module,
device: torch.device,
) -> None:
Expand All @@ -47,7 +48,7 @@ def model_summary(args: Arguments, model: nn.Module, loader: Iterator[Any]) -> N
def check_batch_dimension(
model: nn.Module,
loader: Iterator[Any],
optimizer: optim.Optimizer,
optimizer: Optimizer,
test_val: int = 2,
) -> None:
"""
Expand Down Expand Up @@ -82,7 +83,7 @@ def check_batch_dimension(
def overfit_example(
model: nn.Module,
loader: Iterator[Any],
optimizer: optim.Optimizer,
optimizer: Optimizer,
criterion: nn.Module,
device: torch.device,
batch_dim: int = 0,
Expand Down Expand Up @@ -147,7 +148,7 @@ def detect_nan_tensors(model: nn.Module) -> None:
def check_all_layers_training(
model: nn.Module,
loader: Iterator[Any],
optimizer: optim.Optimizer,
optimizer: Optimizer,
criterion: nn.Module,
) -> None:
"""
Expand Down
98 changes: 91 additions & 7 deletions ruff.toml → pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,95 @@
target-version = "py312"
[project]
name = "ai_toolkit"
version = "0.0.2"
description = "AI Toolkit"
authors = [
{ name = "Tyler Yep", email = "tyep@cs.stanford.edu" },
]
readme = "README.md"
license = "MIT"
classifiers = [
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.12",
]
requires-python = ">=3.12"
dependencies = [
"efficientnet_pytorch",
"matplotlib",
"numpy",
"pandas",
"pillow",
"protobuf",
"tensorboard",
"torch==2.2",
"torchinfo",
"torchvision",
"tqdm",
"wget",
]

[dependency-groups]
dev = [
"mypy",
"pre-commit",
"pytest",
"pytest-cov",
"ruff",
"types-setuptools",
"types-tqdm",
]

[project.urls]
Homepage = "https://github.com/tyleryep/ai-toolkit"

[build-system]
requires = [
"setuptools>=61.2",
]
build-backend = "setuptools.build_meta"

[tool.setuptools]
include-package-data = true

[tool.setuptools.packages.find]
namespaces = false

[tool.setuptools.package-data]
ai_toolkit = [
"py.typed",
]

[tool.mypy]
strict = true
warn_unreachable = true
disallow_any_unimported = true
extra_checks = true
enable_error_code = "ignore-without-code"
warn_return_any = false

[[tool.mypy.overrides]]
module = [
"ai_toolkit.datasets.*",
]
allow_untyped_defs = true
allow_untyped_calls = true

[[tool.mypy.overrides]]
module = [
"ai_toolkit.models.*",
]
allow_untyped_defs = true

[tool.pytest.ini_options]
filterwarnings = [
"ignore:Call to deprecated create function",
]

[tool.ruff]
target-version = "py313"
lint.select = ["ALL"]
lint.ignore = [
"ANN101", # Missing type annotation for `self` in method
"ANN102", # Missing type annotation for `cls` in classmethod
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed
"C901", # function is too complex (12 > 10)
"COM812", # Trailing comma missing
Expand All @@ -14,7 +101,6 @@ lint.ignore = [
"FBT002", # Boolean default value in function definition
"FBT003", # Boolean positional value in function call
"FIX002", # Line contains TODO
"ISC001", # Isort
"PLR0911", # Too many return statements (11 > 6)
"PLR2004", # Magic value used in comparison, consider replacing 2 with a constant variable
"PLR0912", # Too many branches
Expand All @@ -24,6 +110,7 @@ lint.ignore = [
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
"T201", # print() found
"T203", # pprint() found
"TC006", # Add quotes to type expression in `typing.cast()`
"TD002", # Missing author in TODO; try: `# TODO(<author_name>): ...`
"TD003", # Missing issue link on the line following this TODO
"TD005", # Missing issue description after `TODO`
Expand All @@ -40,6 +127,3 @@ lint.ignore = [
"N812", # Lowercase `functional` imported as non-lowercase `F`
"NPY002", # Replace legacy `np.random.randn` call with `np.random.Generator`
]

[lint.flake8-pytest-style]
fixture-parentheses = false
Loading