agents/bc_gan.py (125 lines of code) (raw):

# Copyright 2022 Twitter, Inc. # SPDX-License-Identifier: Apache-2.0 import copy import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from utils.logger import logger from torch.distributions import Distribution, Normal LOG_SIG_MAX = 2 LOG_SIG_MIN = -20 NEGATIVE_SLOPE = 1. / 10. class NormalNoise(object): def __init__(self, device, mean=0., std=1.): self.mean = mean self.std = std self.device = device def sample_noise(self, shape, dtype=None, requires_grad=False): return torch.randn(size=shape, dtype=dtype, device=self.device, requires_grad=requires_grad) * self.std + self.mean class ImplicitPolicy(nn.Module): def __init__(self, state_dim, action_dim, max_action, noise, noise_dim, device): # noise_dim : dimension of noise for "concat" method super(ImplicitPolicy, self).__init__() self.hidden_size = (400, 300) self.l1 = nn.Linear(state_dim + int(noise_dim), self.hidden_size[0]) self.l2 = nn.Linear(self.hidden_size[0], self.hidden_size[1]) self.l3 = nn.Linear(self.hidden_size[1], action_dim) self.max_action = max_action self.noise = noise self.noise_dim = int(noise_dim) self.device = device def forward(self, state): if isinstance(state, np.ndarray): state = torch.FloatTensor(state).to(self.device) # state.shape = (batch_size, state_dim) epsilon = self.noise.sample_noise(shape=(state.shape[0], self.noise_dim)).clamp(-3, 3) state = torch.cat([state, epsilon], 1) # dim = (state.shape[0], state_dim + noise_dim) a = F.leaky_relu(self.l1(state), negative_slope=NEGATIVE_SLOPE) a = F.leaky_relu(self.l2(a), negative_slope=NEGATIVE_SLOPE) return self.l3(a) def sample_multiple_actions(self, state, num_action=10, std=-1.): # num_action : number of actions to sample from policy for each state if isinstance(state, np.ndarray): state = torch.FloatTensor(state) batch_size = state.shape[0] # e.g., num_action = 3, [s1;s2] -> [s1;s1;s1;s2;s2;s2] if std <= 0: state = state.unsqueeze(1).repeat(1, num_action, 1).view(-1, state.size(-1)).to(self.device) else: # std > 0 if num_action == 1: noises = torch.normal(torch.zeros_like(state), torch.ones_like(state)) # B * state_dim state = (state + (std * noises).clamp(-0.05, 0.05)).to(self.device) # B x state_dim else: # num_action > 1 state_noise = state.unsqueeze(1).repeat(1, num_action, 1) # B * num_action * state_dim noises = torch.normal(torch.zeros_like(state_noise), torch.ones_like(state_noise)) # B * num_q_samples * state_dim state_noise = state_noise + (std * noises).clamp(-0.05, 0.05) # N x num_action x state_dim state = torch.cat((state_noise, state.unsqueeze(1)), dim=1).view((batch_size * (num_action+1)), -1).to(self.device) # (B * num_action) x state_dim # return [a11;a12;a13;a21;a22;a23] for [s1;s1;s1;s2;s2;s2] return state, self.forward(state) def sample(self, state): return self.forward(state) class Discriminator(nn.Module): def __init__(self, state_dim, action_dim): super(Discriminator, self).__init__() self.hidden_size = (400, 300) self.model = nn.Sequential( nn.Linear(state_dim + action_dim, self.hidden_size[0]), nn.LeakyReLU(negative_slope=NEGATIVE_SLOPE), nn.Linear(self.hidden_size[0], self.hidden_size[1]), nn.LeakyReLU(negative_slope=NEGATIVE_SLOPE), nn.Linear(self.hidden_size[1], 1), nn.Sigmoid() ) def forward(self, x): validity = self.model(x) return validity class BC_GAN(object): def __init__(self, state_dim, action_dim, max_action, device, discount, tau, lr=3e-4, ): noise_dim = min(action_dim, 10) self.noise = NormalNoise(device=device, mean=0.0, std=1.0) self.actor = ImplicitPolicy(state_dim, action_dim, max_action, self.noise, noise_dim, device).to( device) self.actor_target = copy.deepcopy(self.actor) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr, betas=(0.4, 0.999)) self.discriminator = Discriminator(state_dim=state_dim, action_dim=action_dim).to(device) self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(0.4, 0.999)) self.adversarial_loss = torch.nn.BCELoss() self.max_action = max_action self.action_dim = action_dim self.discount = discount self.tau = tau self.device = device self.g_iter = 2 def sample_action(self, state): with torch.no_grad(): state = torch.FloatTensor(state.reshape(1, -1)).to(self.device) action = self.actor.sample(state) return action.cpu().data.numpy().flatten() def train(self, replay_buffer, iterations, batch_size=100): for it in range(iterations): state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) state_repeat, action_samples = self.actor.sample_multiple_actions(state, num_action=5, std=3e-4) true_samples = torch.cat([state, action], 1) fake_samples = torch.cat([state_repeat, action_samples], 1) fake_labels = torch.zeros(fake_samples.size(0), 1, device=self.device) real_labels = torch.rand(size=(true_samples.size(0), 1), device=self.device) * (1.0 - 0.80) + 0.80 real_loss = self.adversarial_loss(self.discriminator(true_samples), real_labels) fake_loss = self.adversarial_loss(self.discriminator(fake_samples.detach()), fake_labels) discriminator_loss = (real_loss + fake_loss) / 2 self.discriminator_optimizer.zero_grad() discriminator_loss.backward() self.discriminator_optimizer.step() if it % self.g_iter == 0: generator_loss = self.adversarial_loss(self.discriminator(fake_samples), torch.ones(fake_samples.size(0), 1, device=self.device)) self.actor_optimizer.zero_grad() generator_loss.backward() self.actor_optimizer.step() logger.record_tabular('Generator Loss', generator_loss.cpu().data.numpy()) logger.record_tabular('Discriminator Loss', discriminator_loss.cpu().data.numpy()) def save_model(self, dir): torch.save(self.actor.state_dict(), f'{dir}/actor.pth') def load_model(self, dir): self.actor.load_state_dict(torch.load(f'{dir}/actor.pth'))