diff --git a/cycling_utils/saving.py b/cycling_utils/saving.py index 495b214..93e0c28 100644 --- a/cycling_utils/saving.py +++ b/cycling_utils/saving.py @@ -256,7 +256,7 @@ def prepare_checkpoint_directory(self, force_save=False): if self.strategy in ["sync_any", "sync_all"]: global_force = torch.tensor( 1 if force_save else 0, - dtype=torch.int16, + dtype=torch.int64, requires_grad=False, device="cuda", )