def sample_multiple_actions()

in agents/bc_gan.py [0:0]


    def sample_multiple_actions(self, state, num_action=10, std=-1.):
        # num_action : number of actions to sample from policy for each state

        if isinstance(state, np.ndarray):
            state = torch.FloatTensor(state)
        batch_size = state.shape[0]
        # e.g., num_action = 3, [s1;s2] -> [s1;s1;s1;s2;s2;s2]
        if std <= 0:
            state = state.unsqueeze(1).repeat(1, num_action, 1).view(-1, state.size(-1)).to(self.device)
        else:   # std > 0
            if num_action == 1:
                noises = torch.normal(torch.zeros_like(state), torch.ones_like(state))  # B * state_dim
                state = (state + (std * noises).clamp(-0.05, 0.05)).to(self.device)  # B x state_dim
            else:   # num_action > 1
                state_noise = state.unsqueeze(1).repeat(1, num_action, 1)   # B * num_action * state_dim
                noises = torch.normal(torch.zeros_like(state_noise), torch.ones_like(state_noise))  # B * num_q_samples * state_dim
                state_noise = state_noise + (std * noises).clamp(-0.05, 0.05)  # N x num_action x state_dim
                state = torch.cat((state_noise, state.unsqueeze(1)), dim=1).view((batch_size * (num_action+1)), -1).to(self.device)  # (B * num_action) x state_dim
        # return [a11;a12;a13;a21;a22;a23] for [s1;s1;s1;s2;s2;s2]
        return state, self.forward(state)