def guided_sample()

in agents/diffusion.py [0:0]


    def guided_sample(self, state, q_fun, start=0.2, verbose=False, return_diffusion=False):
        device = self.betas.device
        batch_size = state.shape[0]
        shape = (batch_size, self.action_dim)
        x = torch.randn(shape, device=device)
        i_start = self.n_timesteps * start

        if return_diffusion: diffusion = [x]

        def guided_p_sample(x, t, s, fun):
            b, *_, device = *x.shape, x.device
            with torch.no_grad():
                model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, s=s)
            noise = torch.randn_like(x)
            # no noise when t == 0
            nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))

            # Involve Function Guidance
            a = model_mean.clone().requires_grad_(True)
            q_value = fun(s, a)
            # q_value = q_value / q_value.abs().mean().detach()  # normalize q
            grads = torch.autograd.grad(outputs=q_value, inputs=a, create_graph=True, only_inputs=True)[0].detach()
            return (model_mean + model_log_variance * grads) + nonzero_mask * (0.5 * model_log_variance).exp() * noise

        progress = Progress(self.n_timesteps) if verbose else Silent()
        for i in reversed(range(0, self.n_timesteps)):
            timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
            if i <= i_start:
                x = guided_p_sample(x, timesteps, state, q_fun)
            else:
                with torch.no_grad():
                    x = self.p_sample(x, timesteps, state)

            progress.update({'t': i})

            if return_diffusion: diffusion.append(x)

        progress.close()

        x = x.clamp_(-self.max_action, self.max_action)

        if return_diffusion:
            return x, torch.stack(diffusion, dim=1).clamp_(-self.max_action, self.max_action)
        else:
            return x