diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4092b8c..4e24e8b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ default_language_version: python: python3 repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.12.4 + rev: v0.15.0 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] @@ -18,11 +18,11 @@ repos: args: [--in-place, --black, --style=epytext] - repo: https://github.com/executablebooks/mdformat - rev: 0.7.22 + rev: 1.0.0 hooks: - id: mdformat additional_dependencies: - - mdformat-gfm==0.3.6 + - mdformat-gfm - repo: https://github.com/ComPWA/taplo-pre-commit rev: v0.9.3 @@ -31,7 +31,7 @@ repos: - id: taplo-format - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 + rev: v6.0.0 hooks: - id: trailing-whitespace - id: check-docstring-first diff --git a/modelconverter/__main__.py b/modelconverter/__main__.py index aa82528..b7e6f95 100644 --- a/modelconverter/__main__.py +++ b/modelconverter/__main__.py @@ -45,7 +45,9 @@ app = App( name="Modelconverter", - version=lambda: f"ModelConverter v{importlib.metadata.version('modelconv')}", + version=lambda: ( + f"ModelConverter v{importlib.metadata.version('modelconv')}" + ), ) app.meta.group_parameters = Group("Global Parameters", sort_key=0) diff --git a/modelconverter/packages/rvc2/benchmark.py b/modelconverter/packages/rvc2/benchmark.py index 7c8f18e..e271d07 100644 --- a/modelconverter/packages/rvc2/benchmark.py +++ b/modelconverter/packages/rvc2/benchmark.py @@ -55,9 +55,7 @@ def _benchmark( model_path, platform=device.getPlatformAsString(), ), - apiKey=environ.HUBAI_API_KEY - if environ.HUBAI_API_KEY - else "", + apiKey=environ.HUBAI_API_KEY or "", ) ) elif ( diff --git a/modelconverter/packages/rvc4/analyze.py b/modelconverter/packages/rvc4/analyze.py index 80d3c94..d5c397a 100644 --- a/modelconverter/packages/rvc4/analyze.py +++ b/modelconverter/packages/rvc4/analyze.py @@ -435,9 +435,7 @@ def _process_diagview_csv(self, csv_path: str) -> None: pl.col("layer_name") .str.split(":") .list.first() - .map_elements( - lambda x: self._replace_bad_layer_name(x), return_dtype=pl.Utf8 - ) + .map_elements(self._replace_bad_layer_name, return_dtype=pl.Utf8) .alias("layer_name"), pl.col("time_mean") .mul(1 / total_time) diff --git a/modelconverter/packages/rvc4/benchmark.py b/modelconverter/packages/rvc4/benchmark.py index 907b0eb..16fa56f 100644 --- a/modelconverter/packages/rvc4/benchmark.py +++ b/modelconverter/packages/rvc4/benchmark.py @@ -334,7 +334,7 @@ def _benchmark_snpe( model_path, platform=dai.Platform.RVC4.name, ), - apiKey=environ.HUBAI_API_KEY if environ.HUBAI_API_KEY else "", + apiKey=environ.HUBAI_API_KEY or "", ) tmp_dir = Path(model_archive).parent / "tmp" shutil.unpack_archive(model_archive, tmp_dir) @@ -413,7 +413,7 @@ def _benchmark_dai( model_path, platform=device.getPlatformAsString(), ), - apiKey=environ.HUBAI_API_KEY if environ.HUBAI_API_KEY else "", + apiKey=environ.HUBAI_API_KEY or "", ) elif str(model_path).endswith(".tar.xz"): modelPath = str(model_path) diff --git a/modelconverter/packages/rvc4/visualize.py b/modelconverter/packages/rvc4/visualize.py index 9a2d487..1ce649d 100644 --- a/modelconverter/packages/rvc4/visualize.py +++ b/modelconverter/packages/rvc4/visualize.py @@ -223,7 +223,7 @@ def _visualize_layer_outputs(self) -> go.Figure: def _get_csv_paths( self, dir_path: Path, comparison_type: str = "layer_comparison" ) -> dict[str, str]: - dir_path = dir_path if dir_path else constants.OUTPUTS_DIR / "analysis" + dir_path = dir_path or constants.OUTPUTS_DIR / "analysis" csv_paths = {} for file in dir_path.glob(f"*{comparison_type}*.csv"): diff --git a/modelconverter/utils/docker_utils.py b/modelconverter/utils/docker_utils.py index 5205936..2673f36 100644 --- a/modelconverter/utils/docker_utils.py +++ b/modelconverter/utils/docker_utils.py @@ -12,12 +12,12 @@ import psutil import yaml +from docker.utils import parse_repository_tag from loguru import logger from luxonis_ml.utils import environ from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn import docker -from docker.utils import parse_repository_tag def get_docker_client_from_active_context() -> docker.DockerClient: diff --git a/modelconverter/utils/onnx_tools.py b/modelconverter/utils/onnx_tools.py index 55d120d..1d916d3 100644 --- a/modelconverter/utils/onnx_tools.py +++ b/modelconverter/utils/onnx_tools.py @@ -27,7 +27,9 @@ def get_extra_quant_tensors( output_configs: dict[str, OutputConfig], depth: int = 2, ) -> list[str]: - """Return unique tensor names that are inputs to producer nodes encountered when walking upstream from the selected graph outputs, up to depth producer hops. + """Return unique tensor names that are inputs to producer nodes + encountered when walking upstream from the selected graph outputs, + up to depth producer hops. - Starts from graph outputs whose names are keys in output_configs. - At each hop: for each tensor, find its producing node; add that node's @@ -1398,21 +1400,21 @@ def compare_outputs(self, from_modelproto: bool = False) -> bool: inputs = {} for input in ort_session_1.get_inputs(): - if input.type in ["tensor(float64)"]: + if input.type == "tensor(float64)": input_type = np.float64 - elif input.type in ["tensor(float32)", "tensor(float)"]: + elif input.type in {"tensor(float32)", "tensor(float)"}: input_type = np.float32 - elif input.type in ["tensor(float16)"]: + elif input.type == "tensor(float16)": input_type = np.float16 - elif input.type in ["tensor(int64)"]: + elif input.type == "tensor(int64)": input_type = np.int64 - elif input.type in ["tensor(int32)"]: + elif input.type == "tensor(int32)": input_type = np.int32 - elif input.type in ["tensor(int16)"]: + elif input.type == "tensor(int16)": input_type = np.int16 - elif input.type in ["tensor(int8)"]: + elif input.type == "tensor(int8)": input_type = np.int8 - elif input.type in ["tensor(bool)"]: + elif input.type == "tensor(bool)": input_type = "bool" inputs[input.name] = np.random.rand(*input.shape).astype( diff --git a/modelconverter/utils/subprocess.py b/modelconverter/utils/subprocess.py index b22bcb7..339af50 100644 --- a/modelconverter/utils/subprocess.py +++ b/modelconverter/utils/subprocess.py @@ -11,6 +11,7 @@ import psutil from loguru import logger +from typing_extensions import Self from .exceptions import SubprocessException @@ -161,7 +162,7 @@ def wait(self, timeout: float | None = None) -> int: """ return self.proc.wait(timeout=self.timeout or timeout) - def __enter__(self) -> "SubprocessHandle": + def __enter__(self) -> Self: if shutil.which(self.cmd_name) is None: raise SubprocessException( f"Command `{self.cmd_name}` not found. Ensure it is in PATH."