agents/bc_w.py (195 lines of code) (raw):

# Copyright 2022 Twitter, Inc. # SPDX-License-Identifier: Apache-2.0 import copy import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from utils.logger import logger from torch.distributions import Distribution, Normal EPS = 1e-8 LOG_SIG_MAX = 2 LOG_SIG_MIN = -20 class TanhNormal(Distribution): """ Represent distribution of X where X ~ tanh(Z) Z ~ N(mean, std) Note: this is not very numerically stable. """ def __init__(self, normal_mean, normal_std, device, epsilon=1e-6): """ :param normal_mean: Mean of the normal distribution :param normal_std: Std of the normal distribution :param epsilon: Numerical stability epsilon when computing log-prob. """ self.normal_mean = normal_mean self.normal_std = normal_std self.normal = Normal(normal_mean, normal_std) self.epsilon = epsilon self.device = device def sample_n(self, n, return_pre_tanh_value=False): z = self.normal.sample_n(n) if return_pre_tanh_value: return torch.tanh(z), z else: return torch.tanh(z) def log_prob(self, value, pre_tanh_value=None): """ :param value: some value, x :param pre_tanh_value: arctanh(x) :return: """ if pre_tanh_value is None: pre_tanh_value = torch.log( (1+value) / (1-value+self.epsilon) + self.epsilon ) / 2 return self.normal.log_prob(pre_tanh_value) - torch.log( 1 - value * value + self.epsilon ) def sample(self, return_pretanh_value=False): """ Sampling without reparameterization. """ z = self.normal.sample().detach() if return_pretanh_value: return torch.tanh(z), z else: return torch.tanh(z) def rsample(self, return_pretanh_value=False): """ Sampling in the reparameterization case. """ z = ( self.normal_mean + self.normal_std * Normal( torch.zeros(self.normal_mean.size(), device=self.device), torch.ones(self.normal_std.size(), device=self.device) ).sample() ) z.requires_grad_() if return_pretanh_value: return torch.tanh(z), z else: return torch.tanh(z) # Implicit Policy class Actor(nn.Module): """ Gaussian Policy """ def __init__(self, state_dim, action_dim, max_action, device, hidden_sizes=[256, 256], layer_norm=False): super(Actor, self).__init__() self.layer_norm = layer_norm self.base_fc = [] last_size = state_dim for next_size in hidden_sizes: self.base_fc += [ nn.Linear(last_size, next_size), nn.LayerNorm(next_size) if layer_norm else nn.Identity(), nn.ReLU(inplace=True), ] last_size = next_size self.base_fc = nn.Sequential(*self.base_fc) last_hidden_size = hidden_sizes[-1] self.last_fc_mean = nn.Linear(last_hidden_size, action_dim) self.last_fc_log_std = nn.Linear(last_hidden_size, action_dim) self.max_action = max_action self.device = device def forward(self, state): h = self.base_fc(state) mean = self.last_fc_mean(h) std = self.last_fc_log_std(h).clamp(LOG_SIG_MIN, LOG_SIG_MAX).exp() tanh_normal = TanhNormal(mean, std, self.device) action, pre_tanh_value = tanh_normal.rsample(return_pretanh_value=True) log_prob = tanh_normal.log_prob(action, pre_tanh_value=pre_tanh_value) log_prob = log_prob.sum(dim=1, keepdim=True) action = action * self.max_action return action, log_prob def log_prob(self, state, action): h = self.base_fc(state) mean = self.last_fc_mean(h) std = self.last_fc_log_std(h).clamp(LOG_SIG_MIN, LOG_SIG_MAX).exp() tanh_normal = TanhNormal(mean, std, self.device) log_prob = tanh_normal.log_prob(action) log_prob = log_prob.sum(dim=1, keepdim=True) return log_prob def sample(self, state, reparameterize=False, deterministic=False): h = self.base_fc(state) mean = self.last_fc_mean(h) std = self.last_fc_log_std(h).clamp(LOG_SIG_MIN, LOG_SIG_MAX).exp() if deterministic: action = torch.tanh(mean) * self.max_action else: tanh_normal = TanhNormal(mean, std, self.device) if reparameterize: action = tanh_normal.rsample() else: action = tanh_normal.sample() action = action * self.max_action return action class Critic(nn.Module): def __init__(self, state_dim, action_dim, hidden_dim=300): super(Critic, self).__init__() self.model = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1)) def forward(self, state, action): x = torch.cat([state, action], dim=-1) return self.model(x) class BC_W(object): def __init__(self, state_dim, action_dim, max_action, device, discount, tau, lr=3e-4, hidden_sizes=[256,256], w_gamma=5.0, c_iter=3, ): self.actor = Actor(state_dim, action_dim, max_action, device=device, hidden_sizes=hidden_sizes, layer_norm=False).to(device) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr) self.critic = Critic(state_dim, action_dim).to(device) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=1e-4) self.max_action = max_action self.action_dim = action_dim self.discount = discount self.tau = tau self.device = device self.w_gamma = w_gamma self.c_iter = c_iter def sample_action(self, state): with torch.no_grad(): state = torch.FloatTensor(state.reshape(1, -1)).to(self.device) action = self.actor.sample(state, deterministic=True) return action.cpu().data.numpy().flatten() def optimize_c(self, state, action_b): action_pi = self.actor.sample(state).detach() batch_size = state.shape[0] alpha = torch.rand((batch_size, 1)).to(self.device) a_intpl = (action_pi + alpha * (action_b - action_pi)).requires_grad_(True) grads = torch.autograd.grad(outputs=self.critic(state, a_intpl).mean(), inputs=a_intpl, create_graph=True, only_inputs=True)[0] slope = (grads.square().sum(dim=-1) + EPS).sqrt() gradient_penalty = torch.max(slope - 1.0, torch.zeros_like(slope)).square().mean() logits_p = self.critic(state, action_pi) logits_b = self.critic(state, action_b) logits_diff = logits_p - logits_b critic_loss = - logits_diff.mean() + gradient_penalty * self.w_gamma self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() return critic_loss.item() def optimize_p(self, state, action_b): action_pi = self.actor.sample(state) logits_p = self.critic(state, action_pi) logits_b = self.critic(state, action_b) logits_diff = logits_p - logits_b # Actor Training actor_loss = logits_diff.mean() self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() return actor_loss.item() def train(self, replay_buffer, iterations, batch_size=100): for it in range(iterations): # Sample replay buffer / batch state, action, _, _, _ = replay_buffer.sample(batch_size) critic_loss = self.optimize_c(state, action) actor_loss = self.optimize_p(state, action) for _ in range(self.c_iter - 1): state, action, _, _, _ = replay_buffer.sample(batch_size) critic_loss = self.optimize_c(state, action) logger.record_tabular('Actor Loss', actor_loss) logger.record_tabular('Critic Loss', critic_loss) def save_model(self, dir): torch.save(self.actor.state_dict(), f'{dir}/actor.pth') def load_model(self, dir): self.actor.load_state_dict(torch.load(f'{dir}/actor.pth'))