def optimize_c()

in agents/bc_w.py [0:0]


    def optimize_c(self, state, action_b):
        action_pi = self.actor.sample(state).detach()

        batch_size = state.shape[0]
        alpha = torch.rand((batch_size, 1)).to(self.device)
        a_intpl = (action_pi + alpha * (action_b - action_pi)).requires_grad_(True)
        grads = torch.autograd.grad(outputs=self.critic(state, a_intpl).mean(), inputs=a_intpl, create_graph=True,
                                    only_inputs=True)[0]
        slope = (grads.square().sum(dim=-1) + EPS).sqrt()
        gradient_penalty = torch.max(slope - 1.0, torch.zeros_like(slope)).square().mean()

        logits_p = self.critic(state, action_pi)
        logits_b = self.critic(state, action_b)
        logits_diff = logits_p - logits_b
        critic_loss = - logits_diff.mean() + gradient_penalty * self.w_gamma

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        return critic_loss.item()