agents/bc_mle.py (101 lines of code) (raw):
# Copyright 2022 Twitter, Inc.
# SPDX-License-Identifier: Apache-2.0
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.logger import logger
from torch.distributions import Distribution, Normal
LOG_SIG_MAX = 2
LOG_SIG_MIN = -20
class GaussianPolicy(nn.Module):
"""
Gaussian Policy
"""
def __init__(self,
state_dim,
action_dim,
max_action,
device,
hidden_sizes=[256, 256],
layer_norm=False):
super(GaussianPolicy, self).__init__()
self.layer_norm = layer_norm
self.base_fc = []
last_size = state_dim
for next_size in hidden_sizes:
self.base_fc += [
nn.Linear(last_size, next_size),
nn.LayerNorm(next_size) if layer_norm else nn.Identity(),
nn.ReLU(inplace=True),
]
last_size = next_size
self.base_fc = nn.Sequential(*self.base_fc)
last_hidden_size = hidden_sizes[-1]
self.last_fc_mean = nn.Linear(last_hidden_size, action_dim)
self.last_fc_log_std = nn.Linear(last_hidden_size, action_dim)
self.device = device
def forward(self, state):
h = self.base_fc(state)
mean = self.last_fc_mean(h)
std = self.last_fc_log_std(h).clamp(LOG_SIG_MIN, LOG_SIG_MAX).exp()
a_normal = Normal(mean, std, self.device)
action = a_normal.rsample()
log_prob = a_normal.log_prob(action)
log_prob = log_prob.sum(dim=1, keepdim=True)
return action, log_prob
def log_prob(self, state, action):
h = self.base_fc(state)
mean = self.last_fc_mean(h)
std = self.last_fc_log_std(h).clamp(LOG_SIG_MIN, LOG_SIG_MAX).exp()
a_normal = Normal(mean, std, self.device)
log_prob = a_normal.log_prob(action)
log_prob = log_prob.sum(dim=1, keepdim=True)
return log_prob
def sample(self,
state,
reparameterize=False,
deterministic=False):
h = self.base_fc(state)
mean = self.last_fc_mean(h)
std = self.last_fc_log_std(h).clamp(LOG_SIG_MIN, LOG_SIG_MAX).exp()
if deterministic:
action = mean
else:
a_normal = Normal(mean, std, self.device)
if reparameterize:
action = a_normal.rsample()
else:
action = a_normal.sample()
return action
class BC_MLE(object):
def __init__(self,
state_dim,
action_dim,
max_action,
device,
discount,
tau,
lr=3e-4,
):
self.actor = GaussianPolicy(state_dim, action_dim, max_action, device).to(device)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr)
self.max_action = max_action
self.action_dim = action_dim
self.discount = discount
self.tau = tau
self.device = device
def sample_action(self, state):
with torch.no_grad():
state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
action = self.actor.sample(state, deterministic=True)
return action.cpu().data.numpy().flatten()
def train(self, replay_buffer, iterations, batch_size=100):
for it in range(iterations):
# Sample replay buffer / batch
state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
# Actor Training
log_pi = self.actor.log_prob(state, action)
actor_loss = -log_pi.mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
logger.record_tabular('Actor Loss', actor_loss.cpu().data.numpy())
def save_model(self, dir):
torch.save(self.actor.state_dict(), f'{dir}/actor.pth')
def load_model(self, dir):
self.actor.load_state_dict(torch.load(f'{dir}/actor.pth'))