in agents/qgdp.py [0:0]
def train(self, replay_buffer, iterations, batch_size=100):
for it in range(iterations):
# Sample replay buffer / batch
state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
""" Train Diffusion BC Policy """
loss = self.actor.loss(action, state)
self.actor_optimizer.zero_grad()
loss.backward()
self.actor_optimizer.step()
"""Train Q function """
# Value Function Training
with torch.no_grad():
q1, q2 = self.critic_target(state, action)
q = torch.min(q1, q2) # Clipped Double Q-learning
v = self.value_fun(state)
value_loss = expectile_reg_loss(q - v, self.quantile).mean()
self.value_optimizer.zero_grad()
value_loss.backward()
self.value_optimizer.step()
# Critic Training
current_q1, current_q2 = self.critic(state, action)
target_q = (reward + not_done * self.discount * self.value_fun(next_state)).detach()
critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# Update Target Networks
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
# Logging
logger.record_tabular('Diffusion BC Loss', loss.item())
logger.record_tabular('Value Fun Loss', value_loss.item())
logger.record_tabular('Critic Fun Loss', critic_loss.item())