def _get_reg_loss()

in sourcecode/scoring/matrix_factorization/matrix_factorization.py [0:0]


  def _get_reg_loss(self):
    l2_reg_loss = torch.tensor(0.0, dtype=torch.float32).to(self.mf_model.device)

    if self._ratingPerUserLossRatio is None:
      l2_reg_loss += self._userFactorLambda * (self.mf_model.user_factors.weight**2).mean()
      l2_reg_loss += self._userInterceptLambda * (self.mf_model.user_intercepts.weight**2).mean()
    else:
      simulatedNumberOfRatersForLoss = (
        len(self.trainModelData.rating_labels) / self._ratingPerUserLossRatio
      )
      l2_reg_loss += (
        self._userFactorLambda
        * (self.mf_model.user_factors.weight**2).sum()
        / simulatedNumberOfRatersForLoss
      )
      l2_reg_loss += (
        self._userInterceptLambda
        * (self.mf_model.user_intercepts.weight**2).sum()
        / simulatedNumberOfRatersForLoss
      )

    if self._ratingPerNoteLossRatio is None:
      l2_reg_loss += self._noteFactorLambda * (self.mf_model.note_factors.weight**2).mean()
      l2_reg_loss += self._noteInterceptLambda * (self.mf_model.note_intercepts.weight**2).mean()
      l2_reg_loss += (
        self._diamondLambda
        * (self.mf_model.note_factors.weight * self.mf_model.note_intercepts.weight).abs().mean()
      )
    else:
      simulatedNumberOfNotesForLoss = (
        len(self.trainModelData.rating_labels) / self._ratingPerNoteLossRatio
      )
      l2_reg_loss += (
        self._noteFactorLambda
        * (self.mf_model.note_factors.weight**2).sum()
        / simulatedNumberOfNotesForLoss
      )
      l2_reg_loss += (
        self._noteInterceptLambda
        * (self.mf_model.note_intercepts.weight**2).sum()
        / simulatedNumberOfNotesForLoss
      )
      l2_reg_loss += (
        self._diamondLambda
        * (self.mf_model.note_factors.weight * self.mf_model.note_intercepts.weight).abs().sum()
        / simulatedNumberOfNotesForLoss
      )

    l2_reg_loss += self._globalInterceptLambda * (self.mf_model.global_intercept**2).mean()

    return l2_reg_loss