agents/ql_cvae.py (152 lines of code) (raw):

# Copyright 2022 Twitter, Inc. # SPDX-License-Identifier: Apache-2.0 import copy import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from agents.helpers import EMA from torch.distributions import Distribution, Normal LOG_SIG_MAX = 2 LOG_SIG_MIN = -20 # Vanilla Variational Auto-Encoder class actor(nn.Module): def __init__(self, state_dim, action_dim, latent_dim, max_action, device, hidden_dim=256): super(actor, self).__init__() self.e1 = nn.Linear(state_dim + action_dim, hidden_dim) self.e2 = nn.Linear(hidden_dim, hidden_dim) self.mean = nn.Linear(hidden_dim, latent_dim) self.log_std = nn.Linear(hidden_dim, latent_dim) self.d1 = nn.Linear(state_dim + latent_dim, hidden_dim) self.d2 = nn.Linear(hidden_dim, hidden_dim) self.d3 = nn.Linear(hidden_dim, action_dim) self.max_action = max_action self.latent_dim = latent_dim self.device = device def forward(self, state, action): z = F.relu(self.e1(torch.cat([state, action], 1))) z = F.relu(self.e2(z)) mean = self.mean(z) # Clamped for numerical stability log_std = self.log_std(z).clamp(-4, 15) std = torch.exp(log_std) z = mean + std * torch.randn_like(std) u = self.decode(state, z) return u, mean, std def decode(self, state, z=None): # When sampling from the actor, the latent vector is clipped to [-0.5, 0.5] if z is None: z = torch.randn((state.shape[0], self.latent_dim)).to(self.device).clamp(-0.5, 0.5) a = F.relu(self.d1(torch.cat([state, z], 1))) a = F.relu(self.d2(a)) return self.max_action * torch.tanh(self.d3(a)) def sample(self, state): return self.decode(state) class Critic(nn.Module): def __init__(self, state_dim, action_dim, hidden_dim=256): super(Critic, self).__init__() self.q1_model = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim), nn.Mish(), nn.Linear(hidden_dim, hidden_dim), nn.Mish(), nn.Linear(hidden_dim, hidden_dim), nn.Mish(), nn.Linear(hidden_dim, 1)) self.q2_model = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim), nn.Mish(), nn.Linear(hidden_dim, hidden_dim), nn.Mish(), nn.Linear(hidden_dim, hidden_dim), nn.Mish(), nn.Linear(hidden_dim, 1)) def forward(self, state, action): x = torch.cat([state, action], dim=-1) return self.q1_model(x), self.q2_model(x) def q1(self, state, action): x = torch.cat([state, action], dim=-1) return self.q1_model(x) def q_min(self, state, action): q1, q2 = self.forward(state, action) return torch.min(q1, q2) class QL_CVAE(object): def __init__(self, state_dim, action_dim, max_action, device, discount, tau, max_q_backup=False, eta=0.1, ema_decay=0.995, step_start_ema=1000, update_ema_every=5, lr=3e-4, ): latent_dim = action_dim * 2 self.actor = actor(state_dim, action_dim, latent_dim, max_action, device, hidden_dim=500).to(device) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr) self.step = 0 self.step_start_ema = step_start_ema self.ema = EMA(ema_decay) self.ema_model = copy.deepcopy(self.actor) self.update_ema_every = update_ema_every self.critic = Critic(state_dim, action_dim).to(device) self.critic_target = copy.deepcopy(self.critic) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr) self.max_action = max_action self.action_dim = action_dim self.discount = discount self.tau = tau self.eta = eta # q_learning weight self.device = device self.max_q_backup = max_q_backup def sample_action(self, state): with torch.no_grad(): state = torch.FloatTensor(state.reshape(1, -1)).to(self.device) action = self.actor.sample(state) return action.cpu().data.numpy().flatten() def step_ema(self): if self.step < self.step_start_ema: return self.ema.update_model_average(self.ema_model, self.actor) def train(self, replay_buffer, iterations, batch_size=100): for it in range(iterations): # Sample replay buffer / batch state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) """ Q Training """ current_q1, current_q2 = self.critic(state, action) if not self.max_q_backup: next_action = self.ema_model.sample(next_state) target_q1, target_q2 = self.critic_target(next_state, next_action) target_q = torch.min(target_q1, target_q2) else: next_state_rpt = torch.repeat_interleave(next_state, repeats=10, dim=0) next_action_rpt = self.ema_model.sample(next_state_rpt) target_q1, target_q2 = self.critic_target(next_state_rpt, next_action_rpt) target_q1 = target_q1.view(batch_size, 10).max(dim=1, keepdim=True)[0] target_q2 = target_q2.view(batch_size, 10).max(dim=1, keepdim=True)[0] target_q = torch.min(target_q1, target_q2) target_q = (reward + not_done * self.discount * target_q).detach() critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q) self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() # Variational Auto-Encoder Training recon, mean, std = self.actor(state, action) recon_loss = F.mse_loss(recon, action) KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean() vae_loss = recon_loss + 0.5 * KL_loss new_action = self.actor.sample(state) q1_new_action, q2_new_action = self.critic(state, new_action) if np.random.uniform() > 0.5: lmbda = self.eta / q2_new_action.abs().mean().detach() q_loss = - lmbda * q1_new_action.mean() else: lmbda = self.eta / q1_new_action.abs().mean().detach() q_loss = - lmbda * q2_new_action.mean() self.actor_optimizer.zero_grad() vae_loss.backward() q_loss.backward() self.actor_optimizer.step() if self.step % self.update_ema_every == 0: self.step_ema() for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) self.step += 1 def save_model(self, dir): torch.save(self.actor.state_dict(), f'{dir}/actor.pth') def load_model(self, dir): self.actor.load_state_dict(torch.load(f'{dir}/actor.pth'))