def train()

in agents/adw_bc_diffusion.py [0:0]


    def train(self, replay_buffer, iterations, batch_size=100):

        for it in range(iterations):
            # Sample replay buffer / batch
            state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)

            # Value Function Training
            with torch.no_grad():
                q1, q2 = self.critic_target(state, action)
                q = torch.min(q1, q2)  # Clipped Double Q-learning
            v = self.value_fun(state)
            value_loss = expectile_reg_loss(q - v, self.quantile).mean()
            self.value_optimizer.zero_grad()
            value_loss.backward()
            self.value_optimizer.step()

            # Critic Training
            current_q1, current_q2 = self.critic(state, action)
            target_q = (reward + not_done * self.discount * self.value_fun(next_state)).detach()
            critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            self.critic_optimizer.step()

            # Policy Training
            v = self.value_fun(state)
            weight = torch.exp((q - v) / q.abs().mean() * self.temp).clamp_max(100.0).detach()
            loss = self.actor.loss(action, state, weight)

            self.actor_optimizer.zero_grad()
            loss.backward()
            self.actor_optimizer.step()

            # Update Target Networks
            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

        # Logging
        logger.record_tabular('ADW_BC Loss', loss.item())
        logger.record_tabular('Critic Loss', critic_loss.item())