def __init__()

in projects/home/recap/model/mlp.py [0:0]


  def __init__(self, in_features: int, mlp_config: MlpConfig):
    super().__init__()
    self._mlp_config = mlp_config
    input_size = in_features
    layer_sizes = mlp_config.layer_sizes
    modules = []
    for layer_size in layer_sizes[:-1]:
      modules.append(torch.nn.Linear(input_size, layer_size, bias=True))

      if mlp_config.batch_norm:
        modules.append(
          torch.nn.BatchNorm1d(
            layer_size, affine=mlp_config.batch_norm.affine, momentum=mlp_config.batch_norm.momentum
          )
        )

      modules.append(torch.nn.ReLU())

      if mlp_config.dropout:
        modules.append(torch.nn.Dropout(mlp_config.dropout.rate))

      input_size = layer_size
    modules.append(torch.nn.Linear(input_size, layer_sizes[-1], bias=True))
    if mlp_config.final_layer_activation:
      modules.append(torch.nn.ReLU())
    self.layers = torch.nn.ModuleList(modules)
    self.layers.apply(_init_weights)