def build_optimizer()

in projects/twhin/optimizer.py [0:0]


def build_optimizer(model: TwhinModel, config: TwhinModelConfig):
  """Builds an optimizer for a Twhin model combining the embeddings optimizer with an optimizer for per-relation translations.

  Args:
    model: TwhinModel to build optimizer for.
    config: TwhinConfig for model.

  Returns:
    Optimizer for model.
  """
  translation_optimizer_fn = functools.partial(
    get_optimizer_class(config.translation_optimizer),
    **get_optimizer_algorithm_config(config.translation_optimizer).dict(),
  )

  translation_optimizer = keyed.KeyedOptimizerWrapper(
    dict(in_backward_optimizer_filter(model.named_parameters())),
    optim_factory=translation_optimizer_fn,
  )

  lr_dict = {}
  for table in config.embeddings.tables:
    lr_dict[table.name] = _lr_from_config(table.optimizer)
  lr_dict[TRANSLATION_OPT_KEY] = _lr_from_config(config.translation_optimizer)

  logging.info(f"***** LR dict: {lr_dict} *****")

  logging.info(
    f"***** Combining fused optimizer {model.fused_optimizer} with operator optimizer: {translation_optimizer} *****"
  )
  optimizer = keyed.CombinedOptimizer(
    [
      (FUSED_OPT_KEY, model.fused_optimizer),
      (TRANSLATION_OPT_KEY, translation_optimizer),
    ]
  )

  # scheduler = LRShim(optimizer, lr_dict)
  scheduler = None

  logging.info(f"***** Combined optimizer after init: {optimizer} *****")

  return optimizer, scheduler