agents/ed_pcq.py (224 lines of code) (raw):
# Copyright 2022 Twitter, Inc.
# SPDX-License-Identifier: Apache-2.0
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.logger import logger
from agents.diffusion import Diffusion
from agents.model import MLP
from agents.helpers import EMA
from utils import pytorch_util as ptu
def identity(x):
return x
class ParallelizedLayerMLP(nn.Module):
def __init__(
self,
ensemble_size,
input_dim,
output_dim,
w_std_value=1.0,
b_init_value=0.0
):
super().__init__()
# approximation to truncated normal of 2 stds
w_init = torch.randn((ensemble_size, input_dim, output_dim))
w_init = torch.fmod(w_init, 2) * w_std_value
self.W = nn.Parameter(w_init, requires_grad=True)
# constant initialization
b_init = torch.zeros((ensemble_size, 1, output_dim)).float()
b_init += b_init_value
self.b = nn.Parameter(b_init, requires_grad=True)
def forward(self, x):
# assumes x is 3D: (ensemble_size, batch_size, dimension)
return x @ self.W + self.b
class ParallelizedEnsembleFlattenMLP(nn.Module):
def __init__(
self,
ensemble_size,
hidden_sizes,
input_size,
output_size,
init_w=3e-3,
hidden_init=ptu.fanin_init,
w_scale=1,
b_init_value=0.1,
layer_norm=None,
batch_norm=False,
final_init_scale=None,
):
super().__init__()
self.ensemble_size = ensemble_size
self.input_size = input_size
self.output_size = output_size
self.elites = [i for i in range(self.ensemble_size)]
self.sampler = np.random.default_rng()
self.hidden_activation = F.relu
self.output_activation = identity
self.layer_norm = layer_norm
self.fcs = []
if batch_norm:
raise NotImplementedError
in_size = input_size
for i, next_size in enumerate(hidden_sizes):
fc = ParallelizedLayerMLP(
ensemble_size=ensemble_size,
input_dim=in_size,
output_dim=next_size,
)
for j in self.elites:
hidden_init(fc.W[j], w_scale)
fc.b[j].data.fill_(b_init_value)
self.__setattr__('fc%d' % i, fc)
self.fcs.append(fc)
in_size = next_size
self.last_fc = ParallelizedLayerMLP(
ensemble_size=ensemble_size,
input_dim=in_size,
output_dim=output_size,
)
if final_init_scale is None:
self.last_fc.W.data.uniform_(-init_w, init_w)
self.last_fc.b.data.uniform_(-init_w, init_w)
else:
for j in self.elites:
ptu.orthogonal_init(self.last_fc.W[j], final_init_scale)
self.last_fc.b[j].data.fill_(0)
def forward(self, *inputs, **kwargs):
flat_inputs = torch.cat(inputs, dim=-1)
state_dim = inputs[0].shape[-1]
dim = len(flat_inputs.shape)
# repeat h to make amenable to parallelization
# if dim = 3, then we probably already did this somewhere else
# (e.g. bootstrapping in training optimization)
if dim < 3:
flat_inputs = flat_inputs.unsqueeze(0)
if dim == 1:
flat_inputs = flat_inputs.unsqueeze(0)
flat_inputs = flat_inputs.repeat(self.ensemble_size, 1, 1)
# input normalization
h = flat_inputs
# standard feedforward network
for _, fc in enumerate(self.fcs):
h = fc(h)
h = self.hidden_activation(h)
if hasattr(self, 'layer_norm') and (self.layer_norm is not None):
h = self.layer_norm(h)
preactivation = self.last_fc(h)
output = self.output_activation(preactivation)
# if original dim was 1D, squeeze the extra created layer
if dim == 1:
output = output.squeeze(1)
# output is (ensemble_size, batch_size, output_size)
return output
def sample(self, *inputs):
preds = self.forward(*inputs)
return torch.min(preds, dim=0)[0]
def fit_input_stats(self, data, mask=None):
raise NotImplementedError
class ED_PCQ(object):
def __init__(self,
state_dim,
action_dim,
max_action,
device,
discount,
tau,
max_q_backup=False,
eta=0.1,
model_type='MLP',
beta_schedule='linear',
n_timesteps=100,
ema_decay=0.995,
step_start_ema=1000,
update_ema_every=5,
lr=3e-4,
num_qs=50,
num_q_layers=3,
q_eta=1.0,
):
self.model = MLP(state_dim=state_dim, action_dim=action_dim, device=device)
self.actor = Diffusion(state_dim=state_dim, action_dim=action_dim, model=self.model, max_action=max_action,
beta_schedule=beta_schedule, n_timesteps=n_timesteps,
).to(device)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr)
self.step = 0
self.step_start_ema = step_start_ema
self.ema = EMA(ema_decay)
self.ema_model = copy.deepcopy(self.actor)
self.update_ema_every = update_ema_every
self.num_qs = num_qs
self.q_eta = q_eta
self.critic = ParallelizedEnsembleFlattenMLP(ensemble_size=num_qs,
hidden_sizes=[256] * num_q_layers,
input_size=state_dim + action_dim,
output_size=1,
layer_norm=None,
).to(device)
self.critic_target = copy.deepcopy(self.critic)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr)
self.state_dim = state_dim
self.max_action = max_action
self.action_dim = action_dim
self.discount = discount
self.tau = tau
self.eta = eta # q_learning weight
self.device = device
self.max_q_backup = max_q_backup
def step_ema(self):
if self.step < self.step_start_ema:
return
self.ema.update_model_average(self.ema_model, self.model)
def train(self, replay_buffer, iterations, batch_size=100):
for step in range(iterations):
# Sample replay buffer / batch
state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
""" Q Training """
current_qs = self.critic(state, action)
if not self.max_q_backup:
next_action = self.ema_model(next_state)
target_q = self.critic_target.sample(next_state, next_action)
else:
next_state_rpt = torch.repeat_interleave(next_state, repeats=10, dim=0)
next_action_rpt = self.ema_model(next_state_rpt)
target_q = self.critic_target.sample(next_state_rpt, next_action_rpt)
target_q = target_q.view(batch_size, 10).max(dim=1, keepdim=True)[0]
target_q = (reward + not_done * self.discount * target_q).detach().unsqueeze(0)
critic_loss = F.mse_loss(current_qs, target_q, reduction='none')
critic_loss = critic_loss.mean(dim=(1, 2)).sum()
if self.q_eta > 0:
state_tile = state.unsqueeze(0).repeat(self.num_qs, 1, 1)
action_tile = action.unsqueeze(0).repeat(self.num_qs, 1, 1).requires_grad_(True)
qs_preds_tile = self.critic(state_tile, action_tile)
qs_pred_grads, = torch.autograd.grad(qs_preds_tile.sum(), action_tile, retain_graph=True,
create_graph=True)
qs_pred_grads = qs_pred_grads / (torch.norm(qs_pred_grads, p=2, dim=2).unsqueeze(-1) + 1e-10)
qs_pred_grads = qs_pred_grads.transpose(0, 1)
qs_pred_grads = torch.einsum('bik,bjk->bij', qs_pred_grads, qs_pred_grads)
masks = torch.eye(self.num_qs, device=self.device).unsqueeze(dim=0).repeat(qs_pred_grads.size(0), 1, 1)
qs_pred_grads = (1 - masks) * qs_pred_grads
grad_loss = torch.mean(torch.sum(qs_pred_grads, dim=(1, 2))) / (self.num_qs - 1)
critic_loss += self.q_eta * grad_loss
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
""" Policy Training """
bc_loss = self.actor.loss(action, state)
new_action = self.actor(state)
q_new_action = self.critic.sample(state, new_action)
lmbda = self.eta / q_new_action.abs().mean().detach()
q_loss = - lmbda * q_new_action.mean()
self.actor_optimizer.zero_grad()
bc_loss.backward()
q_loss.backward()
self.actor_optimizer.step()
if self.step % self.update_ema_every == 0:
self.step_ema()
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
self.step += 1
# Logging
logger.record_tabular('BC Loss', bc_loss.item())
logger.record_tabular('QL Loss', q_loss.item())
logger.record_tabular('Critic Loss', critic_loss.item())
logger.record_tabular('ED Loss', grad_loss.item())
logger.record_tabular('Target_Q Mean', target_q.mean().item())
def sample_action(self, state):
state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
state_rpt = torch.repeat_interleave(state, repeats=50, dim=0)
with torch.no_grad():
action = self.actor.sample(state_rpt)
q_value = self.critic_target.sample(state_rpt, action).flatten()
idx = torch.multinomial(F.softmax(q_value), 1)
return action[idx].cpu().data.numpy().flatten()
def save_model(self, dir):
torch.save(self.actor.state_dict(), f'{dir}/actor.pth')
torch.save(self.critic.state_dict(), f'{dir}/critic.pth')
def load_model(self, dir):
self.actor.load_state_dict(torch.load(f'{dir}/actor.pth'))
self.critic.load_state_dict(torch.load(f'{dir}/critic.pth'))