Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions e2e_sae/models/transformers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import tempfile
from functools import partial
from pathlib import Path
from typing import Any, Literal, cast
Expand Down Expand Up @@ -376,22 +377,23 @@ def from_wandb(cls, wandb_project_run_id: str) -> "SAETransformer":
file for file in run.files() if file.name.endswith("final_config.yaml")
][0]

train_config_file = train_config_file_remote.download(
exist_ok=True, replace=True, root="/tmp/"
).name

checkpoints = [file for file in run.files() if file.name.endswith(".pt")]
latest_checkpoint_remote = sorted(
checkpoints, key=lambda x: int(x.name.split(".pt")[0].split("_")[-1])
)[-1]
latest_checkpoint_file = latest_checkpoint_remote.download(
exist_ok=True, replace=True, root="/tmp/"
).name
assert latest_checkpoint_file is not None, "Failed to download the latest checkpoint."

return cls.from_checkpoint(
checkpoint_file=latest_checkpoint_file, config_file=train_config_file
)
with tempfile.TemporaryDirectory() as temp_dir:
train_config_file = train_config_file_remote.download(
exist_ok=True, replace=True, root=temp_dir
).name
latest_checkpoint_file = latest_checkpoint_remote.download(
exist_ok=True, replace=True, root=temp_dir
).name
assert latest_checkpoint_file is not None, "Failed to download the latest checkpoint."

return cls.from_checkpoint(
checkpoint_file=latest_checkpoint_file, config_file=train_config_file
)

@classmethod
def from_checkpoint(
Expand Down