sourcecode/scoring/matrix_factorization/model.py (52 lines of code) (raw):

from dataclasses import dataclass import logging from typing import Optional import torch logger = logging.getLogger("birdwatch.model") logger.setLevel(logging.INFO) @dataclass class ModelData: rating_labels: Optional[torch.FloatTensor] user_indexes: Optional[torch.IntTensor] note_indexes: Optional[torch.IntTensor] class BiasedMatrixFactorization(torch.nn.Module): """Matrix factorization algorithm class.""" def __init__( self, n_users: int, n_notes: int, n_factors: int = 1, use_global_intercept: bool = True, log: bool = True, ) -> None: """Initialize matrix factorization model using xavier_uniform for factors and zeros for intercepts. Args: n_users (int): number of raters n_notes (int): number of notes n_factors (int, optional): number of dimensions. Defaults to 1. Only 1 is supported. use_global_intercept (bool, optional): Defaults to True. """ super().__init__() self._log = log self.user_factors = torch.nn.Embedding(n_users, n_factors, sparse=False, dtype=torch.float32) self.note_factors = torch.nn.Embedding(n_notes, n_factors, sparse=False, dtype=torch.float32) self.user_intercepts = torch.nn.Embedding(n_users, 1, sparse=False, dtype=torch.float32) self.note_intercepts = torch.nn.Embedding(n_notes, 1, sparse=False, dtype=torch.float32) self.use_global_intercept = use_global_intercept self.global_intercept = torch.nn.parameter.Parameter(torch.zeros(1, 1, dtype=torch.float32)) torch.nn.init.xavier_uniform_(self.user_factors.weight) torch.nn.init.xavier_uniform_(self.note_factors.weight) self.user_intercepts.weight.data.fill_(0.0) self.note_intercepts.weight.data.fill_(0.0) self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def forward(self, data: ModelData): """ Forward pass: get predicted rating for user of note Args: data: ModelData object with the following attributes: user_indexes (torch.LongTensor): user id, shape (batch_size,) note_indexes (torch.LongTensor): note id, shape (batch_size,) Returns: torch.FloatTensor: predicted rating, shape (batch_size,) """ pred = self.user_intercepts(data.user_indexes) + self.note_intercepts(data.note_indexes) pred += (self.user_factors(data.user_indexes) * self.note_factors(data.note_indexes)).sum( 1, keepdim=True ) if self.use_global_intercept == True: pred += self.global_intercept return pred.squeeze() def freeze_rater_and_global_parameters(self): """Freeze rater and global parameters.""" self._freeze_parameters({"user", "global"}) def freeze_factors(self): """Freeze factors.""" self._freeze_parameters({"factor"}) def _freeze_parameters(self, words_to_freeze: set): """Freeze rater and global parameters.""" for name, param in self.named_parameters(): for word in words_to_freeze: if word in name: if self._log: logger.info(f"Freezing parameter: {name}") param.requires_grad_(False)