in agents/bc_cvae.py [0:0]
def train(self, replay_buffer, iterations, batch_size=100):
for it in range(iterations):
# Sample replay buffer / batch
state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
# Variational Auto-Encoder Training
recon, mean, std = self.vae(state, action)
recon_loss = F.mse_loss(recon, action)
KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
vae_loss = recon_loss + 0.5 * KL_loss
self.vae_optimizer.zero_grad()
vae_loss.backward()
self.vae_optimizer.step()
logger.record_tabular('VAE Loss', vae_loss.cpu().data.numpy())