def train()

in agents/ed_pcq.py [0:0]


    def train(self, replay_buffer, iterations, batch_size=100):

        for step in range(iterations):
            # Sample replay buffer / batch
            state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)

            """ Q Training """
            current_qs = self.critic(state, action)

            if not self.max_q_backup:
                next_action = self.ema_model(next_state)
                target_q = self.critic_target.sample(next_state, next_action)
            else:
                next_state_rpt = torch.repeat_interleave(next_state, repeats=10, dim=0)
                next_action_rpt = self.ema_model(next_state_rpt)
                target_q = self.critic_target.sample(next_state_rpt, next_action_rpt)
                target_q = target_q.view(batch_size, 10).max(dim=1, keepdim=True)[0]

            target_q = (reward + not_done * self.discount * target_q).detach().unsqueeze(0)

            critic_loss = F.mse_loss(current_qs, target_q, reduction='none')
            critic_loss = critic_loss.mean(dim=(1, 2)).sum()

            if self.q_eta > 0:
                state_tile = state.unsqueeze(0).repeat(self.num_qs, 1, 1)
                action_tile = action.unsqueeze(0).repeat(self.num_qs, 1, 1).requires_grad_(True)
                qs_preds_tile = self.critic(state_tile, action_tile)
                qs_pred_grads, = torch.autograd.grad(qs_preds_tile.sum(), action_tile, retain_graph=True,
                                                     create_graph=True)
                qs_pred_grads = qs_pred_grads / (torch.norm(qs_pred_grads, p=2, dim=2).unsqueeze(-1) + 1e-10)
                qs_pred_grads = qs_pred_grads.transpose(0, 1)

                qs_pred_grads = torch.einsum('bik,bjk->bij', qs_pred_grads, qs_pred_grads)
                masks = torch.eye(self.num_qs, device=self.device).unsqueeze(dim=0).repeat(qs_pred_grads.size(0), 1, 1)
                qs_pred_grads = (1 - masks) * qs_pred_grads
                grad_loss = torch.mean(torch.sum(qs_pred_grads, dim=(1, 2))) / (self.num_qs - 1)

                critic_loss += self.q_eta * grad_loss

            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            self.critic_optimizer.step()

            """ Policy Training """
            bc_loss = self.actor.loss(action, state)

            new_action = self.actor(state)
            q_new_action = self.critic.sample(state, new_action)
            lmbda = self.eta / q_new_action.abs().mean().detach()
            q_loss = - lmbda * q_new_action.mean()

            self.actor_optimizer.zero_grad()
            bc_loss.backward()
            q_loss.backward()
            self.actor_optimizer.step()

            if self.step % self.update_ema_every == 0:
                self.step_ema()

            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

            self.step += 1

        # Logging
        logger.record_tabular('BC Loss', bc_loss.item())
        logger.record_tabular('QL Loss', q_loss.item())
        logger.record_tabular('Critic Loss', critic_loss.item())
        logger.record_tabular('ED Loss', grad_loss.item())
        logger.record_tabular('Target_Q Mean', target_q.mean().item())