agents/ql_diffusion.py (141 lines of code) (raw):
# Copyright 2022 Twitter, Inc.
# SPDX-License-Identifier: Apache-2.0
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from agents.diffusion import Diffusion
from agents.model import MLP
from agents.helpers import EMA
class Critic(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super(Critic, self).__init__()
self.q1_model = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim),
nn.Mish(),
nn.Linear(hidden_dim, hidden_dim),
nn.Mish(),
nn.Linear(hidden_dim, hidden_dim),
nn.Mish(),
nn.Linear(hidden_dim, 1))
self.q2_model = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim),
nn.Mish(),
nn.Linear(hidden_dim, hidden_dim),
nn.Mish(),
nn.Linear(hidden_dim, hidden_dim),
nn.Mish(),
nn.Linear(hidden_dim, 1))
def forward(self, state, action):
x = torch.cat([state, action], dim=-1)
return self.q1_model(x), self.q2_model(x)
def q1(self, state, action):
x = torch.cat([state, action], dim=-1)
return self.q1_model(x)
def q_min(self, state, action):
q1, q2 = self.forward(state, action)
return torch.min(q1, q2)
class PCQ(object):
def __init__(self,
state_dim,
action_dim,
max_action,
device,
discount,
tau,
max_q_backup=False,
eta=0.1,
model_type='MLP',
beta_schedule='linear',
n_timesteps=100,
ema_decay=0.995,
step_start_ema=1000,
update_ema_every=5,
lr=3e-4,
lr_decay=False,
lr_maxt=int(1e6),
mode='whole_grad',
):
self.model = MLP(state_dim=state_dim, action_dim=action_dim, device=device)
self.actor = Diffusion(state_dim=state_dim, action_dim=action_dim, model=self.model, max_action=max_action,
beta_schedule=beta_schedule, n_timesteps=n_timesteps,
).to(device)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr)
self.lr_decay = lr_decay
if lr_decay:
self.actor_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.actor_optimizer, T_max=lr_maxt)
self.step = 0
self.step_start_ema = step_start_ema
self.ema = EMA(ema_decay)
self.ema_model = copy.deepcopy(self.actor)
self.update_ema_every = update_ema_every
self.critic = Critic(state_dim, action_dim).to(device)
self.critic_target = copy.deepcopy(self.critic)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr)
self.state_dim = state_dim
self.max_action = max_action
self.action_dim = action_dim
self.discount = discount
self.tau = tau
self.eta = eta # q_learning weight
self.device = device
self.max_q_backup = max_q_backup
self.mode = mode
def step_ema(self):
if self.step < self.step_start_ema:
return
self.ema.update_model_average(self.ema_model, self.actor)
def train(self, replay_buffer, iterations, batch_size=100):
for step in range(iterations):
# Sample replay buffer / batch
state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
""" Q Training """
current_q1, current_q2 = self.critic(state, action)
if not self.max_q_backup:
next_action = self.ema_model(next_state)
target_q1, target_q2 = self.critic_target(next_state, next_action)
target_q = torch.min(target_q1, target_q2)
else:
next_state_rpt = torch.repeat_interleave(next_state, repeats=10, dim=0)
next_action_rpt = self.ema_model(next_state_rpt)
target_q1, target_q2 = self.critic_target(next_state_rpt, next_action_rpt)
target_q1 = target_q1.view(batch_size, 10).max(dim=1, keepdim=True)[0]
target_q2 = target_q2.view(batch_size, 10).max(dim=1, keepdim=True)[0]
target_q = torch.min(target_q1, target_q2)
target_q = (reward + not_done * self.discount * target_q).detach()
critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
""" Policy Training """
bc_loss = self.actor.loss(action, state)
if self.mode == 'whole_grad':
new_action = self.actor(state)
elif self.mode == 'last_few':
new_action = self.actor.sample_last_few(state)
q1_new_action, q2_new_action = self.critic(state, new_action)
if np.random.uniform() > 0.5:
lmbda = self.eta / q2_new_action.abs().mean().detach()
q_loss = - lmbda * q1_new_action.mean()
else:
lmbda = self.eta / q1_new_action.abs().mean().detach()
q_loss = - lmbda * q2_new_action.mean()
# q_new_action = self.critic.q_min(state, new_action)
# lmbda = self.eta / q_new_action.abs().mean().detach()
# q_loss = - lmbda * q_new_action.mean()
self.actor_optimizer.zero_grad()
bc_loss.backward()
q_loss.backward()
self.actor_optimizer.step()
self.actor.step_frozen()
if self.step % self.update_ema_every == 0:
self.step_ema()
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
self.step += 1
if self.lr_decay: self.actor_lr_scheduler.step()
def sample_action(self, state):
state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
state_rpt = torch.repeat_interleave(state, repeats=50, dim=0)
with torch.no_grad():
action = self.actor.sample(state_rpt)
q_value = self.critic_target.q_min(state_rpt, action).flatten()
idx = torch.multinomial(F.softmax(q_value), 1)
return action[idx].cpu().data.numpy().flatten()
def save_model(self, dir):
torch.save(self.actor.state_dict(), f'{dir}/actor.pth')
torch.save(self.critic.state_dict(), f'{dir}/critic.pth')
def load_model(self, dir):
self.actor.load_state_dict(torch.load(f'{dir}/actor.pth'))
self.critic.load_state_dict(torch.load(f'{dir}/critic.pth'))