agents/bc_kl.py (144 lines of code) (raw):

# Copyright 2022 Twitter, Inc. # SPDX-License-Identifier: Apache-2.0 import copy import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from utils.logger import logger import torch.distributions as td LOG_SIG_MAX = 2 LOG_SIG_MIN = -20 def atanh(x): one_plus_x = (1 + x).clamp(min=1e-7) one_minus_x = (1 - x).clamp(min=1e-7) return 0.5*torch.log(one_plus_x/ one_minus_x) # Vanilla Variational Auto-Encoder class VAE(nn.Module): def __init__(self, state_dim, action_dim, latent_dim, max_action, device, hidden_dim=256): super(VAE, self).__init__() self.e1 = nn.Linear(state_dim + action_dim, hidden_dim) self.e2 = nn.Linear(hidden_dim, hidden_dim) self.mean = nn.Linear(hidden_dim, latent_dim) self.log_std = nn.Linear(hidden_dim, latent_dim) self.d1 = nn.Linear(state_dim + latent_dim, hidden_dim) self.d2 = nn.Linear(hidden_dim, hidden_dim) self.d3 = nn.Linear(hidden_dim, action_dim) self.max_action = max_action self.latent_dim = latent_dim self.device = device def forward(self, state, action): z = F.relu(self.e1(torch.cat([state, action], 1))) z = F.relu(self.e2(z)) mean = self.mean(z) # Clamped for numerical stability log_std = self.log_std(z).clamp(-4, 15) std = torch.exp(log_std) z = mean + std * torch.randn_like(std) u = self.decode(state, z) return u, mean, std def decode(self, state, z=None): # When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5] if z is None: z = torch.randn((state.shape[0], self.latent_dim)).to(self.device).clamp(-0.5, 0.5) a = F.relu(self.d1(torch.cat([state, z], 1))) a = F.relu(self.d2(a)) return self.max_action * torch.tanh(self.d3(a)) def sample(self, state): return self.decode(state) class RegularActor(nn.Module): """A probabilistic actor which does regular stochastic mapping of actions from states""" def __init__(self, state_dim, action_dim, max_action, device, hidden_dim=256): super(RegularActor, self).__init__() self.l1 = nn.Linear(state_dim, hidden_dim) self.l2 = nn.Linear(hidden_dim, hidden_dim) self.mean = nn.Linear(hidden_dim, action_dim) self.log_std = nn.Linear(hidden_dim, action_dim) self.max_action = max_action self.device = device def forward(self, state): a = F.relu(self.l1(state)) a = F.relu(self.l2(a)) mean_a = self.mean(a) log_std_a = self.log_std(a) std_a = torch.exp(log_std_a) z = mean_a + std_a * torch.randn_like(std_a) return self.max_action * torch.tanh(z) def sample_multiple(self, state, num_sample=10): a = F.relu(self.l1(state)) a = F.relu(self.l2(a)) mean_a = self.mean(a) log_std_a = self.log_std(a) std_a = torch.exp(log_std_a) # This trick stabilizes learning (clipping gaussian to a smaller range) z = mean_a.unsqueeze(1) + \ std_a.unsqueeze(1) * torch.FloatTensor( np.random.normal(0, 1, size=(std_a.size(0), num_sample, std_a.size(1)))).to(self.device).clamp(-0.5, 0.5) return self.max_action * torch.tanh(z), z def log_pis(self, state, action=None, raw_action=None): """Get log pis for the model.""" a = F.relu(self.l1(state)) a = F.relu(self.l2(a)) mean_a = self.mean(a) log_std_a = self.log_std(a) std_a = torch.exp(log_std_a) normal_dist = td.Normal(loc=mean_a, scale=std_a, validate_args=True) if raw_action is None: raw_action = atanh(action) else: action = torch.tanh(raw_action) log_normal = normal_dist.log_prob(raw_action) log_pis = log_normal.sum(-1) log_pis = log_pis - (1.0 - action ** 2).clamp(min=1e-6).log().sum(-1) return log_pis def sample(self, state): return self.forward(state) class BC_KL(object): def __init__(self, state_dim, action_dim, max_action, device, discount, tau, lr=3e-4, num_samples_match=10, kl_type='backward' ): latent_dim = action_dim * 2 self.vae = VAE(state_dim, action_dim, latent_dim, max_action, device).to(device) self.vae_optimizer = torch.optim.Adam(self.vae.parameters(), lr=lr) self.actor = RegularActor(state_dim, action_dim, max_action, device).to(device) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr) self.max_action = max_action self.action_dim = action_dim self.discount = discount self.tau = tau self.device = device self.num_samples_match = num_samples_match self.kl_type = kl_type def sample_action(self, state): with torch.no_grad(): state = torch.FloatTensor(state.reshape(1, -1)).to(self.device) action = self.actor.sample(state) return action.cpu().data.numpy().flatten() def kl_loss(self, action, state): """ action in shape: (batch_size, num_samples_match, action_dim) """ """ Backward KL: KL(behavior_policy(a|s) | current_policy(a|s)) """ state_rep = state.unsqueeze(1).repeat(1, action.size(1), 1).view(-1, state.size(-1)) action_reshape = action.view(-1, action.size(-1)) action_log_pis = self.actor.log_pis(state=state_rep, raw_action=action_reshape) action_log_prob = action_log_pis.view(state.size(0), action.size(1)) return (-action_log_prob).mean(1) def train(self, replay_buffer, iterations, batch_size=100): for it in range(iterations): # Sample replay buffer / batch state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) # Variational Auto-Encoder Training recon, mean, std = self.vae(state, action) recon_loss = F.mse_loss(recon, action) KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean() vae_loss = recon_loss + 0.5 * KL_loss self.vae_optimizer.zero_grad() vae_loss.backward() self.vae_optimizer.step() num_samples = self.num_samples_match sampled_actions, raw_sampled_actions = self.vae.decode_multiple(state, num_decode=num_samples) # B x N x d # actor_actions, raw_actor_actions = self.actor.sample_multiple(state, num_sample=num_samples) # num) kl_loss = self.kl_loss(raw_sampled_actions, state).mean() self.actor_optimizer.zero_grad() kl_loss.backward() self.actor_optimizer.step() logger.record_tabular('VAE Loss', vae_loss.cpu().data.numpy()) logger.record_tabular('KL Loss', kl_loss.cpu().data.numpy()) def save_model(self, dir): torch.save(self.actor.state_dict(), f'{dir}/actor.pth') def load_model(self, dir): self.actor.load_state_dict(torch.load(f'{dir}/actor.pth'))