def __init__()

in projects/twhin/models/models.py [0:0]


  def __init__(self, model_config: TwhinModelConfig, data_config: TwhinDataConfig):
    super().__init__()
    self.batch_size = data_config.per_replica_batch_size
    self.table_names = [table.name for table in model_config.embeddings.tables]
    self.large_embeddings = LargeEmbeddings(model_config.embeddings)
    self.embedding_dim = model_config.embeddings.tables[0].embedding_dim
    self.num_tables = len(model_config.embeddings.tables)
    self.in_batch_negatives = data_config.in_batch_negatives
    self.global_negatives = data_config.global_negatives
    self.num_relations = len(model_config.relations)

    # one bias per relation
    self.all_trans_embs = torch.nn.parameter.Parameter(
      torch.nn.init.uniform_(torch.empty(self.num_relations, self.embedding_dim))
    )