def checkpoints_iterator()

in common/checkpointing/snapshot.py [0:0]


def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, timeout: int = 1800):
  """Simplified equivalent of tf.train.checkpoints_iterator.

  Args:
    seconds_to_sleep: time between polling calls.
    timeout: how long to wait for a new checkpoint.

  """

  def _poll(last_checkpoint: Optional[str] = None):
    stop_time = time.time() + timeout
    while True:
      _checkpoint_path = get_checkpoint(save_dir, missing_ok=True)
      if not _checkpoint_path or _checkpoint_path == last_checkpoint:
        if time.time() + seconds_to_sleep > stop_time:
          logging.info(
            f"Timed out waiting for next available checkpoint from {save_dir} for {timeout}s."
          )
          return None
        logging.info(f"Waiting for next available checkpoint from {save_dir}.")
        time.sleep(seconds_to_sleep)
      else:
        logging.info(f"Found latest checkpoint {_checkpoint_path}.")
        return _checkpoint_path

  checkpoint_path = None
  while True:
    new_checkpoint = _poll(checkpoint_path)
    if not new_checkpoint:
      return
    checkpoint_path = new_checkpoint
    yield checkpoint_path