in agents/diffusion.py [0:0]
def guided_sample(self, state, q_fun, start=0.2, verbose=False, return_diffusion=False):
device = self.betas.device
batch_size = state.shape[0]
shape = (batch_size, self.action_dim)
x = torch.randn(shape, device=device)
i_start = self.n_timesteps * start
if return_diffusion: diffusion = [x]
def guided_p_sample(x, t, s, fun):
b, *_, device = *x.shape, x.device
with torch.no_grad():
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, s=s)
noise = torch.randn_like(x)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
# Involve Function Guidance
a = model_mean.clone().requires_grad_(True)
q_value = fun(s, a)
# q_value = q_value / q_value.abs().mean().detach() # normalize q
grads = torch.autograd.grad(outputs=q_value, inputs=a, create_graph=True, only_inputs=True)[0].detach()
return (model_mean + model_log_variance * grads) + nonzero_mask * (0.5 * model_log_variance).exp() * noise
progress = Progress(self.n_timesteps) if verbose else Silent()
for i in reversed(range(0, self.n_timesteps)):
timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
if i <= i_start:
x = guided_p_sample(x, timesteps, state, q_fun)
else:
with torch.no_grad():
x = self.p_sample(x, timesteps, state)
progress.update({'t': i})
if return_diffusion: diffusion.append(x)
progress.close()
x = x.clamp_(-self.max_action, self.max_action)
if return_diffusion:
return x, torch.stack(diffusion, dim=1).clamp_(-self.max_action, self.max_action)
else:
return x