def _create_tf_dataset()

in projects/home/recap/data/dataset.py [0:0]


  def _create_tf_dataset(self):
    if hasattr(self, "_tf_dataset"):
      raise ValueError("Do not call `_create_tf_dataset` more than once.")

    world_size = dist.get_world_size() if dist.is_initialized() else 1
    per_replica_bsz = (
      self._batch_size_multiplier * self._data_config.global_batch_size // world_size
    )

    dataset: tf.data.Dataset = self._create_base_tf_dataset(
      batch_size=per_replica_bsz,
    )

    if self._repeat:
      logging.info("Repeating dataset")
      dataset = dataset.repeat()

    if self.dataset_service:
      if self._num_concurrent_iterators > 1:
        if not self.machines_config:
          raise ValueError(
            "Must supply a machine_config for autotuning in order to use >1 concurrent iterators"
          )
        dataset = dataset_lib.with_auto_tune_budget(
          dataset,
          machine_config=self.machines_config.chief,
          num_concurrent_iterators=self.num_concurrent_iterators,
          on_chief=False,
        )

      self.dataset_id, self.job_name = register_dataset(
        dataset=dataset, dataset_service=self.dataset_service, compression=self.compression
      )
      dataset = distribute_from_dataset_id(
        dataset_id=self.dataset_id,  # type: ignore[arg-type]
        job_name=self.job_name,
        dataset_service=self.dataset_service,
        compression=self.compression,
      )

    elif self._num_concurrent_iterators > 1:
      if not self.machines_config:
        raise ValueError(
          "Must supply a machine_config for autotuning in order to use >1 concurrent iterators"
        )
      dataset = dataset_lib.with_auto_tune_budget(
        dataset,
        machine_config=self.machines_config.chief,
        num_concurrent_iterators=self._num_concurrent_iterators,
        on_chief=True,
      )

    # Vocabulary mapping happens on the training node, not in dds because of size.
    if self._vocab_mapper:
      dataset = dataset.map(self._vocab_mapper)

    return dataset.prefetch(world_size * 2)