in agents/ed_pcq.py [0:0]
def __init__(self,
state_dim,
action_dim,
max_action,
device,
discount,
tau,
max_q_backup=False,
eta=0.1,
model_type='MLP',
beta_schedule='linear',
n_timesteps=100,
ema_decay=0.995,
step_start_ema=1000,
update_ema_every=5,
lr=3e-4,
num_qs=50,
num_q_layers=3,
q_eta=1.0,
):
self.model = MLP(state_dim=state_dim, action_dim=action_dim, device=device)
self.actor = Diffusion(state_dim=state_dim, action_dim=action_dim, model=self.model, max_action=max_action,
beta_schedule=beta_schedule, n_timesteps=n_timesteps,
).to(device)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr)
self.step = 0
self.step_start_ema = step_start_ema
self.ema = EMA(ema_decay)
self.ema_model = copy.deepcopy(self.actor)
self.update_ema_every = update_ema_every
self.num_qs = num_qs
self.q_eta = q_eta
self.critic = ParallelizedEnsembleFlattenMLP(ensemble_size=num_qs,
hidden_sizes=[256] * num_q_layers,
input_size=state_dim + action_dim,
output_size=1,
layer_norm=None,
).to(device)
self.critic_target = copy.deepcopy(self.critic)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr)
self.state_dim = state_dim
self.max_action = max_action
self.action_dim = action_dim
self.discount = discount
self.tau = tau
self.eta = eta # q_learning weight
self.device = device
self.max_q_backup = max_q_backup