in agents/bc_gan2.py [0:0]
def train(self, replay_buffer, iterations, batch_size=100):
self.actor.train()
for it in range(iterations):
# Sample replay buffer / batch
state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
"""
Generator Training
"""
new_action = self.actor(state)
gen_logits = self.discriminator(state, new_action)
generator_loss = nn.functional.softplus(-gen_logits).mean()
self.gen_optim.zero_grad()
generator_loss.backward()
self.gen_optim.step()
"""
Discriminator Training
"""
fake_labels = torch.zeros(state.shape[0], 1, device=self.device)
real_labels = torch.ones(state.shape[0], 1, device=self.device)
real_loss = self.adversarial_loss(self.discriminator(state, action), real_labels)
fake_loss = self.adversarial_loss(self.discriminator(state, new_action.detach()), fake_labels)
discriminator_loss = real_loss + fake_loss
self.disc_optim.zero_grad()
discriminator_loss.backward()
self.disc_optim.step()
# Update Target Networks
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1. - self.tau) * target_param.data)
# Logging
logger.record_tabular('Generator Loss', generator_loss.item())
logger.record_tabular('Real Loss', real_loss.item())
logger.record_tabular('Fake Loss', fake_loss.item())
logger.record_tabular('Discriminator Loss', discriminator_loss.item())