diff --git a/README.md b/README.md index d9c46db..1d0ce68 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ python test.py ./confs/SRFlow_DF2K_4X.yml # Diverse Images 4X (Dataset Incl python test.py ./confs/SRFlow_DF2K_8X.yml # Diverse Images 8X (Dataset Included) python test.py ./confs/SRFlow_CelebA_8X.yml # Faces 8X ``` - +For testing, we apply SRFlow to the full images on CPU. # Our paper explains diff --git a/code/models/SRFlow_model.py b/code/models/SRFlow_model.py index 865df21..8d5683b 100644 --- a/code/models/SRFlow_model.py +++ b/code/models/SRFlow_model.py @@ -127,7 +127,7 @@ def get_z(self, heat, seed=None, batch_size=1, lr_shape=None, y_label=None): fac = 2 ** (L - 3) z_size = int(self.lr_size // (2 ** (L - 3))) z = torch.normal(mean=0, std=heat, size=(batch_size, 3 * 8 * 8 * fac * fac, z_size, z_size)) - return z + return z.to(self.device) def get_current_log(self): return self.log_dict diff --git a/requirements.txt b/requirements.txt index 073fbbe..3445516 100644 --- a/requirements.txt +++ b/requirements.txt @@ -38,7 +38,7 @@ nbconvert==6.0.7 nbformat==5.0.8 nest-asyncio==1.4.2 networkx==2.5 -notebook==6.1.4 +notebook==6.1.5 numpy==1.19.4 opencv-python==4.4.0.46 packaging==20.4