in agents/bc_kl.py [0:0]
def train(self, replay_buffer, iterations, batch_size=100):
for it in range(iterations):
# Sample replay buffer / batch
state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
# Variational Auto-Encoder Training
recon, mean, std = self.vae(state, action)
recon_loss = F.mse_loss(recon, action)
KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
vae_loss = recon_loss + 0.5 * KL_loss
self.vae_optimizer.zero_grad()
vae_loss.backward()
self.vae_optimizer.step()
num_samples = self.num_samples_match
sampled_actions, raw_sampled_actions = self.vae.decode_multiple(state, num_decode=num_samples) # B x N x d
# actor_actions, raw_actor_actions = self.actor.sample_multiple(state, num_sample=num_samples) # num)
kl_loss = self.kl_loss(raw_sampled_actions, state).mean()
self.actor_optimizer.zero_grad()
kl_loss.backward()
self.actor_optimizer.step()
logger.record_tabular('VAE Loss', vae_loss.cpu().data.numpy())
logger.record_tabular('KL Loss', kl_loss.cpu().data.numpy())