From 99bc28dd835fa6a7d81ead1a9c789701b84d1e4a Mon Sep 17 00:00:00 2001 From: StrongAdam Date: Mon, 12 May 2025 14:18:02 +1000 Subject: [PATCH] fix: dtype --- cycling_utils/saving.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cycling_utils/saving.py b/cycling_utils/saving.py index 17d994a..495b214 100644 --- a/cycling_utils/saving.py +++ b/cycling_utils/saving.py @@ -129,10 +129,10 @@ def __init__( {os.environ['RANK']} was passed '{strategy}'." local_strategy_tensor = torch.tensor( - strategy_int, requires_grad=False, device="cuda" + strategy_int, dtype=torch.int64, requires_grad=False, device="cuda" ) global_strategy_list = [ - torch.zeros(1, requires_grad=False, device="cuda") + torch.zeros(1, dtype=torch.int64, requires_grad=False, device="cuda") for _ in range(int(os.environ["WORLD_SIZE"])) ] all_gather(global_strategy_list, local_strategy_tensor)