Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
lightning_logs/
dataset/
__pycache__/
__pycache__/
*.egg-info/
15 changes: 0 additions & 15 deletions config.py

This file was deleted.

37 changes: 37 additions & 0 deletions configs/mnist_simple_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Training Hyperparameters
learning_rate = 0.001
num_epochs = 3

# Dataset
data_dir = "dataset/"
num_workers = 4
batch_size = 64
input_size = 784
num_classes = 10

# Compute related
accelerator = "gpu"
devices = [0]
precision = 32

# Trainer callbacks
callbacks = [
dict(
type="RichProgressBar",
leave=True
),
dict(
type="RichModelSummary",
)
]

# default Lightning Trainer
trainer = dict(
type="Trainer",
accelerator=accelerator,
devices=devices,
min_epochs=1,
max_epochs=num_epochs,
precision=precision,
callbacks=callbacks
)
25 changes: 25 additions & 0 deletions configs/mnist_simple_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Training Hyperparameters
learning_rate: 0.001

# Dataset
data_dir: "dataset/"
batch_size: 64
num_workers: 4
input_size: 784
num_classes: 10

# Trainer callbacks
callbacks: &callbacks
- type: "RichProgressBar"
leave: True
- type: "RichModelSummary"

# default Lightning Trainer
trainer:
type: "Trainer" # L.Trainer
accelerator: "gpu"
devices: [0]
min_epochs: 1
max_epochs: 3
precision: 32
callbacks: *callbacks
10 changes: 10 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[tool.poetry]
name = "lightning-codebase"
version = "0.1"
description = "A PyTorch Lightning codebase for fast training setups"

[tool.poetry.dependencies]
lightning = "*"
torchvision = "*"
munch = "*"

5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
lightning
torchvision
munch

pip install -e .
Empty file.
1 change: 1 addition & 0 deletions src/lightning_codebase/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .mnist import *
File renamed without changes.
1 change: 1 addition & 0 deletions src/lightning_codebase/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .mnist_simple_model import *
File renamed without changes.
1 change: 1 addition & 0 deletions src/lightning_codebase/trainers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .creator import *
26 changes: 26 additions & 0 deletions src/lightning_codebase/trainers/creator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from ..utils.registry import Registry
import lightning as L
from lightning.pytorch.callbacks import RichProgressBar, RichModelSummary

TRAINERS = Registry("trainers")
CALLBACKS = Registry("callbacks")

TRAINERS.register(module=L.Trainer)

CALLBACKS.register(module=RichProgressBar)
CALLBACKS.register(module=RichModelSummary)

def create_callbacks(callbacks_cfg):
def create_callback(cfg):
return CALLBACKS.create(cfg)

callbacks = []
for callback in callbacks_cfg:
callbacks.append(create_callback(callback))
return callbacks

def create_trainer(cfg):
cfg_callbacks_dict = cfg.get("callbacks", None)
if cfg_callbacks_dict:
cfg["callbacks"] = create_callbacks(cfg_callbacks_dict)
return TRAINERS.create(cfg)
2 changes: 2 additions & 0 deletions src/lightning_codebase/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .registry import *
from .config import *
7 changes: 7 additions & 0 deletions src/lightning_codebase/utils/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from munch import Munch
import yaml

def load_config(path):
with open(path, 'r') as file:
config_dict = yaml.safe_load(file)
return Munch(config_dict)
55 changes: 55 additions & 0 deletions src/lightning_codebase/utils/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@

class Registry():
"""
Generates a mapping of strings to classes
Useful to write configs that can use multiple classes as fields,
like models, trainers, schedulers etc.
"""
def __init__(self, name):
self.name = name
self._data_dict = dict()

def register(self, module=None):

# For existent module classes
if module is not None:
self._data_dict[module.__name__] = module

# For customized classes (used as decorator)
def _register(custom_class):
self._data_dict[custom_class.__name__] = custom_class

return _register

def get(self, key):
creator = self._data_dict.get(key, None)
if creator is None:
raise KeyError(f"Registry does not contain the key {key}")
return creator

def create(self, cfg):
"""
Creates an object of the class defined by the 'type'
field in a configuration dict if the class is registred

Args:
cfg (dict): configuration containing the data to build a class object
registry (Registry): Registry where the class is mapped
"""

if "type" not in cfg:
raise AttributeError("'cfg' object does not have an 'type' attribute.")

# Finds the class in registry
args = cfg.deepcopy()
class_name = args.pop("type")
creator = self.get(class_name)

# Creates object from class
try:
# By removing the type from the args, the rest of them are used to create the object
return creator(**args)
except:
raise AttributeError(
f"'cfg' object contains the wrong arguments to create an object of {class_name}"
)
24 changes: 13 additions & 11 deletions train_and_eval.py → tools/train_and_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,35 @@
from lightning.pytorch.callbacks import RichProgressBar, RichModelSummary
from pytorch_lightning.loggers import TensorBoardLogger

from model import MnistSimpleModel
from dataset import MnistDataModule
import config
from lightning_codebase.models import MnistSimpleModel
from lightning_codebase.datasets import MnistDataModule
import configs.mnist_simple_config as mnist_simple_config



def main():
torch.set_float32_matmul_precision("medium")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger = TensorBoardLogger("tb_logs, ")
# Datamodule
datamodule = MnistDataModule(
data_dir=config.DATA_DIR,
batch_size=config.BATCH_SIZE,
num_workers=config.NUM_WORKERS,
data_dir=mnist_simple_config.data_dir,
batch_size=mnist_simple_config.batch_size,
num_workers=mnist_simple_config.num_workers,
)

# Initialize network
model = MnistSimpleModel(
input_size=config.INPUT_SIZE, num_classes=config.NUM_CLASSES
input_size=mnist_simple_config.input_size, num_classes=mnist_simple_config.num_classes, learning_rate=mnist_simple_config.learning_rate
).to(device)

# Trainer
trainer = L.Trainer(
accelerator=config.ACCELERATOR,
devices=config.DEVICES,
accelerator=mnist_simple_config.accelerator,
devices=mnist_simple_config.devices,
min_epochs=1,
max_epochs=config.NUM_EPOCHS,
precision=config.PRECISION,
max_epochs=mnist_simple_config.num_epochs,
precision=mnist_simple_config.precision,
callbacks=[RichProgressBar(leave=True), RichModelSummary()],
)

Expand Down
35 changes: 35 additions & 0 deletions tools/train_eval_with_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import argparse
from lightning_codebase.utils import load_config

import torch
from lightning_codebase.models import create_model
from lightning_codebase.datasets import create_datamodule
from lightning_codebase.trainers import create_trainer

def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('config', type=str, help='path to the YAML training configuration file')
args = parser.parse_args()
return args

def main():
args = get_args()
cfg = load_config(args.config)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_float32_matmul_precision("medium")

datamodule = create_datamodule(cfg.datamodule)
model = create_model(cfg.model).to(device)
trainer = create_trainer(cfg.trainer)

# Training and evaluation
trainer.fit(model, datamodule)
trainer.validate(model, datamodule)
trainer.test(model, datamodule)

# TODO: Qualitative prediction results


if __name__ == "__main__":
main()