in agents/diffusion.py [0:0]
def p_sample_loop(self, state, shape, verbose=False, return_diffusion=False):
device = self.betas.device
batch_size = shape[0]
x = torch.randn(shape, device=device)
if return_diffusion: diffusion = [x]
progress = Progress(self.n_timesteps) if verbose else Silent()
for i in reversed(range(0, self.n_timesteps)):
timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
x = self.p_sample(x, timesteps, state)
progress.update({'t': i})
if return_diffusion: diffusion.append(x)
progress.close()
if return_diffusion:
return x, torch.stack(diffusion, dim=1)
else:
return x