agents/ql_diffusion.py (141 lines of code) (raw):

# Copyright 2022 Twitter, Inc. # SPDX-License-Identifier: Apache-2.0 import copy import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from agents.diffusion import Diffusion from agents.model import MLP from agents.helpers import EMA class Critic(nn.Module): def __init__(self, state_dim, action_dim, hidden_dim=256): super(Critic, self).__init__() self.q1_model = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim), nn.Mish(), nn.Linear(hidden_dim, hidden_dim), nn.Mish(), nn.Linear(hidden_dim, hidden_dim), nn.Mish(), nn.Linear(hidden_dim, 1)) self.q2_model = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim), nn.Mish(), nn.Linear(hidden_dim, hidden_dim), nn.Mish(), nn.Linear(hidden_dim, hidden_dim), nn.Mish(), nn.Linear(hidden_dim, 1)) def forward(self, state, action): x = torch.cat([state, action], dim=-1) return self.q1_model(x), self.q2_model(x) def q1(self, state, action): x = torch.cat([state, action], dim=-1) return self.q1_model(x) def q_min(self, state, action): q1, q2 = self.forward(state, action) return torch.min(q1, q2) class PCQ(object): def __init__(self, state_dim, action_dim, max_action, device, discount, tau, max_q_backup=False, eta=0.1, model_type='MLP', beta_schedule='linear', n_timesteps=100, ema_decay=0.995, step_start_ema=1000, update_ema_every=5, lr=3e-4, lr_decay=False, lr_maxt=int(1e6), mode='whole_grad', ): self.model = MLP(state_dim=state_dim, action_dim=action_dim, device=device) self.actor = Diffusion(state_dim=state_dim, action_dim=action_dim, model=self.model, max_action=max_action, beta_schedule=beta_schedule, n_timesteps=n_timesteps, ).to(device) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr) self.lr_decay = lr_decay if lr_decay: self.actor_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.actor_optimizer, T_max=lr_maxt) self.step = 0 self.step_start_ema = step_start_ema self.ema = EMA(ema_decay) self.ema_model = copy.deepcopy(self.actor) self.update_ema_every = update_ema_every self.critic = Critic(state_dim, action_dim).to(device) self.critic_target = copy.deepcopy(self.critic) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr) self.state_dim = state_dim self.max_action = max_action self.action_dim = action_dim self.discount = discount self.tau = tau self.eta = eta # q_learning weight self.device = device self.max_q_backup = max_q_backup self.mode = mode def step_ema(self): if self.step < self.step_start_ema: return self.ema.update_model_average(self.ema_model, self.actor) def train(self, replay_buffer, iterations, batch_size=100): for step in range(iterations): # Sample replay buffer / batch state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) """ Q Training """ current_q1, current_q2 = self.critic(state, action) if not self.max_q_backup: next_action = self.ema_model(next_state) target_q1, target_q2 = self.critic_target(next_state, next_action) target_q = torch.min(target_q1, target_q2) else: next_state_rpt = torch.repeat_interleave(next_state, repeats=10, dim=0) next_action_rpt = self.ema_model(next_state_rpt) target_q1, target_q2 = self.critic_target(next_state_rpt, next_action_rpt) target_q1 = target_q1.view(batch_size, 10).max(dim=1, keepdim=True)[0] target_q2 = target_q2.view(batch_size, 10).max(dim=1, keepdim=True)[0] target_q = torch.min(target_q1, target_q2) target_q = (reward + not_done * self.discount * target_q).detach() critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q) self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() """ Policy Training """ bc_loss = self.actor.loss(action, state) if self.mode == 'whole_grad': new_action = self.actor(state) elif self.mode == 'last_few': new_action = self.actor.sample_last_few(state) q1_new_action, q2_new_action = self.critic(state, new_action) if np.random.uniform() > 0.5: lmbda = self.eta / q2_new_action.abs().mean().detach() q_loss = - lmbda * q1_new_action.mean() else: lmbda = self.eta / q1_new_action.abs().mean().detach() q_loss = - lmbda * q2_new_action.mean() # q_new_action = self.critic.q_min(state, new_action) # lmbda = self.eta / q_new_action.abs().mean().detach() # q_loss = - lmbda * q_new_action.mean() self.actor_optimizer.zero_grad() bc_loss.backward() q_loss.backward() self.actor_optimizer.step() self.actor.step_frozen() if self.step % self.update_ema_every == 0: self.step_ema() for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) self.step += 1 if self.lr_decay: self.actor_lr_scheduler.step() def sample_action(self, state): state = torch.FloatTensor(state.reshape(1, -1)).to(self.device) state_rpt = torch.repeat_interleave(state, repeats=50, dim=0) with torch.no_grad(): action = self.actor.sample(state_rpt) q_value = self.critic_target.q_min(state_rpt, action).flatten() idx = torch.multinomial(F.softmax(q_value), 1) return action[idx].cpu().data.numpy().flatten() def save_model(self, dir): torch.save(self.actor.state_dict(), f'{dir}/actor.pth') torch.save(self.critic.state_dict(), f'{dir}/critic.pth') def load_model(self, dir): self.actor.load_state_dict(torch.load(f'{dir}/actor.pth')) self.critic.load_state_dict(torch.load(f'{dir}/critic.pth'))