in sourcecode/scoring/pandas_utils.py [0:0]
def safe_merge(self) -> Callable:
"""Return a modified merge function that checks type stability."""
def _safe_merge(*args, **kwargs):
"""Wrapper around pd.DataFrame.merge.
Args:
args: non-keyword arguments to pass through to merge.
kwargs: keyword arguments to pass through to merge.
"""
lines = []
check = self._get_check(lines, kwargs)
leftFrame = args[0]
rightFrame = args[1]
# Validate that argument types are as expected
assert type(leftFrame) is pd.DataFrame
assert type(rightFrame) is pd.DataFrame
lines.extend(self._validate_dataframe(leftFrame))
lines.extend(self._validate_dataframe(rightFrame))
# Store dtypes and validate that any common columns have the same type
leftDtypes = dict(leftFrame.reset_index(drop=False).dtypes)
rightDtypes = dict(rightFrame.reset_index(drop=False).dtypes)
for col in set(leftDtypes) & set(rightDtypes):
check(
col,
leftDtypes[col] == rightDtypes[col],
f"Input mismatch on {col}: left={leftDtypes[col]} vs right={rightDtypes[col]}",
)
# Identify the columns we are merging on, if left_on and right_on are unset
if "on" in kwargs and type(kwargs["on"]) == str:
onCols = set([kwargs["on"]])
elif "on" in kwargs and type(kwargs["on"]) == list:
onCols = set(kwargs["on"])
elif "left_on" in kwargs:
assert "on" not in kwargs, "not expecting both on and left_on"
assert "right_on" in kwargs, "expecting both left_on and right_on to be set"
onCols = set()
else:
assert "on" not in kwargs, f"""unexpected type for on: {type(kwargs["on"])}"""
onCols = set(leftFrame.columns) & set(rightFrame.columns)
# Validate that merge columns have matching types
if "left_on" in kwargs:
assert "right_on" in kwargs
left_on = kwargs["left_on"]
right_on = kwargs["right_on"]
check(
[left_on, right_on],
leftDtypes[left_on] == rightDtypes[right_on],
f"Merge key mismatch on type({left_on})={leftDtypes[left_on]} vs type({right_on})={rightDtypes[right_on]}",
)
else:
assert len(onCols), "expected onCols to be defined since left_on was not"
assert "right_on" not in kwargs, "did not expect onCols and right_on"
for col in onCols:
check(
col,
leftDtypes[col] == rightDtypes[col],
f"Merge key mismatch on {col}: left={leftDtypes[col]} vs right={rightDtypes[col]}",
)
# Compute expected column types
leftSuffix, rightSuffix = kwargs.get("suffixes", ("_x", "_y"))
commonCols = set(leftFrame.columns) & set(rightFrame.columns)
expectedColTypes = dict()
for col in set(leftFrame.columns) | set(rightFrame.columns):
if col in onCols:
# Note that we check above whether leftDtypes[col] == rightDtypes[col] and either raise an
# error or log as appropriate if there is a mismatch.
if leftDtypes[col] == rightDtypes[col]:
expectedColTypes[col] = leftDtypes[col]
else:
# Set expectation to None since we don't know what will happen, but do want to log an
# error later
expectedColTypes[col] = None
elif col in commonCols:
expectedColTypes[f"{col}{leftSuffix}"] = leftDtypes[col]
expectedColTypes[f"{col}{rightSuffix}"] = rightDtypes[col]
elif col in leftDtypes:
assert col not in rightDtypes
expectedColTypes[col] = leftDtypes[col]
else:
expectedColTypes[col] = rightDtypes[col]
# Perform merge and validate results
result = self._origMerge(*args, **kwargs)
resultDtypes = dict(result.dtypes)
for col in resultDtypes:
check(
col,
resultDtypes[col] == expectedColTypes[col],
f"Output mismatch on {col}: result={resultDtypes[col]} expected={expectedColTypes[col]}",
)
lines.extend(self._validate_dataframe(result))
self._log_errors("MERGE", self._get_callsite(), lines)
return result
return _safe_merge