def create_default_pa_to_batch()

in reader/utils.py [0:0]


def create_default_pa_to_batch(schema) -> DataclassBatch:
  """ """
  _CustomBatch = DataclassBatch.from_schema("DefaultBatch", schema=schema)

  def get_imputation_value(pa_type):
    type_map = {
      pa.float64(): pa.scalar(0, type=pa.float64()),
      pa.int64(): pa.scalar(0, type=pa.int64()),
      pa.string(): pa.scalar("", type=pa.string()),
    }
    if pa_type not in type_map:
      raise Exception(f"Imputation for type {pa_type} not supported.")
    return type_map[pa_type]

  def _impute(array: pa.array) -> pa.array:
    return array.fill_null(get_imputation_value(array.type))

  def _column_to_tensor(record_batch: pa.RecordBatch):
    tensors = {
      col_name: pa_to_torch(_impute(record_batch.column(col_name)))
      for col_name in record_batch.schema.names
    }
    return _CustomBatch(**tensors)

  return _column_to_tensor