def __init__()

in projects/home/recap/model/mask_net.py [0:0]


  def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int):
    super().__init__()
    self.mask_net_config = mask_net_config
    mask_blocks = []

    if mask_net_config.use_parallel:
      total_output_mask_blocks = 0
      for mask_block_config in mask_net_config.mask_blocks:
        mask_blocks.append(MaskBlock(mask_block_config, in_features, in_features))
        total_output_mask_blocks += mask_block_config.output_size
      self._mask_blocks = torch.nn.ModuleList(mask_blocks)
    else:
      input_size = in_features
      for mask_block_config in mask_net_config.mask_blocks:
        mask_blocks.append(MaskBlock(mask_block_config, input_size, in_features))
        input_size = mask_block_config.output_size

      self._mask_blocks = torch.nn.ModuleList(mask_blocks)
      total_output_mask_blocks = mask_block_config.output_size

    if mask_net_config.mlp:
      self._dense_layers = mlp.Mlp(total_output_mask_blocks, mask_net_config.mlp)
      self.out_features = mask_net_config.mlp.layer_sizes[-1]
    else:
      self.out_features = total_output_mask_blocks
    self.shared_size = total_output_mask_blocks