in projects/home/recap/data/dataset.py [0:0]
def _create_tf_dataset(self):
if hasattr(self, "_tf_dataset"):
raise ValueError("Do not call `_create_tf_dataset` more than once.")
world_size = dist.get_world_size() if dist.is_initialized() else 1
per_replica_bsz = (
self._batch_size_multiplier * self._data_config.global_batch_size // world_size
)
dataset: tf.data.Dataset = self._create_base_tf_dataset(
batch_size=per_replica_bsz,
)
if self._repeat:
logging.info("Repeating dataset")
dataset = dataset.repeat()
if self.dataset_service:
if self._num_concurrent_iterators > 1:
if not self.machines_config:
raise ValueError(
"Must supply a machine_config for autotuning in order to use >1 concurrent iterators"
)
dataset = dataset_lib.with_auto_tune_budget(
dataset,
machine_config=self.machines_config.chief,
num_concurrent_iterators=self.num_concurrent_iterators,
on_chief=False,
)
self.dataset_id, self.job_name = register_dataset(
dataset=dataset, dataset_service=self.dataset_service, compression=self.compression
)
dataset = distribute_from_dataset_id(
dataset_id=self.dataset_id, # type: ignore[arg-type]
job_name=self.job_name,
dataset_service=self.dataset_service,
compression=self.compression,
)
elif self._num_concurrent_iterators > 1:
if not self.machines_config:
raise ValueError(
"Must supply a machine_config for autotuning in order to use >1 concurrent iterators"
)
dataset = dataset_lib.with_auto_tune_budget(
dataset,
machine_config=self.machines_config.chief,
num_concurrent_iterators=self._num_concurrent_iterators,
on_chief=True,
)
# Vocabulary mapping happens on the training node, not in dds because of size.
if self._vocab_mapper:
dataset = dataset.map(self._vocab_mapper)
return dataset.prefetch(world_size * 2)