def train()

in agents/ql_diffusion.py [0:0]


    def train(self, replay_buffer, iterations, batch_size=100):

        for step in range(iterations):
            # Sample replay buffer / batch
            state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)

            """ Q Training """
            current_q1, current_q2 = self.critic(state, action)

            if not self.max_q_backup:
                next_action = self.ema_model(next_state)
                target_q1, target_q2 = self.critic_target(next_state, next_action)
                target_q = torch.min(target_q1, target_q2)
            else:
                next_state_rpt = torch.repeat_interleave(next_state, repeats=10, dim=0)
                next_action_rpt = self.ema_model(next_state_rpt)
                target_q1, target_q2 = self.critic_target(next_state_rpt, next_action_rpt)
                target_q1 = target_q1.view(batch_size, 10).max(dim=1, keepdim=True)[0]
                target_q2 = target_q2.view(batch_size, 10).max(dim=1, keepdim=True)[0]
                target_q = torch.min(target_q1, target_q2)
            target_q = (reward + not_done * self.discount * target_q).detach()

            critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)

            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            self.critic_optimizer.step()

            """ Policy Training """
            bc_loss = self.actor.loss(action, state)

            if self.mode == 'whole_grad':
                new_action = self.actor(state)
            elif self.mode == 'last_few':
                new_action = self.actor.sample_last_few(state)

            q1_new_action, q2_new_action = self.critic(state, new_action)
            if np.random.uniform() > 0.5:
                lmbda = self.eta / q2_new_action.abs().mean().detach()
                q_loss = - lmbda * q1_new_action.mean()
            else:
                lmbda = self.eta / q1_new_action.abs().mean().detach()
                q_loss = - lmbda * q2_new_action.mean()
            # q_new_action = self.critic.q_min(state, new_action)
            # lmbda = self.eta / q_new_action.abs().mean().detach()
            # q_loss = - lmbda * q_new_action.mean()

            self.actor_optimizer.zero_grad()
            bc_loss.backward()
            q_loss.backward()
            self.actor_optimizer.step()
            self.actor.step_frozen()

            if self.step % self.update_ema_every == 0:
                self.step_ema()

            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

            self.step += 1
            if self.lr_decay: self.actor_lr_scheduler.step()