diff --git a/batchflow/models/torch/base.py b/batchflow/models/torch/base.py index e0d1f47e3..2334fd927 100755 --- a/batchflow/models/torch/base.py +++ b/batchflow/models/torch/base.py @@ -683,7 +683,7 @@ def make_infrastructure(self): self.make_loss() self.make_optimizer() self.make_decay() - self.scaler = torch.cuda.amp.GradScaler() + self.scaler = torch.GradScaler("cuda") self.setup_gradient_clipping() self.setup_weights_averaging() diff --git a/batchflow/plotter/plot.py b/batchflow/plotter/plot.py index 653196104..d05d3fb3d 100644 --- a/batchflow/plotter/plot.py +++ b/batchflow/plotter/plot.py @@ -855,10 +855,10 @@ def clear(self): self.annotations = {} - self.ax.clear() for layer in self.layers: for obj in layer.objects: obj.remove() + self.ax.clear() self.layers = []