in agents/bc_w.py [0:0]
def sample(self,
state,
reparameterize=False,
deterministic=False):
h = self.base_fc(state)
mean = self.last_fc_mean(h)
std = self.last_fc_log_std(h).clamp(LOG_SIG_MIN, LOG_SIG_MAX).exp()
if deterministic:
action = torch.tanh(mean) * self.max_action
else:
tanh_normal = TanhNormal(mean, std, self.device)
if reparameterize:
action = tanh_normal.rsample()
else:
action = tanh_normal.sample()
action = action * self.max_action
return action