def train()

in agents/bc_kl.py [0:0]


    def train(self, replay_buffer, iterations, batch_size=100):

        for it in range(iterations):
            # Sample replay buffer / batch
            state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)

            # Variational Auto-Encoder Training
            recon, mean, std = self.vae(state, action)
            recon_loss = F.mse_loss(recon, action)
            KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
            vae_loss = recon_loss + 0.5 * KL_loss

            self.vae_optimizer.zero_grad()
            vae_loss.backward()
            self.vae_optimizer.step()

            num_samples = self.num_samples_match
            sampled_actions, raw_sampled_actions = self.vae.decode_multiple(state, num_decode=num_samples)  # B x N x d
            # actor_actions, raw_actor_actions = self.actor.sample_multiple(state, num_sample=num_samples)  # num)

            kl_loss = self.kl_loss(raw_sampled_actions, state).mean()
            self.actor_optimizer.zero_grad()
            kl_loss.backward()
            self.actor_optimizer.step()

        logger.record_tabular('VAE Loss', vae_loss.cpu().data.numpy())
        logger.record_tabular('KL Loss', kl_loss.cpu().data.numpy())