def get_checkpoints()

in common/checkpointing/snapshot.py [0:0]


def get_checkpoints(save_dir: str) -> List[str]:
  """Gets all checkpoints that have been fully written."""
  checkpoints = []
  fs = infer_fs(save_dir)
  if fs.exists(save_dir):
    prefix = GCS_PREFIX if is_gcs_fs(fs) else ""
    checkpoints = list(f"{prefix}{elem}" for elem in fs.ls(save_dir, detail=False))
    # Only take checkpoints that were fully written.
    checkpoints = list(
      filter(
        lambda path: fs.exists(f"{path}/{torchsnapshot.snapshot.SNAPSHOT_METADATA_FNAME}"),
        checkpoints,
      )
    )
    checkpoints = sorted(checkpoints, key=lambda path: int(os.path.basename(path)))
  return checkpoints