diff --git a/.gitignore b/.gitignore index 8015bd0..099e661 100644 --- a/.gitignore +++ b/.gitignore @@ -147,4 +147,6 @@ shared_with_container/configs/_* tests/data/test_packages tests/data/test_utils/hubai_models +.build + .DS_Store diff --git a/modelconverter/utils/docker_utils.py b/modelconverter/utils/docker_utils.py index 2673f36..a332785 100644 --- a/modelconverter/utils/docker_utils.py +++ b/modelconverter/utils/docker_utils.py @@ -4,6 +4,7 @@ import subprocess import sys import tempfile +import zipfile from pathlib import Path from typing import Literal from urllib.error import HTTPError, URLError @@ -11,6 +12,7 @@ from urllib.request import Request, urlopen import psutil +import semver import yaml from docker.utils import parse_repository_tag from loguru import logger @@ -18,6 +20,7 @@ from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn import docker +from modelconverter import __version__ def get_docker_client_from_active_context() -> docker.DockerClient: @@ -151,7 +154,9 @@ def docker_build( tag_version = rvc4_tag_version(version) if target == "rvc4" else version if target == "rvc4": - ensure_snpe_archive(version) + build_dir = prepare_build_environemnt(target, version) + else: + build_dir = Path() tag = f"{tag_version}-{bare_tag}" @@ -166,11 +171,11 @@ def docker_build( docker_bin(), "build", "-f", - f"docker/{target}/Dockerfile", + str(build_dir / "docker" / target / "Dockerfile"), "-t", image, "--load", - ".", + str(build_dir), ] if version is not None: args += ["--build-arg", f"VERSION={version}"] @@ -180,8 +185,43 @@ def docker_build( return image -def ensure_snpe_archive(version: str) -> Path: - archive_path = Path("docker/extra_packages") / f"snpe-{version}.zip" +def prepare_build_environemnt( + target: Literal["rvc2", "rvc3", "rvc4", "hailo"], version: str +) -> Path: + if target != "rvc4": + raise NotImplementedError( + "Fully automatic docker build is only implemented for RVC4" + ) + + build_path = Path(".build", target) + build_path.mkdir(parents=True, exist_ok=True) + _download_file( + f"https://github.com/luxonis/modelconverter/archive/refs/tags/v{__version__}-beta.zip", + build_path / f"modelconverter-{__version__}-beta.zip", + fallback_url="https://github.com/luxonis/modelconverter/archive/refs/heads/main.zip", + ) + with zipfile.ZipFile( + build_path / f"modelconverter-{__version__}-beta.zip" + ) as z: + z.extractall(build_path) + + if (p := Path("docker", "extra_packages", f"snpe-{version}.zip")).exists(): + shutil.copy(p, build_path / f"modelconverter-{__version__}-beta" / p) + + elif not (build_path / f"modelconverter-{__version__}-beta" / p).exists(): + download_snpe_archive( + version, + build_path + / f"modelconverter-{__version__}-beta" + / "docker" + / "extra_packages", + ) + + return build_path / f"modelconverter-{__version__}-beta" + + +def download_snpe_archive(version: str, dest: Path) -> Path: + archive_path = dest / f"snpe-{semver.finalize_version(version)}.zip" if archive_path.exists(): return archive_path @@ -197,19 +237,22 @@ def ensure_snpe_archive(version: str) -> Path: ) try: _download_file(url, archive_path) - except (HTTPError, URLError, RuntimeError) as exc: + except (HTTPError, URLError, RuntimeError) as e: msg = ( - f"Failed to download SNPE archive from {url}: {exc}. " + f"Failed to download SNPE archive from {url}: {e}. " "Download it manually from " "https://softwarecenter.qualcomm.com/catalog/item/" "Qualcomm_AI_Runtime_Community and save it as " f"{archive_path}." ) - raise RuntimeError(msg) from exc + raise RuntimeError(msg) from e + return archive_path -def _download_file(url: str, dest: Path) -> None: +def _download_file( + url: str, dest: Path, *, fallback_url: str | None = None +) -> None: parsed = urlparse(url) if parsed.scheme != "https": raise RuntimeError(f"Refusing to download from non-HTTPS URL: {url}") @@ -223,7 +266,7 @@ def _download_file(url: str, dest: Path) -> None: f"HTTP {response.status} while downloading {url}" ) total = None - getheader = getattr(response, "getheader", None) + getheader: str | None = getattr(response, "getheader", None) if callable(getheader): length = getheader("Content-Length") if length and length.isdigit(): @@ -248,6 +291,14 @@ def _download_file(url: str, dest: Path) -> None: tmp_file.write(chunk) progress.update(task, advance=len(chunk)) tmp_path.replace(dest) + except Exception as e: + if fallback_url: + logger.warning( + f"Failed to download from {url}: {e}. Attempting fallback URL {fallback_url}..." + ) + _download_file(fallback_url, dest) + else: + raise finally: if tmp_path is not None and tmp_path.exists() and not dest.exists(): tmp_path.unlink(missing_ok=True)