def train()

in core/custom_training_loop.py [0:0]


def train(
  model: torch.nn.Module,
  optimizer: torch.optim.Optimizer,
  device: str,
  save_dir: str,
  logging_interval: int,
  train_steps: int,
  checkpoint_frequency: int,
  dataset: Iterable,
  worker_batch_size: int,
  num_workers: Optional[int] = 0,
  enable_amp: bool = False,
  initial_checkpoint_dir: Optional[str] = None,
  gradient_accumulation: Optional[int] = None,
  logger_initializer: Optional[Callable] = None,
  scheduler: _LRScheduler = None,
  metrics: Optional[tm.MetricCollection] = None,
  parameters_to_log: Optional[Dict[str, Callable]] = None,
  tables_to_log: Optional[List[str]] = None,