def run()

in projects/home/recap/main.py [0:0]


def run(unused_argv: str, data_service_dispatcher: Optional[str] = None):
  print("#" * 100)

  config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path)
  logging.info("Config: %s", config.pretty_print())

  device = setup_and_get_device()

  # Always enable tensorfloat on supported devices.
  torch.backends.cuda.matmul.allow_tf32 = True
  torch.backends.cudnn.allow_tf32 = True

  loss_fn = losses.build_multi_task_loss(
    loss_type=LossType.BCE_WITH_LOGITS,
    tasks=list(config.model.tasks.keys()),
    pos_weights=[task.pos_weight for task in config.model.tasks.values()],
  )

  # Since the prod model doesn't use large embeddings, for now we won't support them.
  assert config.model.large_embeddings is None

  train_dataset = ds.RecapDataset(
    data_config=config.train_data,
    dataset_service=data_service_dispatcher,
    mode=recap_config_mod.JobMode.TRAIN,
    compression=config.train_data.dataset_service_compression,
    vocab_mapper=None,
    repeat=True,
  )

  train_iterator = iter(train_dataset.to_dataloader())

  torch_element_spec = train_dataset.torch_element_spec

  model = model_mod.create_ranking_model(
    data_spec=torch_element_spec[0],
    config=config,
    loss_fn=loss_fn,
    device=device,
  )

  optimizer, scheduler = optimizer_mod.build_optimizer(model, config.optimizer, None)

  model = maybe_shard_model(model, device)

  datetime_str = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M")
  print(f"{datetime_str}\n", end="")

  if FLAGS.debug_loop:
    logging.warning("Running debug mode, slow!")
    train_mod = debug_training_loop
  else:
    train_mod = ctl

  train_mod.train(
    model=model,
    optimizer=optimizer,
    device=device,
    save_dir=config.training.save_dir,
    logging_interval=config.training.train_log_every_n,
    train_steps=config.training.num_train_steps,
    checkpoint_frequency=config.training.checkpoint_every_n,
    dataset=train_iterator,
    worker_batch_size=config.train_data.global_batch_size,
    enable_amp=False,
    initial_checkpoint_dir=config.training.initial_checkpoint_dir,
    gradient_accumulation=config.training.gradient_accumulation,
    scheduler=scheduler,
  )