diff --git a/point_e/models/download.py b/point_e/models/download.py index b1760e1..a621735 100644 --- a/point_e/models/download.py +++ b/point_e/models/download.py @@ -8,6 +8,7 @@ import requests import torch +import shutil from filelock import FileLock from tqdm.auto import tqdm @@ -39,26 +40,36 @@ def fetch_file_cached( """ if cache_dir is None: cache_dir = default_cache_dir() + os.makedirs(cache_dir, exist_ok=True) local_path = os.path.join(cache_dir, url.split("/")[-1]) + if os.path.exists(local_path): return local_path response = requests.get(url, stream=True) size = int(response.headers.get("content-length", "0")) + with FileLock(local_path + ".lock"): if progress: pbar = tqdm(total=size, unit="iB", unit_scale=True) + tmp_path = local_path + ".tmp" + with open(tmp_path, "wb") as f: for chunk in response.iter_content(chunk_size): if progress: pbar.update(len(chunk)) f.write(chunk) - os.rename(tmp_path, local_path) - if progress: - pbar.close() - return local_path + + shutil.copyfile(tmp_path, local_path) + + os.remove(tmp_path) + + if progress: + pbar.close() + + return local_path def load_checkpoint(