def _init_tensor_spec()

in projects/home/recap/data/dataset.py [0:0]


  def _init_tensor_spec(self):
    def _tensor_spec_to_torch_shape(spec):
      if spec.shape is None:
        return None
      shape = [x if x is not None else -1 for x in spec.shape]
      return torch.Size(shape)

    self.torch_element_spec = tf.nest.map_structure(
      _tensor_spec_to_torch_shape, self._tf_dataset.element_spec
    )