agents/bc_gan2.py (133 lines of code) (raw):

# Copyright 2022 Twitter, Inc. # SPDX-License-Identifier: Apache-2.0 import copy import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from utils.logger import logger from torch.distributions import Normal from torch.distributions.transformed_distribution import TransformedDistribution from torch.distributions.transforms import TanhTransform LOG_SIG_MAX = 2. LOG_SIG_MIN = -20. class Generator(nn.Module): def __init__(self, state_dim, action_dim, max_action, device, z_dim=64, w_dim=256, num_layers=2): super(Generator, self).__init__() self.z_dim = z_dim self.device = device self.max_action = max_action self.network = nn.Sequential(nn.Linear(state_dim + action_dim, 256), nn.LeakyReLU(0.1), nn.Linear(256, 256), nn.LeakyReLU(0.1), nn.Linear(256, 256), nn.LeakyReLU(0.1), nn.Linear(256, action_dim), nn.Tanh()) def forward(self, x): z = torch.randn((x.shape[0], self.z_dim), device=self.device) w = torch.cat([x, z], dim=-1) a = self.network(w) * self.max_action return a def sample(self, x): return self.forward(x) class Discriminator(nn.Module): def __init__(self, state_dim, action_dim, w_dim=256, num_layers=2): super(Discriminator, self).__init__() self.network = nn.Sequential(nn.Linear(state_dim + action_dim, 256), nn.LeakyReLU(0.1), nn.Linear(256, 256), nn.LeakyReLU(0.1), nn.Linear(256, 256), nn.LeakyReLU(0.1), nn.Linear(256, 1)) def forward(self, state, action): return self.network(torch.cat([state, action], dim=-1)) class Critic(nn.Module): def __init__(self, state_dim, action_dim, hidden_dim=256): super(Critic, self).__init__() self.q1_model = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1)) self.q2_model = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1)) def forward(self, state, action): x = torch.cat([state, action], dim=-1) return self.q1_model(x), self.q2_model(x) def q1(self, state, action): x = torch.cat([state, action], dim=-1) return self.q1_model(x) class BC_GAN(object): def __init__(self, state_dim, action_dim, max_action, device, discount, tau, lr=3e-4): self.actor = self.generator = Generator(state_dim, action_dim, max_action, device, z_dim=min(action_dim, 10)).to(device) self.actor_target = copy.deepcopy(self.actor) self.gen_optim = torch.optim.Adam(self.generator.parameters(), lr=2e-4) self.discriminator = Discriminator(state_dim, action_dim).to(device) self.disc_optim = torch.optim.Adam(self.discriminator.parameters(), lr=2e-4) self.adversarial_loss = torch.nn.BCEWithLogitsLoss() self.critic = Critic(state_dim, action_dim).to(device) self.critic_target = copy.deepcopy(self.critic) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=2e-4) self.max_action = max_action self.action_dim = action_dim self.discount = discount self.tau = tau self.device = device def sample_action(self, state): if self.actor.training: self.actor.eval() state = torch.FloatTensor(state.reshape(1, -1)).to(self.device) action = self.actor.sample(state) return action.cpu().data.numpy().flatten() def train(self, replay_buffer, iterations, batch_size=100): self.actor.train() for it in range(iterations): # Sample replay buffer / batch state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) """ Generator Training """ new_action = self.actor(state) gen_logits = self.discriminator(state, new_action) generator_loss = nn.functional.softplus(-gen_logits).mean() self.gen_optim.zero_grad() generator_loss.backward() self.gen_optim.step() """ Discriminator Training """ fake_labels = torch.zeros(state.shape[0], 1, device=self.device) real_labels = torch.ones(state.shape[0], 1, device=self.device) real_loss = self.adversarial_loss(self.discriminator(state, action), real_labels) fake_loss = self.adversarial_loss(self.discriminator(state, new_action.detach()), fake_labels) discriminator_loss = real_loss + fake_loss self.disc_optim.zero_grad() discriminator_loss.backward() self.disc_optim.step() # Update Target Networks for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): target_param.data.copy_(self.tau * param.data + (1. - self.tau) * target_param.data) # Logging logger.record_tabular('Generator Loss', generator_loss.item()) logger.record_tabular('Real Loss', real_loss.item()) logger.record_tabular('Fake Loss', fake_loss.item()) logger.record_tabular('Discriminator Loss', discriminator_loss.item())