From 48d99a9c914da2d360353a17a2c69d6a7238420a Mon Sep 17 00:00:00 2001 From: Leonardo Holtz Date: Sun, 26 Jan 2025 17:44:30 -0300 Subject: [PATCH] feat: Start of registry logic --- .gitignore | 3 +- config.py | 15 ----- configs/mnist_simple_config.py | 37 +++++++++++++ configs/mnist_simple_config.yaml | 25 +++++++++ pyproject.toml | 10 ++++ requirements.txt | 5 ++ src/lightning_codebase/__init__.py | 0 src/lightning_codebase/datasets/__init__.py | 1 + .../lightning_codebase/datasets/mnist.py | 0 src/lightning_codebase/models/__init__.py | 1 + .../models/mnist_simple_model.py | 0 src/lightning_codebase/trainers/__init__.py | 1 + src/lightning_codebase/trainers/creator.py | 26 +++++++++ src/lightning_codebase/utils/__init__.py | 2 + src/lightning_codebase/utils/config.py | 7 +++ src/lightning_codebase/utils/registry.py | 55 +++++++++++++++++++ train_and_eval.py => tools/train_and_eval.py | 24 ++++---- tools/train_eval_with_registry.py | 35 ++++++++++++ 18 files changed, 220 insertions(+), 27 deletions(-) delete mode 100644 config.py create mode 100644 configs/mnist_simple_config.py create mode 100644 configs/mnist_simple_config.yaml create mode 100644 pyproject.toml create mode 100644 requirements.txt create mode 100644 src/lightning_codebase/__init__.py create mode 100644 src/lightning_codebase/datasets/__init__.py rename dataset.py => src/lightning_codebase/datasets/mnist.py (100%) create mode 100644 src/lightning_codebase/models/__init__.py rename model.py => src/lightning_codebase/models/mnist_simple_model.py (100%) create mode 100644 src/lightning_codebase/trainers/__init__.py create mode 100644 src/lightning_codebase/trainers/creator.py create mode 100644 src/lightning_codebase/utils/__init__.py create mode 100644 src/lightning_codebase/utils/config.py create mode 100644 src/lightning_codebase/utils/registry.py rename train_and_eval.py => tools/train_and_eval.py (57%) create mode 100644 tools/train_eval_with_registry.py diff --git a/.gitignore b/.gitignore index eb6ae3b..b0246b3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ lightning_logs/ dataset/ -__pycache__/ \ No newline at end of file +__pycache__/ +*.egg-info/ \ No newline at end of file diff --git a/config.py b/config.py deleted file mode 100644 index 47a28a9..0000000 --- a/config.py +++ /dev/null @@ -1,15 +0,0 @@ -# Training Hyperparameters -INPUT_SIZE = 784 -NUM_CLASSES = 10 -LEARNING_RATE = 0.001 -NUM_EPOCHS = 3 - -# Dataset -DATA_DIR = "dataset/" -NUM_WORKERS = 4 -BATCH_SIZE = 64 - -# Compute related -ACCELERATOR = "gpu" -DEVICES = [0] -PRECISION = 32 diff --git a/configs/mnist_simple_config.py b/configs/mnist_simple_config.py new file mode 100644 index 0000000..11e9230 --- /dev/null +++ b/configs/mnist_simple_config.py @@ -0,0 +1,37 @@ +# Training Hyperparameters +learning_rate = 0.001 +num_epochs = 3 + +# Dataset +data_dir = "dataset/" +num_workers = 4 +batch_size = 64 +input_size = 784 +num_classes = 10 + +# Compute related +accelerator = "gpu" +devices = [0] +precision = 32 + +# Trainer callbacks +callbacks = [ + dict( + type="RichProgressBar", + leave=True + ), + dict( + type="RichModelSummary", + ) +] + +# default Lightning Trainer +trainer = dict( + type="Trainer", + accelerator=accelerator, + devices=devices, + min_epochs=1, + max_epochs=num_epochs, + precision=precision, + callbacks=callbacks +) \ No newline at end of file diff --git a/configs/mnist_simple_config.yaml b/configs/mnist_simple_config.yaml new file mode 100644 index 0000000..037ab6e --- /dev/null +++ b/configs/mnist_simple_config.yaml @@ -0,0 +1,25 @@ +# Training Hyperparameters +learning_rate: 0.001 + +# Dataset +data_dir: "dataset/" +batch_size: 64 +num_workers: 4 +input_size: 784 +num_classes: 10 + +# Trainer callbacks +callbacks: &callbacks + - type: "RichProgressBar" + leave: True + - type: "RichModelSummary" + +# default Lightning Trainer +trainer: + type: "Trainer" # L.Trainer + accelerator: "gpu" + devices: [0] + min_epochs: 1 + max_epochs: 3 + precision: 32 + callbacks: *callbacks diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..b04e380 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,10 @@ +[tool.poetry] +name = "lightning-codebase" +version = "0.1" +description = "A PyTorch Lightning codebase for fast training setups" + +[tool.poetry.dependencies] + lightning = "*" + torchvision = "*" + munch = "*" + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..8aba6ea --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +lightning +torchvision +munch + +pip install -e . \ No newline at end of file diff --git a/src/lightning_codebase/__init__.py b/src/lightning_codebase/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/lightning_codebase/datasets/__init__.py b/src/lightning_codebase/datasets/__init__.py new file mode 100644 index 0000000..fe34b11 --- /dev/null +++ b/src/lightning_codebase/datasets/__init__.py @@ -0,0 +1 @@ +from .mnist import * \ No newline at end of file diff --git a/dataset.py b/src/lightning_codebase/datasets/mnist.py similarity index 100% rename from dataset.py rename to src/lightning_codebase/datasets/mnist.py diff --git a/src/lightning_codebase/models/__init__.py b/src/lightning_codebase/models/__init__.py new file mode 100644 index 0000000..1ea88e4 --- /dev/null +++ b/src/lightning_codebase/models/__init__.py @@ -0,0 +1 @@ +from .mnist_simple_model import * \ No newline at end of file diff --git a/model.py b/src/lightning_codebase/models/mnist_simple_model.py similarity index 100% rename from model.py rename to src/lightning_codebase/models/mnist_simple_model.py diff --git a/src/lightning_codebase/trainers/__init__.py b/src/lightning_codebase/trainers/__init__.py new file mode 100644 index 0000000..8bb28b6 --- /dev/null +++ b/src/lightning_codebase/trainers/__init__.py @@ -0,0 +1 @@ +from .creator import * \ No newline at end of file diff --git a/src/lightning_codebase/trainers/creator.py b/src/lightning_codebase/trainers/creator.py new file mode 100644 index 0000000..a77674a --- /dev/null +++ b/src/lightning_codebase/trainers/creator.py @@ -0,0 +1,26 @@ +from ..utils.registry import Registry +import lightning as L +from lightning.pytorch.callbacks import RichProgressBar, RichModelSummary + +TRAINERS = Registry("trainers") +CALLBACKS = Registry("callbacks") + +TRAINERS.register(module=L.Trainer) + +CALLBACKS.register(module=RichProgressBar) +CALLBACKS.register(module=RichModelSummary) + +def create_callbacks(callbacks_cfg): + def create_callback(cfg): + return CALLBACKS.create(cfg) + + callbacks = [] + for callback in callbacks_cfg: + callbacks.append(create_callback(callback)) + return callbacks + +def create_trainer(cfg): + cfg_callbacks_dict = cfg.get("callbacks", None) + if cfg_callbacks_dict: + cfg["callbacks"] = create_callbacks(cfg_callbacks_dict) + return TRAINERS.create(cfg) diff --git a/src/lightning_codebase/utils/__init__.py b/src/lightning_codebase/utils/__init__.py new file mode 100644 index 0000000..5098778 --- /dev/null +++ b/src/lightning_codebase/utils/__init__.py @@ -0,0 +1,2 @@ +from .registry import * +from .config import * \ No newline at end of file diff --git a/src/lightning_codebase/utils/config.py b/src/lightning_codebase/utils/config.py new file mode 100644 index 0000000..e7fc778 --- /dev/null +++ b/src/lightning_codebase/utils/config.py @@ -0,0 +1,7 @@ +from munch import Munch +import yaml + +def load_config(path): + with open(path, 'r') as file: + config_dict = yaml.safe_load(file) + return Munch(config_dict) \ No newline at end of file diff --git a/src/lightning_codebase/utils/registry.py b/src/lightning_codebase/utils/registry.py new file mode 100644 index 0000000..bb2d757 --- /dev/null +++ b/src/lightning_codebase/utils/registry.py @@ -0,0 +1,55 @@ + +class Registry(): + """ + Generates a mapping of strings to classes + Useful to write configs that can use multiple classes as fields, + like models, trainers, schedulers etc. + """ + def __init__(self, name): + self.name = name + self._data_dict = dict() + + def register(self, module=None): + + # For existent module classes + if module is not None: + self._data_dict[module.__name__] = module + + # For customized classes (used as decorator) + def _register(custom_class): + self._data_dict[custom_class.__name__] = custom_class + + return _register + + def get(self, key): + creator = self._data_dict.get(key, None) + if creator is None: + raise KeyError(f"Registry does not contain the key {key}") + return creator + + def create(self, cfg): + """ + Creates an object of the class defined by the 'type' + field in a configuration dict if the class is registred + + Args: + cfg (dict): configuration containing the data to build a class object + registry (Registry): Registry where the class is mapped + """ + + if "type" not in cfg: + raise AttributeError("'cfg' object does not have an 'type' attribute.") + + # Finds the class in registry + args = cfg.deepcopy() + class_name = args.pop("type") + creator = self.get(class_name) + + # Creates object from class + try: + # By removing the type from the args, the rest of them are used to create the object + return creator(**args) + except: + raise AttributeError( + f"'cfg' object contains the wrong arguments to create an object of {class_name}" + ) \ No newline at end of file diff --git a/train_and_eval.py b/tools/train_and_eval.py similarity index 57% rename from train_and_eval.py rename to tools/train_and_eval.py index 043349f..4f87364 100644 --- a/train_and_eval.py +++ b/tools/train_and_eval.py @@ -4,9 +4,11 @@ from lightning.pytorch.callbacks import RichProgressBar, RichModelSummary from pytorch_lightning.loggers import TensorBoardLogger -from model import MnistSimpleModel -from dataset import MnistDataModule -import config +from lightning_codebase.models import MnistSimpleModel +from lightning_codebase.datasets import MnistDataModule +import configs.mnist_simple_config as mnist_simple_config + + def main(): torch.set_float32_matmul_precision("medium") @@ -14,23 +16,23 @@ def main(): logger = TensorBoardLogger("tb_logs, ") # Datamodule datamodule = MnistDataModule( - data_dir=config.DATA_DIR, - batch_size=config.BATCH_SIZE, - num_workers=config.NUM_WORKERS, + data_dir=mnist_simple_config.data_dir, + batch_size=mnist_simple_config.batch_size, + num_workers=mnist_simple_config.num_workers, ) # Initialize network model = MnistSimpleModel( - input_size=config.INPUT_SIZE, num_classes=config.NUM_CLASSES + input_size=mnist_simple_config.input_size, num_classes=mnist_simple_config.num_classes, learning_rate=mnist_simple_config.learning_rate ).to(device) # Trainer trainer = L.Trainer( - accelerator=config.ACCELERATOR, - devices=config.DEVICES, + accelerator=mnist_simple_config.accelerator, + devices=mnist_simple_config.devices, min_epochs=1, - max_epochs=config.NUM_EPOCHS, - precision=config.PRECISION, + max_epochs=mnist_simple_config.num_epochs, + precision=mnist_simple_config.precision, callbacks=[RichProgressBar(leave=True), RichModelSummary()], ) diff --git a/tools/train_eval_with_registry.py b/tools/train_eval_with_registry.py new file mode 100644 index 0000000..2e09600 --- /dev/null +++ b/tools/train_eval_with_registry.py @@ -0,0 +1,35 @@ +import argparse +from lightning_codebase.utils import load_config + +import torch +from lightning_codebase.models import create_model +from lightning_codebase.datasets import create_datamodule +from lightning_codebase.trainers import create_trainer + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('config', type=str, help='path to the YAML training configuration file') + args = parser.parse_args() + return args + +def main(): + args = get_args() + cfg = load_config(args.config) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + torch.set_float32_matmul_precision("medium") + + datamodule = create_datamodule(cfg.datamodule) + model = create_model(cfg.model).to(device) + trainer = create_trainer(cfg.trainer) + + # Training and evaluation + trainer.fit(model, datamodule) + trainer.validate(model, datamodule) + trainer.test(model, datamodule) + + # TODO: Qualitative prediction results + + +if __name__ == "__main__": + main()