in sourcecode/scoring/pandas_utils.py [0:0]
def safe_join(self) -> Callable:
"""Return a modified merge function that checks type stability."""
def _safe_join(*args, **kwargs):
"""Wrapper around pd.DataFrame.merge.
Args:
args: non-keyword arguments to pass through to merge.
kwargs: keyword arguments to pass through to merge.
"""
lines = []
check = self._get_check(lines, kwargs)
leftFrame = args[0]
rightFrame = args[1]
# Validate arguments are as expected
assert type(leftFrame) is pd.DataFrame
assert type(rightFrame) is pd.DataFrame
lines.extend(self._validate_dataframe(leftFrame))
lines.extend(self._validate_dataframe(rightFrame))
assert len(set(kwargs) - {"lsuffix", "rsuffix", "how"}) == 0, f"unexpected kwargs: {kwargs}"
# Validate the assumption that columns used as the join key in the index have the same type.
# This is analogous to validating that onCols match and have the same types in _safe_merge.
if len(leftFrame.index.names) == 1 and len(rightFrame.index.names) == 1:
match = leftFrame.index.dtype == rightFrame.index.dtype
elif len(leftFrame.index.names) == 1 and len(rightFrame.index.names) > 1:
indexTypes = dict(rightFrame.index.dtypes)
name = leftFrame.index.names[0]
assert name in indexTypes, f"{name} not found in {indexTypes}"
match = indexTypes[name] == leftFrame.index.dtype
elif len(leftFrame.index.names) > 1 and len(rightFrame.index.names) == 1:
indexTypes = dict(leftFrame.index.dtypes)
name = rightFrame.index.names[0]
assert name in indexTypes, f"{name} not found in {indexTypes}"
match = indexTypes[name] == rightFrame.index.dtype
else:
assert (
len(leftFrame.index.names) > 1
), f"unexpected left: {type(leftFrame.index)}, {leftFrame.index}"
assert (
len(rightFrame.index.names) > 1
), f"unexpected right: {type(rightFrame.index)}, {rightFrame.index}"
leftIndexTypes = dict(leftFrame.index.dtypes)
rightIndexTypes = dict(rightFrame.index.dtypes)
match = True
for col in set(leftIndexTypes) & set(rightIndexTypes):
match = match & (leftIndexTypes[col] == rightIndexTypes[col])
check(
list(set(leftFrame.index.names) | set(rightFrame.index.names)),
match,
"Join index mismatch:\nleft:\n{left}\nvs\nright:\n{right}".format(
left=leftFrame.index.dtype if len(leftFrame.index.names) == 1 else leftFrame.index.dtypes,
right=rightFrame.index.dtype
if len(rightFrame.index.names) == 1
else rightFrame.index.dtypes,
),
)
# Validate that input columns with the same name have the same types
leftDtypes = dict(leftFrame.dtypes)
rightDtypes = dict(rightFrame.dtypes)
for col in set(leftDtypes) & set(rightDtypes):
check(
col,
leftDtypes[col] == rightDtypes[col],
f"Input mismatch on {col}: left={leftDtypes[col]} vs right={rightDtypes[col]}",
)
# Validate that none of the columns in an index have the same name as a non-index column
# in the opposite dataframe
assert (
len(set(leftFrame.index.names) & set(rightFrame.columns)) == 0
), f"left index: {set(leftFrame.index.names)}; right columns {set(rightFrame.columns)}"
assert (
len(set(rightFrame.index.names) & set(leftFrame.columns)) == 0
), f"right index: {set(rightFrame.index.names)}; left columns {set(leftFrame.columns)}"
# Compute expected types for output columns
commonCols = set(leftFrame.columns) & set(rightFrame.columns)
expectedColTypes = dict()
leftSuffix = kwargs.get("lsuffix", "")
rightSuffix = kwargs.get("rsuffix", "")
for col in set(leftFrame.columns) | set(rightFrame.columns):
if col in commonCols:
expectedColTypes[f"{col}{leftSuffix}"] = leftDtypes[col]
expectedColTypes[f"{col}{rightSuffix}"] = rightDtypes[col]
elif col in leftDtypes:
assert col not in rightDtypes
expectedColTypes[col] = leftDtypes[col]
else:
expectedColTypes[col] = rightDtypes[col]
# Compute expected types for index columns
leftIndexCols = set(leftFrame.index.names)
rightIndexCols = set(rightFrame.index.names)
if len(leftIndexCols) > 1:
leftDtypes = dict(leftFrame.index.dtypes)
else:
leftDtypes = {leftFrame.index.name: rightFrame.index.dtype}
if len(rightIndexCols) > 1:
rightDtypes = dict(rightFrame.index.dtypes)
else:
rightDtypes = {rightFrame.index.name: rightFrame.index.dtype}
for col in leftIndexCols & rightIndexCols:
# For columns in both indices, type should not change if input types agree. If input types
# disagree, then we have no expectation.
if leftDtypes[col] == rightDtypes[col]:
expectedColTypes[col] = leftDtypes[col]
else:
expectedColTypes[col] = None
for col in (leftIndexCols | rightIndexCols) - (leftIndexCols & rightIndexCols):
# For columns in exactly one index, the expected output type should match the input column type
# and the column name should not change because we have validated that the column does not
# appear in the other dataframe
if col in leftDtypes:
assert col not in rightDtypes, f"unexpected column: {col}"
expectedColTypes[col] = leftDtypes[col]
else:
expectedColTypes[col] = rightDtypes[col]
# Perform join and validate results. Note that we already validated that the indices had the
# same columns and types, and that the "on" argument is unset, so now we only need to check
# the non-index columns.
result = self._origJoin(*args, **kwargs)
# Note that we must reset index to force any NaNs in the index to emerge as float types.
# See example below.
# left = pd.DataFrame({"idx0": [1, 2], "idx1": [11, 12], "val1": [4, 5]}).set_index(["idx0", "idx1"])
# right = pd.DataFrame({"idx0": [1, 2, 3], "idx2": [21, 22, 23], "val2": [7, 8, 9]}).set_index(["idx0", "idx2"])
# print(dict(left.join(right, how="outer").index.dtypes))
# print(dict(left.join(right, how="outer").reset_index(drop=False).dtypes))
# $> {'idx0': dtype('int64'), 'idx1': dtype('int64'), 'idx2': dtype('int64')}
# $> {'idx0': dtype('int64'), 'idx1': dtype('float64'), 'idx2': dtype('int64'), 'val1': dtype('float64'), 'val2': dtype('int64')}
resultDtypes = dict(result.reset_index(drop=False).dtypes)
# Add default type for index
if "index" not in expectedColTypes:
expectedColTypes["index"] = np.int64
for col, dtype in resultDtypes.items():
if len(col) == 2 and col[1] == "":
col = col[0]
check(
col,
dtype == expectedColTypes[col],
f"Output mismatch on {col}: result={dtype} expected={expectedColTypes[col]}",
)
lines.extend(self._validate_dataframe(result))
self._log_errors("JOIN", self._get_callsite(), lines)
return result
return _safe_join