in sourcecode/scoring/pandas_utils.py [0:0]
def safe_concat(self) -> Callable:
"""Return a modified concat function that checks type stability."""
def _safe_concat(*args, **kwargs):
"""Wrapper around pd.concat
Args:
args: non-keyword arguments to pass through to merge.
kwargs: keyword arguments to pass through to merge.
"""
lines = []
check = self._get_check(lines, kwargs)
# Validate that all objects being concatenated are either Series or DataFrames
objs = args[0]
assert type(objs) == list, f"expected first argument to be a list: type={type(objs)}"
assert (
all(type(obj) == pd.Series for obj in objs)
or all(type(obj) == pd.DataFrame for obj in objs)
), f"Expected concat args to be either pd.Series or pd.DataFrame: {[type(obj) for obj in objs]}"
if type(objs[0]) == pd.Series:
if "axis" in kwargs and kwargs["axis"] == 1:
# Since the call is concatenating Series as columns in a DataFrame, validate that the sequence
# of Series dtypes matches the sequence of column dtypes in the dataframe.
result = self._origConcat(*args, **kwargs)
objDtypes = [obj.dtype for obj in objs]
assert len(objDtypes) == len(
result.dtypes
), f"dtype length mismatch: {len(objDtypes)} vs {len(result.dtypes)}"
for col, seriesType, colType in zip(result.columns, objDtypes, result.dtypes):
check(
col,
seriesType == colType,
f"Series concat on {col}: {seriesType} vs {colType}",
)
else:
# If Series, validate that all series were same type and return
seriesTypes = set(obj.dtype for obj in objs)
check(None, len(seriesTypes) == 1, f"More than 1 unique Series type: {seriesTypes}")
result = self._origConcat(*args, **kwargs)
else:
# If DataFrame, validate that all input columns with matching names have the same type
# and build expectation for output column types
assert type(objs[0]) == pd.DataFrame
# Validate all inputs
for dfArg in objs:
lines.extend(self._validate_dataframe(dfArg))
colTypes: Dict[str, List[type]] = dict()
for df in objs:
for col, dtype in df.reset_index(drop=False).dtypes.items():
if col not in colTypes:
colTypes[col] = []
colTypes[col].append(dtype)
# Perform concatenation and validate that there weren't any type changes
result = self._origConcat(*args, **kwargs)
for col, outputType in result.reset_index(drop=False).dtypes.items():
check(
col,
all(inputType == outputType for inputType in colTypes[col]),
f"DataFrame concat on {col}: output={outputType} inputs={colTypes[col]}",
)
if isinstance(result, pd.DataFrame):
lines.extend(self._validate_dataframe(result))
elif isinstance(result, pd.Series):
lines.extend(self._validate_series(result))
self._log_errors("CONCAT", self._get_callsite(), lines)
return result
return _safe_concat