def forward()

in projects/home/recap/model/mask_net.py [0:0]


  def forward(self, inputs: torch.Tensor):
    if self.mask_net_config.use_parallel:
      mask_outputs = []
      for mask_layer in self._mask_blocks:
        mask_outputs.append(mask_layer(mask_input=inputs, net=inputs))
      # Share the outputs of the MaskBlocks.
      all_mask_outputs = torch.cat(mask_outputs, dim=1)
      output = (
        all_mask_outputs
        if self.mask_net_config.mlp is None
        else self._dense_layers(all_mask_outputs)["output"]
      )
      return {"output": output, "shared_layer": all_mask_outputs}
    else:
      net = inputs
      for mask_layer in self._mask_blocks:
        net = mask_layer(net=net, mask_input=inputs)
      # Share the output of the stacked MaskBlocks.
      output = net if self.mask_net_config.mlp is None else self._dense_layers[net]["output"]
      return {"output": output, "shared_layer": net}