From 8e16f39fc267bd9bae5ad196a90e62714630b65c Mon Sep 17 00:00:00 2001 From: Nix Goldowsky-Dill Date: Tue, 23 Apr 2024 16:40:03 +0000 Subject: [PATCH] use tempfile lib for wandb download --- e2e_sae/models/transformers.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/e2e_sae/models/transformers.py b/e2e_sae/models/transformers.py index 71f2059..c454d61 100644 --- a/e2e_sae/models/transformers.py +++ b/e2e_sae/models/transformers.py @@ -1,3 +1,4 @@ +import tempfile from functools import partial from pathlib import Path from typing import Any, Literal, cast @@ -376,22 +377,23 @@ def from_wandb(cls, wandb_project_run_id: str) -> "SAETransformer": file for file in run.files() if file.name.endswith("final_config.yaml") ][0] - train_config_file = train_config_file_remote.download( - exist_ok=True, replace=True, root="/tmp/" - ).name - checkpoints = [file for file in run.files() if file.name.endswith(".pt")] latest_checkpoint_remote = sorted( checkpoints, key=lambda x: int(x.name.split(".pt")[0].split("_")[-1]) )[-1] - latest_checkpoint_file = latest_checkpoint_remote.download( - exist_ok=True, replace=True, root="/tmp/" - ).name - assert latest_checkpoint_file is not None, "Failed to download the latest checkpoint." - return cls.from_checkpoint( - checkpoint_file=latest_checkpoint_file, config_file=train_config_file - ) + with tempfile.TemporaryDirectory() as temp_dir: + train_config_file = train_config_file_remote.download( + exist_ok=True, replace=True, root=temp_dir + ).name + latest_checkpoint_file = latest_checkpoint_remote.download( + exist_ok=True, replace=True, root=temp_dir + ).name + assert latest_checkpoint_file is not None, "Failed to download the latest checkpoint." + + return cls.from_checkpoint( + checkpoint_file=latest_checkpoint_file, config_file=train_config_file + ) @classmethod def from_checkpoint(