def train()

in agents/bc_gan.py [0:0]


    def train(self, replay_buffer, iterations, batch_size=100):

        for it in range(iterations):
            state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)

            state_repeat, action_samples = self.actor.sample_multiple_actions(state, num_action=5, std=3e-4)
            true_samples = torch.cat([state, action], 1)
            fake_samples = torch.cat([state_repeat, action_samples], 1)

            fake_labels = torch.zeros(fake_samples.size(0), 1, device=self.device)
            real_labels = torch.rand(size=(true_samples.size(0), 1), device=self.device) * (1.0 - 0.80) + 0.80

            real_loss = self.adversarial_loss(self.discriminator(true_samples), real_labels)
            fake_loss = self.adversarial_loss(self.discriminator(fake_samples.detach()), fake_labels)
            discriminator_loss = (real_loss + fake_loss) / 2

            self.discriminator_optimizer.zero_grad()
            discriminator_loss.backward()
            self.discriminator_optimizer.step()

            if it % self.g_iter == 0:
                generator_loss = self.adversarial_loss(self.discriminator(fake_samples),
                                                       torch.ones(fake_samples.size(0), 1, device=self.device))
                self.actor_optimizer.zero_grad()
                generator_loss.backward()
                self.actor_optimizer.step()

        logger.record_tabular('Generator Loss', generator_loss.cpu().data.numpy())
        logger.record_tabular('Discriminator Loss', discriminator_loss.cpu().data.numpy())