in agents/bc_gan.py [0:0]
def train(self, replay_buffer, iterations, batch_size=100):
for it in range(iterations):
state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
state_repeat, action_samples = self.actor.sample_multiple_actions(state, num_action=5, std=3e-4)
true_samples = torch.cat([state, action], 1)
fake_samples = torch.cat([state_repeat, action_samples], 1)
fake_labels = torch.zeros(fake_samples.size(0), 1, device=self.device)
real_labels = torch.rand(size=(true_samples.size(0), 1), device=self.device) * (1.0 - 0.80) + 0.80
real_loss = self.adversarial_loss(self.discriminator(true_samples), real_labels)
fake_loss = self.adversarial_loss(self.discriminator(fake_samples.detach()), fake_labels)
discriminator_loss = (real_loss + fake_loss) / 2
self.discriminator_optimizer.zero_grad()
discriminator_loss.backward()
self.discriminator_optimizer.step()
if it % self.g_iter == 0:
generator_loss = self.adversarial_loss(self.discriminator(fake_samples),
torch.ones(fake_samples.size(0), 1, device=self.device))
self.actor_optimizer.zero_grad()
generator_loss.backward()
self.actor_optimizer.step()
logger.record_tabular('Generator Loss', generator_loss.cpu().data.numpy())
logger.record_tabular('Discriminator Loss', discriminator_loss.cpu().data.numpy())