diff --git a/torchbenchmark/models/moco/moco/builder.py b/torchbenchmark/models/moco/moco/builder.py index 7d80fe996c..295e22a7e7 100644 --- a/torchbenchmark/models/moco/moco/builder.py +++ b/torchbenchmark/models/moco/moco/builder.py @@ -47,7 +47,7 @@ def _momentum_update_key_encoder(self): Momentum update of the key encoder """ for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): - param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) + param_k.mul_(self.m).add_(param_q.mul(1. - self.m)) @torch.no_grad() def _dequeue_and_enqueue(self, keys):