def _create_base_tf_dataset()

in projects/home/recap/data/dataset.py [0:0]


  def _create_base_tf_dataset(self, batch_size: int):
    if self._data_config.inputs:
      glob = self._data_config.inputs
      filenames = sorted(tf.io.gfile.glob(glob))
    elif self._data_config.explicit_datetime_inputs:
      num_missing_hours_tol = self._data_config.explicit_datetime_inputs.num_missing_hours_tol
      filenames, num_hours_missing = get_explicit_datetime_inputs_files(
        self._data_config.explicit_datetime_inputs,
        increment="hourly",
      )
      if num_hours_missing > num_missing_hours_tol:
        raise ValueError(
          f"We are missing {num_hours_missing} hours of data"
          f"more than tolerance {num_missing_hours_tol}."
        )
    elif self._data_config.explicit_date_inputs:
      num_missing_days_tol = self._data_config.explicit_date_inputs.num_missing_days_tol
      filenames, num_days_missing = get_explicit_datetime_inputs_files(
        self._data_config.explicit_date_inputs,
        increment="daily",
      )
      if num_days_missing > num_missing_days_tol:
        raise ValueError(
          f"We are missing {num_days_missing} days of data"
          f"more than tolerance {num_missing_days_tol}."
        )
    else:
      raise ValueError(
        "Must specifiy either `inputs`, `explicit_datetime_inputs`, or `explicit_date_inputs` in data_config"
      )

    num_files = len(filenames)
    logging.info(f"Found {num_files} data files")
    if num_files < 1:
      raise ValueError("No data files found")

    if self._data_config.num_files_to_keep is not None:
      filenames = filenames[: self._data_config.num_files_to_keep]
      logging.info(f"Retaining only {len(filenames)} files.")

    filenames_ds = (
      tf.data.Dataset.from_tensor_slices(filenames).shuffle(len(filenames))
      # Because of drop_remainder, if our dataset does not fill
      # up a batch, it will emit nothing without this repeat.
      .repeat(-1)
    )

    if self._data_config.file_batch_size:
      filenames_ds = filenames_ds.batch(self._data_config.file_batch_size)

    def per_shard_dataset(filename):
      ds = tf.data.TFRecordDataset([filename], compression_type="GZIP")
      return ds.prefetch(4)

    ds = filenames_ds.interleave(
      per_shard_dataset,
      block_length=4,
      deterministic=False,
      num_parallel_calls=self._data_config.interleave_num_parallel_calls
      or tf.data.experimental.AUTOTUNE,
    )

    # Combine functions into one map call to reduce overhead.
    map_fn = functools.partial(
      _chain,
      f1=self._parse_fn,
      f2=self._output_map_fn,
    )

    # Shuffle -> Batch -> Parse is the correct ordering
    # Shuffling needs to be performed before batching otherwise there is not much point
    # Batching happens before parsing because tf.Example parsing is actually vectorized
    #     and works much faster overall on batches of data.
    ds = (
      # DANGER DANGER: there is a default shuffle size here.
      ds.shuffle(self._data_config.examples_shuffle_buffer_size)
      .batch(batch_size=batch_size, drop_remainder=True)
      .map(
        map_fn,
        num_parallel_calls=self._data_config.map_num_parallel_calls
        or tf.data.experimental.AUTOTUNE,
      )
    )

    if self._data_config.cache:
      ds = ds.cache()

    if self._data_config.ignore_data_errors:
      ds = ds.apply(tf.data.experimental.ignore_errors())

    options = tf.data.Options()
    options.experimental_deterministic = False
    ds = ds.with_options(options)

    return ds