in agents/diffusion.py [0:0]
def p_losses(self, x_start, state, t, weights=1.0):
noise = torch.randn_like(x_start)
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
x_recon = self.model(x_noisy, t, state)
assert noise.shape == x_recon.shape
if self.predict_epsilon:
loss = self.loss_fn(x_recon, noise, weights)
else:
loss = self.loss_fn(x_recon, x_start, weights)
return loss