in agents/ed_pcq.py [0:0]
def forward(self, *inputs, **kwargs):
flat_inputs = torch.cat(inputs, dim=-1)
state_dim = inputs[0].shape[-1]
dim = len(flat_inputs.shape)
# repeat h to make amenable to parallelization
# if dim = 3, then we probably already did this somewhere else
# (e.g. bootstrapping in training optimization)
if dim < 3:
flat_inputs = flat_inputs.unsqueeze(0)
if dim == 1:
flat_inputs = flat_inputs.unsqueeze(0)
flat_inputs = flat_inputs.repeat(self.ensemble_size, 1, 1)
# input normalization
h = flat_inputs
# standard feedforward network
for _, fc in enumerate(self.fcs):
h = fc(h)
h = self.hidden_activation(h)
if hasattr(self, 'layer_norm') and (self.layer_norm is not None):
h = self.layer_norm(h)
preactivation = self.last_fc(h)
output = self.output_activation(preactivation)
# if original dim was 1D, squeeze the extra created layer
if dim == 1:
output = output.squeeze(1)
# output is (ensemble_size, batch_size, output_size)
return output