diff --git a/batchflow/models/torch/base.py b/batchflow/models/torch/base.py index 6196517e9..17826698a 100755 --- a/batchflow/models/torch/base.py +++ b/batchflow/models/torch/base.py @@ -815,8 +815,9 @@ def setup_gradient_clipping(self): raise ValueError(f'gradient_clipping must be int, float or callable but it is{type(gradient_clipping)}') for p in self.model.parameters(): - hook = p.register_hook(function) - self._hooks.append(hook) + if p.requires_grad: + hook = p.register_hook(function) + self._hooks.append(hook) def setup_weights_averaging(self): """ Prepare WA-model: check all required keys and store copy on CPU. """