in agents/bc_gan.py [0:0]
def sample_multiple_actions(self, state, num_action=10, std=-1.):
# num_action : number of actions to sample from policy for each state
if isinstance(state, np.ndarray):
state = torch.FloatTensor(state)
batch_size = state.shape[0]
# e.g., num_action = 3, [s1;s2] -> [s1;s1;s1;s2;s2;s2]
if std <= 0:
state = state.unsqueeze(1).repeat(1, num_action, 1).view(-1, state.size(-1)).to(self.device)
else: # std > 0
if num_action == 1:
noises = torch.normal(torch.zeros_like(state), torch.ones_like(state)) # B * state_dim
state = (state + (std * noises).clamp(-0.05, 0.05)).to(self.device) # B x state_dim
else: # num_action > 1
state_noise = state.unsqueeze(1).repeat(1, num_action, 1) # B * num_action * state_dim
noises = torch.normal(torch.zeros_like(state_noise), torch.ones_like(state_noise)) # B * num_q_samples * state_dim
state_noise = state_noise + (std * noises).clamp(-0.05, 0.05) # N x num_action x state_dim
state = torch.cat((state_noise, state.unsqueeze(1)), dim=1).view((batch_size * (num_action+1)), -1).to(self.device) # (B * num_action) x state_dim
# return [a11;a12;a13;a21;a22;a23] for [s1;s1;s1;s2;s2;s2]
return state, self.forward(state)