-
Notifications
You must be signed in to change notification settings - Fork 8
Description
Hi GenPercept authors, thanks for releasing the code and models.
I’m trying to understand the “single-step diffusion” setting described in the paper / docs, but I’m confused by the training implementation vs. the provided configs.
In the training code, timesteps are sampled like this:
if 'fix_timesteps' in self.cfg.model.keys():
timesteps = torch.tensor([self.cfg.model.fix_timesteps]).long().repeat(rgb.shape[0]).to(self.unet.device)
else:
timesteps = torch.randint(
0,
self.scheduler_timesteps,
(batch_size,),
device=device,
generator=rand_num_generator,
).long()
From my understanding, if the method is truly “single-step diffusion”, the timestep should be fixed (e.g., always using a specific t), which seems to be supported by model.fix_timesteps.However, in the README / released configs, I don’t see fix_timesteps being set, which would make the training sample random timesteps instead.
Also, I noticed you provide an ablation config that explicitly fixes the timestep (i.e., sets model.fix_timesteps). This made me wonder about the intended default behavior:
I wonder that "Is random timestep sampling (no fix_timesteps) expected to perform better than a fixed timestep in the main method?"Thanks for your time!