in agents/bc_w.py [0:0]
def optimize_c(self, state, action_b):
action_pi = self.actor.sample(state).detach()
batch_size = state.shape[0]
alpha = torch.rand((batch_size, 1)).to(self.device)
a_intpl = (action_pi + alpha * (action_b - action_pi)).requires_grad_(True)
grads = torch.autograd.grad(outputs=self.critic(state, a_intpl).mean(), inputs=a_intpl, create_graph=True,
only_inputs=True)[0]
slope = (grads.square().sum(dim=-1) + EPS).sqrt()
gradient_penalty = torch.max(slope - 1.0, torch.zeros_like(slope)).square().mean()
logits_p = self.critic(state, action_pi)
logits_b = self.critic(state, action_b)
logits_diff = logits_p - logits_b
critic_loss = - logits_diff.mean() + gradient_penalty * self.w_gamma
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
return critic_loss.item()