Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,37 +1,41 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
rev: v6.0.0
hooks:
- id: end-of-file-fixer
- id: fix-encoding-pragma
args: [ --remove ]
- id: mixed-line-ending
- id: trailing-whitespace
- id: check-yaml

- repo: https://github.com/asottile/pyupgrade
rev: v3.21.0
hooks:
- id: pyupgrade
args: [ --py310-plus ]

- repo: https://github.com/ikamensh/flynt/
rev: '1.0.1'
rev: '1.0.6'
hooks:
- id: flynt

- repo: https://github.com/psf/black
rev: 25.1.0
rev: 25.9.0
hooks:
- id: black
exclude: (.*)/migrations

- repo: https://github.com/pycqa/flake8
rev: 7.2.0
rev: 7.3.0
hooks:
- id: flake8

- repo: https://github.com/pycqa/isort
rev: '6.0.1'
rev: '7.0.0'
hooks:
- id: isort

- repo: https://github.com/PyCQA/bandit
rev: 1.8.3
rev: 1.8.6
hooks:
- id: bandit
args: [ "-c", "pyproject.toml" ]
Expand All @@ -40,7 +44,7 @@ repos:
- repo: https://github.com/PyCQA/pylint
# Configuration help can be found here:
# https://pylint.pycqa.org/en/latest/user_guide/installation/pre-commit-integration.html
rev: v3.3.6
rev: v4.0.2
hooks:
- id: pylint
alias: pylint-with-spelling
Expand All @@ -55,7 +59,7 @@ repos:
)$

- repo: https://github.com/commitizen-tools/commitizen
rev: v4.6.0
rev: v4.9.1
hooks:
- id: commitizen
stages: [ commit-msg ]
Expand All @@ -66,6 +70,6 @@ repos:
- id: nb-clean

- repo: https://github.com/pycqa/doc8
rev: v1.1.2
rev: v2.0.0
hooks:
- id: doc8
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dependencies = [
'orbax-checkpoint',
'e3nn-jax',
"equinox",
"reax>=0.2.0",
"reax>=0.2,<0.6",
"tensorial>=0.4.2",
"pymatgen",
]
Expand Down
4 changes: 2 additions & 2 deletions src/e3response/_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import shutil
import tempfile
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING

from typing_extensions import override

Expand All @@ -12,7 +12,7 @@

class MlflowHandler(logging.Handler):
def __init__(
self, client: "mlflow.tracking.MlflowClient", run_id: str, level: Union[int, str] = 0
self, client: "mlflow.tracking.MlflowClient", run_id: str, level: int | str = 0
) -> None:
super().__init__(level)
self._tempfile = None
Expand Down
17 changes: 9 additions & 8 deletions src/e3response/data/barium_titanate.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from collections.abc import Callable, Iterable, Sequence
import functools
import logging
import pathlib
import re
import tarfile
from typing import Any, Callable, Final, Iterable, Optional, Sequence, Union
from typing import Any, Final

import ase
import ase.io
Expand Down Expand Up @@ -33,13 +34,13 @@ class BtoDataModule(reax.DataModule):
def __init__(
self,
r_max: float,
data_dir: Union[str, pathlib.Path] = "data/bto/",
data_dir: str | pathlib.Path = "data/bto/",
archives: Sequence[str] = (
"BTO_Pm-3m_5atoms_400K_3x3x3_ensemble.tar.gz",
"BTO_Pm-3m_5atoms_800K_3x3x3.tar.gz",
),
tensors: tuple[str] = ("raman_tensors", "born_charges", "dielectric"),
train_val_test_split: Sequence[Union[int, float]] = (0.8, 0.1, 0.1),
train_val_test_split: Sequence[int | float] = (0.8, 0.1, 0.1),
batch_size: int = 64,
) -> None:
"""Initialize a `SiliconDataModule`.
Expand All @@ -55,14 +56,14 @@ def __init__(
self._data_dir: Final[str] = str(data_dir)
self._archives: Final[tuple[str, ...]] = tuple(archives)
self._tensors = tensors
self._train_val_test_split: Final[Sequence[Union[int, float]]] = train_val_test_split
self._train_val_test_split: Final[Sequence[int | float]] = train_val_test_split
self._batch_size: Final[int] = batch_size

# State
self.batch_size_per_device = batch_size
self.data_train: Optional[reax.data.Dataset] = None
self.data_val: Optional[reax.data.Dataset] = None
self.data_test: Optional[reax.data.Dataset] = None
self.data_train: reax.data.Dataset | None = None
self.data_val: reax.data.Dataset | None = None
self.data_test: reax.data.Dataset | None = None

@override
def setup(self, stage: "reax.Stage", /) -> None:
Expand Down Expand Up @@ -235,7 +236,7 @@ def get_structures(root_dir: pathlib.Path, tensors: Iterable[str]) -> list[ase.A


def read_scf(filename) -> ase.Atoms:
with open(filename, "r", encoding="utf-8") as fileobj:
with open(filename, encoding="utf-8") as fileobj:
_data, card_lines = espresso.read_fortran_namelist(fileobj)

cell, _ = espresso.get_cell_parameters(card_lines)
Expand Down
45 changes: 23 additions & 22 deletions src/e3response/data/qm9_nmr.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import collections
from collections.abc import Callable, Sequence
import functools
from functools import lru_cache
import logging
import os
import pathlib
import re
import tempfile
from typing import Any, Callable, Final, Optional, Sequence, Union
from typing import Any, Final
import urllib.error
import urllib.request
import zipfile

import ase
import jraph
import numpy as np
from pymatgen.io import gaussian # type: ignore
import pymatgen.io.ase # type: ignore
from pymatgen.io import gaussian
import pymatgen.io.ase
import reax
from tensorial import gcnn
import tqdm
Expand Down Expand Up @@ -58,9 +59,9 @@ def __init__(
self,
r_max: float = 5,
data_dir: str = "data/qm9_nmr/",
dataset: Union[str, Sequence[str]] = "gasphase",
atom_keys: Optional[Union[str, Sequence[str]]] = None,
limit: Optional[int] = None,
dataset: str | Sequence[str] = "gasphase",
atom_keys: str | Sequence[str] | None = None,
limit: int | None = None,
) -> None:
"""
Initialize the QM9-NMR dataset.
Expand Down Expand Up @@ -131,7 +132,7 @@ def __init__(
try:
with zipfile.ZipFile(archive_path, "r") as zip_ref:
zip_ref.testzip()
except (zipfile.BadZipFile, zipfile.LargeZipFile, IOError) as e:
except (zipfile.BadZipFile, zipfile.LargeZipFile, OSError) as e:
_LOGGER.warning(
"%s is corrupted or unreadable: %s, removing corrupted archive ...",
archive_name,
Expand Down Expand Up @@ -189,7 +190,7 @@ def reporthook(_block_num, block_size, total_size):
except OSError as e:
_LOGGER.error("Filesystem error while writing %s: %s", path, e)

def _extract_archive_zip(self, zip_path: str, limit: Optional[int] = None) -> list:
def _extract_archive_zip(self, zip_path: str, limit: int | None = None) -> list:

structures = []

Expand Down Expand Up @@ -229,7 +230,7 @@ def _create_molecule_data(log_file):
structure = gaussian_output.final_structure

# extraction of data from .log file
with open(log_file, "r", encoding="utf-8") as file:
with open(log_file, encoding="utf-8") as file:
log_data = file.read()

shielding_pattern = (
Expand Down Expand Up @@ -294,12 +295,12 @@ def _create_molecule_data(log_file):
_LOGGER.error("Error in file %s: %s", log_file, e)
raise

except (IOError, OSError) as e:
except OSError as e:
_LOGGER.error("File system error while processing %s: %s", log_file, e)
raise


def get_structure_and_data_from_log(log_path: pathlib.Path) -> Optional[ase.Atoms]:
def get_structure_and_data_from_log(log_path: pathlib.Path) -> ase.Atoms | None:
# _LOGGER.info("Parsing Gaussian .log file: %s", log_path)

try:
Expand Down Expand Up @@ -333,7 +334,7 @@ def get_structure_and_data_from_log(log_path: pathlib.Path) -> Optional[ase.Atom

return atoms

except (ValueError, IOError) as e:
except (ValueError, OSError) as e:
_LOGGER.error("Parsing error for %s: %s", log_path, e)
return None

Expand All @@ -350,10 +351,10 @@ def __init__(
self,
r_max: float = 5,
data_dir: str = "data/qm9_nmr/",
dataset: Union[str, Sequence[str]] = "gasphase",
atom_keys: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
train_val_test_split: Sequence[Union[int, float]] = (0.85, 0.05, 0.1),
dataset: str | Sequence[str] = "gasphase",
atom_keys: Sequence[str] | None = None,
limit: int | None = None,
train_val_test_split: Sequence[int | float] = (0.85, 0.05, 0.1),
batch_size: int = 64,
) -> None:
"""Initialize a QM9-NMR data module.
Expand All @@ -371,19 +372,19 @@ def __init__(

# Params
self._data_dir: Final[str] = data_dir
self._dataset: Union[str, Sequence[str]] = dataset
self.dataset: Optional[Qm9NmrDataset] = None
self._dataset: str | Sequence[str] = dataset
self.dataset: Qm9NmrDataset | None = None
self._rmax = r_max
self._atom_keys = atom_keys
self._limit = limit
self._train_val_test_split: Final[Sequence[Union[int, float]]] = train_val_test_split
self._train_val_test_split: Final[Sequence[int | float]] = train_val_test_split
self._batch_size: Final[int] = batch_size

# State
self.batch_size_per_device = batch_size
self.data_train: Optional[reax.data.Dataset] = None
self.data_val: Optional[reax.data.Dataset] = None
self.data_test: Optional[reax.data.Dataset] = None
self.data_train: reax.data.Dataset | None = None
self.data_val: reax.data.Dataset | None = None
self.data_test: reax.data.Dataset | None = None

@override
def setup(self, stage: "reax.Stage", /) -> None:
Expand Down
3 changes: 1 addition & 2 deletions src/e3response/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import pathlib
from typing import Optional

import hydra
from hydra.core import hydra_config
Expand Down Expand Up @@ -102,7 +101,7 @@ def train(cfg: omegaconf.DictConfig):
return metric_dict, object_dict


def main(cfg: omegaconf.DictConfig) -> Optional[float]:
def main(cfg: omegaconf.DictConfig) -> float | None:
"""Main entry point for training.

:param cfg: DictConfig configuration composed by Hydra.
Expand Down
4 changes: 2 additions & 2 deletions src/e3response/utils/logging_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any

import jax
from lightning_utilities.core.rank_zero import rank_zero_only
Expand All @@ -12,7 +12,7 @@


@rank_zero_only
def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
def log_hyperparameters(object_dict: dict[str, Any]) -> None:
"""Controls which config parts are saved by Lightning loggers.

Additionally, it saves:
Expand Down
6 changes: 3 additions & 3 deletions src/e3response/utils/pylogger.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Mapping
import logging
from typing import Mapping, Optional

from lightning_utilities.core import rank_zero

Expand All @@ -13,7 +13,7 @@ def __init__(
self,
name: str = __name__,
rank_zero_only: bool = False,
extra: Optional[Mapping[str, object]] = None,
extra: Mapping[str, object] | None = None,
) -> None:
"""Initializes a multi-GPU-friendly python command line logger that logs on all processes
with their rank prefixed in the log message.
Expand All @@ -28,7 +28,7 @@ def __init__(
super().__init__(logger=logger, extra=extra)
self.rank_zero_only = rank_zero_only

def log(self, level: int, msg: str, *args, rank: Optional[int] = None, **kwargs) -> None:
def log(self, level: int, msg: str, *args, rank: int | None = None, **kwargs) -> None:
"""Delegate a log call to the underlying logger, after prefixing its message with the rank
of the process it's being logged from. If `'rank'` is provided, then the log will only
occur on that rank/process.
Expand Down
2 changes: 1 addition & 1 deletion src/e3response/utils/rich_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Sequence
from pathlib import Path
from typing import Sequence

from hydra.core.hydra_config import HydraConfig
from lightning_utilities.core.rank_zero import rank_zero_only
Expand Down
5 changes: 3 additions & 2 deletions src/e3response/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Callable
from importlib.util import find_spec
from typing import Any, Callable, Optional
from typing import Any
import warnings

from omegaconf import DictConfig
Expand Down Expand Up @@ -99,7 +100,7 @@ def wrap(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]:
return wrap


def get_metric_value(metric_dict: dict[str, Any], metric_name: Optional[str]) -> Optional[float]:
def get_metric_value(metric_dict: dict[str, Any], metric_name: str | None) -> float | None:
"""Safely retrieves value of the metric logged in reax.Module.

:param metric_dict: A dict containing metric values.
Expand Down
Loading