in reader/utils.py [0:0]
def create_default_pa_to_batch(schema) -> DataclassBatch:
""" """
_CustomBatch = DataclassBatch.from_schema("DefaultBatch", schema=schema)
def get_imputation_value(pa_type):
type_map = {
pa.float64(): pa.scalar(0, type=pa.float64()),
pa.int64(): pa.scalar(0, type=pa.int64()),
pa.string(): pa.scalar("", type=pa.string()),
}
if pa_type not in type_map:
raise Exception(f"Imputation for type {pa_type} not supported.")
return type_map[pa_type]
def _impute(array: pa.array) -> pa.array:
return array.fill_null(get_imputation_value(array.type))
def _column_to_tensor(record_batch: pa.RecordBatch):
tensors = {
col_name: pa_to_torch(_impute(record_batch.column(col_name)))
for col_name in record_batch.schema.names
}
return _CustomBatch(**tensors)
return _column_to_tensor