From 1aa59e671cc62af523c9b9cdf1f5aade92b76248 Mon Sep 17 00:00:00 2001 From: Benjamin Bennett <125957+benbennett@users.noreply.github.com> Date: Tue, 23 Sep 2025 23:01:31 -0500 Subject: [PATCH] Fix DiscPool dtype mismatch --- util/image_pool.py | 223 +++++++++++++++++++++++---------------------- 1 file changed, 114 insertions(+), 109 deletions(-) diff --git a/util/image_pool.py b/util/image_pool.py index 6a2af7c..1805b34 100644 --- a/util/image_pool.py +++ b/util/image_pool.py @@ -1,109 +1,114 @@ -import random -import torch -from torch.utils.data import Dataset - - -class ImagePool(): - """This class implements an image buffer that stores previously generated images. - - This buffer enables us to update discriminators using a history of generated images - rather than the ones produced by the latest generators. - """ - - def __init__(self, pool_size): - """Initialize the ImagePool class - - Parameters: - pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created - """ - self.pool_size = pool_size - if self.pool_size > 0: # create an empty pool - self.num_imgs = 0 - self.images = [] - - def query(self, images): - """Return an image from the pool. - - Parameters: - images: the latest generated images from the generator - - Returns images from the buffer. - - By 50/100, the buffer will return input images. - By 50/100, the buffer will return images previously stored in the buffer, - and insert the current images to the buffer. - """ - if self.pool_size == 0: # if the buffer size is 0, do nothing - return images - return_images = [] - for image in images: - image = torch.unsqueeze(image.data, 0) - if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer - self.num_imgs = self.num_imgs + 1 - self.images.append(image) - return_images.append(image) - else: - p = random.uniform(0, 1) - if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer - random_id = random.randint(0, self.pool_size - 1) # randint is inclusive - tmp = self.images[random_id].clone() - self.images[random_id] = image - return_images.append(tmp) - else: # by another 50% chance, the buffer will return the current image - return_images.append(image) - return_images = torch.cat(return_images, 0) # collect all the images and return - return return_images - - -class DiscPool(Dataset): - """This class implements a buffer that stores the previous discriminator map for each image in the dataset. - - This buffer enables us to recall the outputs of the discriminator in the previous epoch - """ - - def __init__(self, opt, device, dataset, isTrain=True, disc_out_size=30): - """Initialize the DiscPool class - - Parameters: - opt: stores all the experiment flags; needs to be a subclass of BaseOptions - device: the device used - isTrain: whether this class is instanced during the train or test phase - disc_out_size: the size of the ouput tensor of the discriminator - """ - self.dataset_len = len(dataset) - - if isTrain: - # Initially the discriminator doesn't know real/fake because is not trained yet - self.disc_out = torch.rand((self.dataset_len, 1, disc_out_size, disc_out_size), dtype=torch.float32) - # self.disc_out = torch.rand((self.dataset_len, 1, 1, disc_out_size, disc_out_size), dtype=torch.float32) #TODO: need to understand why there is missing 1 dimension - else: - # At the end, the discriminator is expected to be near its maximum entropy state (D_i = 1/2) - self.disc_out = 0.5 + 0.001 * torch.randn((self.dataset_len, 1, disc_out_size, disc_out_size), - dtype=torch.float32) - - self.disc_out = self.disc_out.to(device) - - def __getitem__(self, _): - raise NotImplementedError('DiscPool does not support this operation') - - def __len__(self): - return self.dataset_len - - def query(self, img_idx): - """Return the last discriminator map from the pool, corresponding to given image indices. - - Parameters: - img_idx: indices of the images that the discriminator just processed - - Returns discriminator map from the buffer. - """ - return self.disc_out[img_idx] - - def insert(self, disc_out, img_idx): - """Insert the last discriminator map in the pool, corresponding to given image index. - - Parameters: - disc_out: output from the discriminator in the backward pass of generator - img_idx: indices of the images that the discriminator just processed - """ - self.disc_out[img_idx] = disc_out \ No newline at end of file +import random +import torch +from torch.utils.data import Dataset + + +class ImagePool(): + """This class implements an image buffer that stores previously generated images. + + This buffer enables us to update discriminators using a history of generated images + rather than the ones produced by the latest generators. + """ + + def __init__(self, pool_size): + """Initialize the ImagePool class + + Parameters: + pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created + """ + self.pool_size = pool_size + if self.pool_size > 0: # create an empty pool + self.num_imgs = 0 + self.images = [] + + def query(self, images): + """Return an image from the pool. + + Parameters: + images: the latest generated images from the generator + + Returns images from the buffer. + + By 50/100, the buffer will return input images. + By 50/100, the buffer will return images previously stored in the buffer, + and insert the current images to the buffer. + """ + if self.pool_size == 0: # if the buffer size is 0, do nothing + return images + return_images = [] + for image in images: + image = torch.unsqueeze(image.data, 0) + if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer + self.num_imgs = self.num_imgs + 1 + self.images.append(image) + return_images.append(image) + else: + p = random.uniform(0, 1) + if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer + random_id = random.randint(0, self.pool_size - 1) # randint is inclusive + tmp = self.images[random_id].clone() + self.images[random_id] = image + return_images.append(tmp) + else: # by another 50% chance, the buffer will return the current image + return_images.append(image) + return_images = torch.cat(return_images, 0) # collect all the images and return + return return_images + + +class DiscPool(Dataset): + """This class implements a buffer that stores the previous discriminator map for each image in the dataset. + + This buffer enables us to recall the outputs of the discriminator in the previous epoch + """ + + def __init__(self, opt, device, dataset, isTrain=True, disc_out_size=30): + """Initialize the DiscPool class + + Parameters: + opt: stores all the experiment flags; needs to be a subclass of BaseOptions + device: the device used + isTrain: whether this class is instanced during the train or test phase + disc_out_size: the size of the ouput tensor of the discriminator + """ + self.dataset_len = len(dataset) + + if isTrain: + # Initially the discriminator doesn't know real/fake because is not trained yet + self.disc_out = torch.rand((self.dataset_len, 1, disc_out_size, disc_out_size), dtype=torch.float32) + # self.disc_out = torch.rand((self.dataset_len, 1, 1, disc_out_size, disc_out_size), dtype=torch.float32) #TODO: need to understand why there is missing 1 dimension + else: + # At the end, the discriminator is expected to be near its maximum entropy state (D_i = 1/2) + self.disc_out = 0.5 + 0.001 * torch.randn((self.dataset_len, 1, disc_out_size, disc_out_size), + dtype=torch.float32) + + self.disc_out = self.disc_out.to(device) + + def __getitem__(self, _): + raise NotImplementedError('DiscPool does not support this operation') + + def __len__(self): + return self.dataset_len + + def query(self, img_idx): + """Return the last discriminator map from the pool, corresponding to given image indices. + + Parameters: + img_idx: indices of the images that the discriminator just processed + + Returns discriminator map from the buffer. + """ + return self.disc_out[img_idx] + + def insert(self, disc_out, img_idx): + """Insert the last discriminator map in the pool, corresponding to given image index. + + Parameters: + disc_out: output from the discriminator in the backward pass of generator + img_idx: indices of the images that the discriminator just processed + """ + # torch.amp.autocast may produce half-precision discriminator outputs while the + # pool was initialized with float32 tensors. Mixing these dtypes triggers a + # runtime error when updating the buffer, so we ensure both device and dtype + # match the pool storage before the assignment. + disc_out = disc_out.to(self.disc_out) + self.disc_out[img_idx] = disc_out