in agents/ql_cvae.py [0:0]
def train(self, replay_buffer, iterations, batch_size=100):
for it in range(iterations):
# Sample replay buffer / batch
state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
""" Q Training """
current_q1, current_q2 = self.critic(state, action)
if not self.max_q_backup:
next_action = self.ema_model.sample(next_state)
target_q1, target_q2 = self.critic_target(next_state, next_action)
target_q = torch.min(target_q1, target_q2)
else:
next_state_rpt = torch.repeat_interleave(next_state, repeats=10, dim=0)
next_action_rpt = self.ema_model.sample(next_state_rpt)
target_q1, target_q2 = self.critic_target(next_state_rpt, next_action_rpt)
target_q1 = target_q1.view(batch_size, 10).max(dim=1, keepdim=True)[0]
target_q2 = target_q2.view(batch_size, 10).max(dim=1, keepdim=True)[0]
target_q = torch.min(target_q1, target_q2)
target_q = (reward + not_done * self.discount * target_q).detach()
critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# Variational Auto-Encoder Training
recon, mean, std = self.actor(state, action)
recon_loss = F.mse_loss(recon, action)
KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
vae_loss = recon_loss + 0.5 * KL_loss
new_action = self.actor.sample(state)
q1_new_action, q2_new_action = self.critic(state, new_action)
if np.random.uniform() > 0.5:
lmbda = self.eta / q2_new_action.abs().mean().detach()
q_loss = - lmbda * q1_new_action.mean()
else:
lmbda = self.eta / q1_new_action.abs().mean().detach()
q_loss = - lmbda * q2_new_action.mean()
self.actor_optimizer.zero_grad()
vae_loss.backward()
q_loss.backward()
self.actor_optimizer.step()
if self.step % self.update_ema_every == 0:
self.step_ema()
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
self.step += 1