def apply_optimizers()

in projects/twhin/models/models.py [0:0]


def apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig):
  for table in model_config.embeddings.tables:
    optimizer_class = get_optimizer_class(table.optimizer)
    optimizer_kwargs = get_optimizer_algorithm_config(table.optimizer).dict()
    params = [
      param
      for name, param in model.large_embeddings.ebc.named_parameters()
      if (name.startswith(f"embedding_bags.{table.name}"))
    ]
    apply_optimizer_in_backward(
      optimizer_class=optimizer_class,
      params=params,
      optimizer_kwargs=optimizer_kwargs,
    )

  return model