def train()

in agents/bc_gan2.py [0:0]


    def train(self, replay_buffer, iterations, batch_size=100):

        self.actor.train()

        for it in range(iterations):
            # Sample replay buffer / batch
            state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)

            """
            Generator Training
            """
            new_action = self.actor(state)
            gen_logits = self.discriminator(state, new_action)
            generator_loss = nn.functional.softplus(-gen_logits).mean()

            self.gen_optim.zero_grad()
            generator_loss.backward()
            self.gen_optim.step()

            """
            Discriminator Training
            """
            fake_labels = torch.zeros(state.shape[0], 1, device=self.device)
            real_labels = torch.ones(state.shape[0], 1, device=self.device)

            real_loss = self.adversarial_loss(self.discriminator(state, action), real_labels)
            fake_loss = self.adversarial_loss(self.discriminator(state, new_action.detach()), fake_labels)
            discriminator_loss = real_loss + fake_loss

            self.disc_optim.zero_grad()
            discriminator_loss.backward()
            self.disc_optim.step()

            # Update Target Networks
            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

            for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1. - self.tau) * target_param.data)

        # Logging
        logger.record_tabular('Generator Loss', generator_loss.item())
        logger.record_tabular('Real Loss', real_loss.item())
        logger.record_tabular('Fake Loss', fake_loss.item())
        logger.record_tabular('Discriminator Loss', discriminator_loss.item())