agents/bc_diffusion.py (53 lines of code) (raw):

# Copyright 2022 Twitter, Inc. # SPDX-License-Identifier: Apache-2.0 import copy import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from utils.logger import logger from agents.diffusion import Diffusion from agents.model import MLP_Unet, MLP, Tanh_MLP class BC(object): def __init__(self, state_dim, action_dim, max_action, device, discount, tau, model_type='MLP', beta_schedule='linear', n_timesteps=100, lr=2e-4, ): if model_type == 'MLP': self.model = MLP(state_dim=state_dim, action_dim=action_dim, device=device) elif model_type == 'MLP_Unet': self.model = MLP_Unet(state_dim=state_dim, action_dim=action_dim, device=device) elif model_type == 'Tanh_MLP': self.model = Tanh_MLP(state_dim=state_dim, action_dim=action_dim, max_action=max_action, device=device) self.actor = Diffusion(state_dim=state_dim, action_dim=action_dim, model=self.model, max_action=max_action, beta_schedule=beta_schedule, n_timesteps=n_timesteps, ).to(device) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr) self.max_action = max_action self.action_dim = action_dim self.discount = discount self.tau = tau self.device = device def train(self, replay_buffer, iterations, batch_size=100): for it in range(iterations): # Sample replay buffer / batch state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) loss = self.actor.loss(action, state) self.actor_optimizer.zero_grad() loss.backward() self.actor_optimizer.step() # Logging logger.record_tabular('Diffusion BC Loss', loss.item()) def sample_action(self, state): state = torch.FloatTensor(state.reshape(1, -1)).to(self.device) with torch.no_grad(): action = self.actor.sample(state) return action.cpu().data.numpy().flatten() def save_model(self, dir): torch.save(self.actor.state_dict(), f'{dir}/actor.pth') def load_model(self, dir): self.actor.load_state_dict(torch.load(f'{dir}/actor.pth'))