diff --git a/01_Diffusion_Models_Tutorial/Diffusion Model.ipynb b/01_Diffusion_Models_Tutorial/Diffusion Model.ipynb index 8078f38..4c32132 100644 --- a/01_Diffusion_Models_Tutorial/Diffusion Model.ipynb +++ b/01_Diffusion_Models_Tutorial/Diffusion Model.ipynb @@ -907,6 +907,7 @@ " t = torch.full((1,), i, dtype=torch.long, device=device)\n", " labels = torch.tensor([c] * NUM_DISPLAY_IMAGES).resize(NUM_DISPLAY_IMAGES, 1).float().to(device)\n", " imgs = diffusion_model.backward(x=imgs, t=t, model=unet.eval().to(device), labels = labels)\n", + " imgs = imgs.clamp(-1, 1)\n", " for idx, img in enumerate(imgs):\n", " ax[c][idx].imshow(reverse_transform(img))\n", " ax[c][idx].set_title(f\"Class: {classes[c]}\", fontsize = 100)\n",