agents/bc_gan2.py (133 lines of code) (raw):
# Copyright 2022 Twitter, Inc.
# SPDX-License-Identifier: Apache-2.0
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.logger import logger
from torch.distributions import Normal
from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions.transforms import TanhTransform
LOG_SIG_MAX = 2.
LOG_SIG_MIN = -20.
class Generator(nn.Module):
def __init__(self,
state_dim,
action_dim,
max_action,
device,
z_dim=64,
w_dim=256,
num_layers=2):
super(Generator, self).__init__()
self.z_dim = z_dim
self.device = device
self.max_action = max_action
self.network = nn.Sequential(nn.Linear(state_dim + action_dim, 256),
nn.LeakyReLU(0.1),
nn.Linear(256, 256),
nn.LeakyReLU(0.1),
nn.Linear(256, 256),
nn.LeakyReLU(0.1),
nn.Linear(256, action_dim),
nn.Tanh())
def forward(self, x):
z = torch.randn((x.shape[0], self.z_dim), device=self.device)
w = torch.cat([x, z], dim=-1)
a = self.network(w) * self.max_action
return a
def sample(self, x):
return self.forward(x)
class Discriminator(nn.Module):
def __init__(self, state_dim, action_dim, w_dim=256, num_layers=2):
super(Discriminator, self).__init__()
self.network = nn.Sequential(nn.Linear(state_dim + action_dim, 256),
nn.LeakyReLU(0.1),
nn.Linear(256, 256),
nn.LeakyReLU(0.1),
nn.Linear(256, 256),
nn.LeakyReLU(0.1),
nn.Linear(256, 1))
def forward(self, state, action):
return self.network(torch.cat([state, action], dim=-1))
class Critic(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super(Critic, self).__init__()
self.q1_model = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, 1))
self.q2_model = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, 1))
def forward(self, state, action):
x = torch.cat([state, action], dim=-1)
return self.q1_model(x), self.q2_model(x)
def q1(self, state, action):
x = torch.cat([state, action], dim=-1)
return self.q1_model(x)
class BC_GAN(object):
def __init__(self,
state_dim,
action_dim,
max_action,
device,
discount,
tau,
lr=3e-4):
self.actor = self.generator = Generator(state_dim,
action_dim,
max_action,
device,
z_dim=min(action_dim, 10)).to(device)
self.actor_target = copy.deepcopy(self.actor)
self.gen_optim = torch.optim.Adam(self.generator.parameters(), lr=2e-4)
self.discriminator = Discriminator(state_dim, action_dim).to(device)
self.disc_optim = torch.optim.Adam(self.discriminator.parameters(), lr=2e-4)
self.adversarial_loss = torch.nn.BCEWithLogitsLoss()
self.critic = Critic(state_dim, action_dim).to(device)
self.critic_target = copy.deepcopy(self.critic)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=2e-4)
self.max_action = max_action
self.action_dim = action_dim
self.discount = discount
self.tau = tau
self.device = device
def sample_action(self, state):
if self.actor.training:
self.actor.eval()
state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
action = self.actor.sample(state)
return action.cpu().data.numpy().flatten()
def train(self, replay_buffer, iterations, batch_size=100):
self.actor.train()
for it in range(iterations):
# Sample replay buffer / batch
state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
"""
Generator Training
"""
new_action = self.actor(state)
gen_logits = self.discriminator(state, new_action)
generator_loss = nn.functional.softplus(-gen_logits).mean()
self.gen_optim.zero_grad()
generator_loss.backward()
self.gen_optim.step()
"""
Discriminator Training
"""
fake_labels = torch.zeros(state.shape[0], 1, device=self.device)
real_labels = torch.ones(state.shape[0], 1, device=self.device)
real_loss = self.adversarial_loss(self.discriminator(state, action), real_labels)
fake_loss = self.adversarial_loss(self.discriminator(state, new_action.detach()), fake_labels)
discriminator_loss = real_loss + fake_loss
self.disc_optim.zero_grad()
discriminator_loss.backward()
self.disc_optim.step()
# Update Target Networks
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1. - self.tau) * target_param.data)
# Logging
logger.record_tabular('Generator Loss', generator_loss.item())
logger.record_tabular('Real Loss', real_loss.item())
logger.record_tabular('Fake Loss', fake_loss.item())
logger.record_tabular('Discriminator Loss', discriminator_loss.item())