in core/custom_training_loop.py [0:0]
def train(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
device: str,
save_dir: str,
logging_interval: int,
train_steps: int,
checkpoint_frequency: int,
dataset: Iterable,
worker_batch_size: int,
num_workers: Optional[int] = 0,
enable_amp: bool = False,
initial_checkpoint_dir: Optional[str] = None,
gradient_accumulation: Optional[int] = None,
logger_initializer: Optional[Callable] = None,
scheduler: _LRScheduler = None,
metrics: Optional[tm.MetricCollection] = None,
parameters_to_log: Optional[Dict[str, Callable]] = None,
tables_to_log: Optional[List[str]] = None,