in agents/bc_kl.py [0:0]
def log_pis(self, state, action=None, raw_action=None):
"""Get log pis for the model."""
a = F.relu(self.l1(state))
a = F.relu(self.l2(a))
mean_a = self.mean(a)
log_std_a = self.log_std(a)
std_a = torch.exp(log_std_a)
normal_dist = td.Normal(loc=mean_a, scale=std_a, validate_args=True)
if raw_action is None:
raw_action = atanh(action)
else:
action = torch.tanh(raw_action)
log_normal = normal_dist.log_prob(raw_action)
log_pis = log_normal.sum(-1)
log_pis = log_pis - (1.0 - action ** 2).clamp(min=1e-6).log().sum(-1)
return log_pis