in agents/bc_mmd.py [0:0]
def decode_multiple(self, state, z=None, num_decode=10):
"""Decode 10 samples atleast"""
if z is None:
z = torch.FloatTensor(np.random.normal(0, 1, size=(state.size(0), num_decode, self.latent_dim))).to(
self.device).clamp(-0.5, 0.5)
a = F.relu(self.d1(torch.cat([state.unsqueeze(0).repeat(num_decode, 1, 1).permute(1, 0, 2), z], 2)))
a = F.relu(self.d2(a))
return self.max_action * torch.tanh(self.d3(a)), self.d3(a)