sourcecode/scoring/pandas_utils.py (479 lines of code) (raw):
"""This module patches Pandas to alert or fail on unexpected dtype conversions.
The module corrently supports the merge, join and concat operations as these functions
can generate derived dataframes with type conversions. The patch can be configured to
either log to stderr or assert False when an unexpected type conversion is detected.
This module should support type-related work in the scorer, including:
* Setting all input datatypes to the appropriate np (non-nullable) or pd (nullable) datatype
for the associated input. For example, noteIds should be np.int64, timestamp of first
status should be pd.Int64Dtype, etc.
* Enforcing type expectations on outputs. For example, validating that the participantId
is an int64 and has not been converted to float.
* Fixing unexpected type conversion errors by specifying default values for rows that are
lacking columns during a merge, join or concat. For example, if we generate numRatings
and then join with noteStatusHistory, we should be able to pass fillna={"numRatings": 0}
to "merge" so that the resulting column should still have type np.int64 where missing
values have been filled with 0 (as opposed to cast to a float with missing values set to
np.NaN).
* Add an "allow_unsafe" keyword argument to merge, join and concat that overrides "fail"
and instead logs to stderr. This will allow us to default all current and new code to
enforced safe behavior except for callsites that haven't been fixed yet.
"""
from collections import Counter
from dataclasses import dataclass
from enum import Enum
from hashlib import sha256
import re
import sys
from threading import Lock
import traceback
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
from . import constants as c
import numpy as np
import pandas as pd
def get_df_fingerprint(df, cols):
"""Fingerprint the order of select column values within a dataframe."""
try:
strs = [
sha256(b"".join(map(lambda v: int(v).to_bytes(8, "big"), df[col]))).hexdigest()
for col in cols
]
return sha256(",".join(strs).encode("utf-8")).hexdigest()
except ValueError:
strs = [sha256(",".join(map(str, df[col])).encode("utf-8")).hexdigest() for col in cols]
return sha256(",".join(strs).encode("utf-8")).hexdigest()
def keep_columns(df: pd.DataFrame, cols: List[str]):
cols = [col for col in cols if col in df]
return df[cols]
def get_df_info(
df: pd.DataFrame, name: Optional[str] = None, deep: bool = False, counter: bool = False
) -> str:
"""Log dtype and RAM usage stats for each input DataFrame."""
stats = (
df.dtypes.to_frame().reset_index(drop=False).rename(columns={"index": "column", 0: "dtype"})
).merge(
# deep=True shows memory usage for the entire contained object (e.g. if the type
# of a column is "object", then deep=True shows the size of the objects instead
# of the size of the pointers.
df.memory_usage(index=True, deep=deep)
.to_frame()
.reset_index(drop=False)
.rename(columns={"index": "column", 0: "RAM"})
)
ramBytes = stats["RAM"].sum()
if name is not None:
lines = [f"""{name} total RAM: {ramBytes} bytes ({ramBytes * 1e-9:.3f} GB)"""]
else:
lines = [f"""total RAM: {ramBytes} bytes ({ramBytes * 1e-9:.3f} GB)"""]
lines.extend(str(stats).split("\n"))
if counter:
for col, dtype in zip(stats["column"], stats["dtype"]):
if dtype != object:
continue
lines.append(f"{col}: {Counter(type(obj) for obj in df[col])}")
return "\n".join(lines)
class TypeErrorCounter(object):
def __init__(self):
self._callCounts: Dict[Tuple[str, str], int] = dict()
self._typeErrors: Dict[Tuple[str, str], Counter[str]] = dict()
self._lock = Lock()
def log_errors(self, method: str, callsite: str, errors: List[str]) -> None:
key = (method, callsite)
with self._lock:
if key not in self._callCounts:
self._callCounts[key] = 0
self._callCounts[key] += 1
if key not in self._typeErrors:
self._typeErrors[key] = Counter()
for error in errors:
self._typeErrors[key][error] += 1
def get_summary(self):
lines = []
keys = [
(method, -1 * count, callsite) for ((method, callsite), count) in self._callCounts.items()
]
for method, count, callsite in sorted(keys):
lines.append(f"{method}: {-1 * count} BAD CALLS AT: {callsite.rstrip()}")
for error, errorCount in self._typeErrors[(method, callsite)].items():
lines.append(f" {errorCount:3d}x {error}")
lines.append("")
return "\n".join(lines)
class LogLevel(Enum):
# Raise an error if the expecatation is violated
FATAL = 1
# Log to stderr when the expectation is violated
ERROR = 2
# Log to stderr any time the column is observed
INFO = 3
@dataclass
class TypeExpectation:
dtype: type
logLevel: LogLevel
class PandasPatcher(object):
def __init__(
self, fail: bool, typeOverrides: Dict[str, TypeExpectation] = dict(), silent: bool = False
):
"""Initialize a PandasPatcher with particular failure and type expectations.
Args:
fail: Whether to raise errors or log to stderr when expectations are violated.
expectations: Type expecatations for select columns.
"""
self._silent = silent # Set to True to basically disable
self._fail = fail
self._counter = TypeErrorCounter()
self._origConcat = pd.concat
self._origJoin = pd.DataFrame.join
self._origMerge = pd.DataFrame.merge
self._origApply = pd.DataFrame.apply
self._origInit = pd.DataFrame.__init__
self._origGetItem = pd.DataFrame.__getitem__
self._origSetItem = pd.DataFrame.__setitem__
self._origLocGetItem = pd.core.indexing._LocationIndexer.__getitem__
self._origLocSetItem = pd.core.indexing._LocationIndexer.__setitem__
self._expectations = {
c.noteIdKey: TypeExpectation(np.int64, LogLevel.ERROR),
}
for column, expectation in typeOverrides.items():
self._expectations[column] = expectation
def get_summary(self) -> str:
return f"\nTYPE WARNING SUMMARY\n{self._counter.get_summary()}"
def _log_errors(self, method: str, callsite: str, lines: List[str]) -> None:
if not lines:
return
self._counter.log_errors(method, callsite, lines)
errorLines = "\n".join([f" PandasTypeError: {l}" for l in lines])
msg = f"\n{method} ERROR(S) AT: {callsite}\n{errorLines}\n"
if not self._silent:
print(msg, file=sys.stderr)
def _get_check(self, lines: List[str], kwargs: Dict) -> Callable:
"""Return a function which will either assert a condition or append to a list of errors.
Note that this function does not actually log to stderr, but rather appends to a list so
that all
"""
unsafeAllowed = set()
if "unsafeAllowed" in kwargs:
unsafeAllowedArg = kwargs["unsafeAllowed"]
if isinstance(unsafeAllowedArg, str):
unsafeAllowed = {unsafeAllowedArg}
elif isinstance(unsafeAllowedArg, List):
unsafeAllowed = set(unsafeAllowedArg)
else:
assert isinstance(unsafeAllowedArg, Set)
unsafeAllowed = unsafeAllowedArg
del kwargs["unsafeAllowed"]
def _check(columns: Any, condition: bool, msg: str):
if isinstance(columns, str):
failDisabled = columns in unsafeAllowed
elif isinstance(columns, List):
failDisabled = all(col in unsafeAllowed for col in columns)
else:
# Note there are multiple circumstances where the type of Columns may not be a str
# or List[str], including when we are concatenating a Series (column name will be
# set to None), when there are mulit-level column names (column name will be a tuple)
# or when Pandas has set column names to a RangeIndex.
failDisabled = False
if self._fail and not failDisabled:
assert condition, msg
elif not condition:
if failDisabled:
lines.append(f"{msg} (allowed)")
else:
lines.append(f"{msg} (UNALLOWED)")
return _check
def _get_callsite(self) -> str:
"""Return the file, function, line numer and pandas API call on a single line."""
for line in traceback.format_stack()[::-1]:
path = line.split(",")[0]
if "/pandas_utils.py" in path:
continue
if "/pandas/" in path:
continue
break
# Handle paths resulting from bazel invocation
match = re.match(r'^ File ".*?/site-packages(/.*?)", (.*?), (.*?)\n (.*)\n$', line)
if match:
return f"{match.group(1)}, {match.group(3)}, at {match.group(2)}: {match.group(4)}"
# Handle paths fresulting from pytest invocation
match = re.match(r'^ File ".*?/src/(test|main)/python(/.*?)", (.*?), (.*?)\n (.*)\n$', line)
if match:
return f"{match.group(2)}, {match.group(4)}, at {match.group(3)}: {match.group(5)}"
# Handle other paths (e.g. notebook, public code)
match = re.match(r'^ File "(.*?)", (.*?), (.*?)\n (.*)\n$', line)
if match:
return f"{match.group(1)}, {match.group(3)}, at {match.group(2)}: {match.group(4)}"
else:
stack = "\n\n".join(traceback.format_stack()[::-1])
print(f"parsing error:\n{stack}", file=sys.stderr)
return "parsing error. callsite unknown."
def _check_dtype(self, dtype: Any, expected: type) -> bool:
"""Return True IFF dtype corresponds to expected.
Note that for non-nullable columns, dtype may equal type (e.g. np.int64), but for nullable
columns the column type is actually an instance of a pandas dtype (e.g. pd.Int64Dtype)
"""
assert expected != object, "expectation must be more specific than object"
return dtype == expected or isinstance(dtype, expected)
def _check_name_and_type(self, name: str, dtype: Any) -> List[str]:
"""Returns a list of type mismatches if any are found, or raises an error."""
if name not in self._expectations:
return []
typeExpectation = self._expectations[name]
msg = f"Type expectation mismatch on {name}: found={dtype} expected={typeExpectation.dtype.__name__}"
match = self._check_dtype(dtype, typeExpectation.dtype)
if typeExpectation.logLevel == LogLevel.INFO:
return (
[msg]
if not match
else [
f"Type expectation match on {name}: found={dtype} expected={typeExpectation.dtype.__name__}"
]
)
elif typeExpectation.logLevel == LogLevel.ERROR or not self._fail:
return [msg] if not match else []
else:
assert typeExpectation.logLevel == LogLevel.FATAL
assert self._fail
assert match, msg
return []
def _validate_series(self, series: pd.Series) -> List[str]:
assert isinstance(series, pd.Series), f"unexpected type: {type(series)}"
return self._check_name_and_type(series.name, series.dtype)
def _validate_dataframe(self, df: pd.DataFrame) -> List[str]:
"""Returns a list of type mismatches if any are found, or raises an error."""
assert isinstance(df, pd.DataFrame), f"unexpected type: {type(df)}"
lines = []
# Check index types
if type(df.index) == pd.MultiIndex:
for name, dtype in df.index.dtypes.to_dict().items():
lines.extend(self._check_name_and_type(name, dtype))
elif type(df.index) == pd.RangeIndex or df.index.name is None:
# Index is uninteresting - none was specified by the caller.
pass
else:
lines.extend(self._check_name_and_type(df.index.name, df.index.dtype))
# Check column types
for name, dtype in df.dtypes.to_dict().items():
lines.extend(self._check_name_and_type(name, dtype))
return lines
def safe_init(self) -> Callable:
"""Return a modified __init__ function that checks type expectations."""
def _safe_init(*args, **kwargs):
"""Wrapper around pd.concat
Args:
args: non-keyword arguments to pass through to merge.
kwargs: keyword arguments to pass through to merge.
"""
df = args[0]
assert isinstance(df, pd.DataFrame), f"unexpected type: {type(df)}"
retVal = self._origInit(*args, **kwargs)
assert retVal is None
lines = self._validate_dataframe(df)
self._log_errors("INIT", self._get_callsite(), lines)
return retVal
return _safe_init
def safe_concat(self) -> Callable:
"""Return a modified concat function that checks type stability."""
def _safe_concat(*args, **kwargs):
"""Wrapper around pd.concat
Args:
args: non-keyword arguments to pass through to merge.
kwargs: keyword arguments to pass through to merge.
"""
lines = []
check = self._get_check(lines, kwargs)
# Validate that all objects being concatenated are either Series or DataFrames
objs = args[0]
assert type(objs) == list, f"expected first argument to be a list: type={type(objs)}"
assert (
all(type(obj) == pd.Series for obj in objs)
or all(type(obj) == pd.DataFrame for obj in objs)
), f"Expected concat args to be either pd.Series or pd.DataFrame: {[type(obj) for obj in objs]}"
if type(objs[0]) == pd.Series:
if "axis" in kwargs and kwargs["axis"] == 1:
# Since the call is concatenating Series as columns in a DataFrame, validate that the sequence
# of Series dtypes matches the sequence of column dtypes in the dataframe.
result = self._origConcat(*args, **kwargs)
objDtypes = [obj.dtype for obj in objs]
assert len(objDtypes) == len(
result.dtypes
), f"dtype length mismatch: {len(objDtypes)} vs {len(result.dtypes)}"
for col, seriesType, colType in zip(result.columns, objDtypes, result.dtypes):
check(
col,
seriesType == colType,
f"Series concat on {col}: {seriesType} vs {colType}",
)
else:
# If Series, validate that all series were same type and return
seriesTypes = set(obj.dtype for obj in objs)
check(None, len(seriesTypes) == 1, f"More than 1 unique Series type: {seriesTypes}")
result = self._origConcat(*args, **kwargs)
else:
# If DataFrame, validate that all input columns with matching names have the same type
# and build expectation for output column types
assert type(objs[0]) == pd.DataFrame
# Validate all inputs
for dfArg in objs:
lines.extend(self._validate_dataframe(dfArg))
colTypes: Dict[str, List[type]] = dict()
for df in objs:
for col, dtype in df.reset_index(drop=False).dtypes.items():
if col not in colTypes:
colTypes[col] = []
colTypes[col].append(dtype)
# Perform concatenation and validate that there weren't any type changes
result = self._origConcat(*args, **kwargs)
for col, outputType in result.reset_index(drop=False).dtypes.items():
check(
col,
all(inputType == outputType for inputType in colTypes[col]),
f"DataFrame concat on {col}: output={outputType} inputs={colTypes[col]}",
)
if isinstance(result, pd.DataFrame):
lines.extend(self._validate_dataframe(result))
elif isinstance(result, pd.Series):
lines.extend(self._validate_series(result))
self._log_errors("CONCAT", self._get_callsite(), lines)
return result
return _safe_concat
def safe_apply(self) -> Callable:
"""Return a modified apply function that checks type stability."""
def _safe_apply(*args, **kwargs):
"""Wrapper around pd.DataFrame.apply
Args:
args: non-keyword arguments to pass through to merge.
kwargs: keyword arguments to pass through to merge.
"""
# TODO: Flesh this out with additional expectatoins around input and output types
result = self._origApply(*args, **kwargs)
if isinstance(result, pd.DataFrame):
self._log_errors("APPLY", self._get_callsite(), self._validate_dataframe(result))
elif isinstance(result, pd.Series):
self._log_errors("APPLY", self._get_callsite(), self._validate_series(result))
return result
return _safe_apply
def safe_merge(self) -> Callable:
"""Return a modified merge function that checks type stability."""
def _safe_merge(*args, **kwargs):
"""Wrapper around pd.DataFrame.merge.
Args:
args: non-keyword arguments to pass through to merge.
kwargs: keyword arguments to pass through to merge.
"""
lines = []
check = self._get_check(lines, kwargs)
leftFrame = args[0]
rightFrame = args[1]
# Validate that argument types are as expected
assert type(leftFrame) is pd.DataFrame
assert type(rightFrame) is pd.DataFrame
lines.extend(self._validate_dataframe(leftFrame))
lines.extend(self._validate_dataframe(rightFrame))
# Store dtypes and validate that any common columns have the same type
leftDtypes = dict(leftFrame.reset_index(drop=False).dtypes)
rightDtypes = dict(rightFrame.reset_index(drop=False).dtypes)
for col in set(leftDtypes) & set(rightDtypes):
check(
col,
leftDtypes[col] == rightDtypes[col],
f"Input mismatch on {col}: left={leftDtypes[col]} vs right={rightDtypes[col]}",
)
# Identify the columns we are merging on, if left_on and right_on are unset
if "on" in kwargs and type(kwargs["on"]) == str:
onCols = set([kwargs["on"]])
elif "on" in kwargs and type(kwargs["on"]) == list:
onCols = set(kwargs["on"])
elif "left_on" in kwargs:
assert "on" not in kwargs, "not expecting both on and left_on"
assert "right_on" in kwargs, "expecting both left_on and right_on to be set"
onCols = set()
else:
assert "on" not in kwargs, f"""unexpected type for on: {type(kwargs["on"])}"""
onCols = set(leftFrame.columns) & set(rightFrame.columns)
# Validate that merge columns have matching types
if "left_on" in kwargs:
assert "right_on" in kwargs
left_on = kwargs["left_on"]
right_on = kwargs["right_on"]
check(
[left_on, right_on],
leftDtypes[left_on] == rightDtypes[right_on],
f"Merge key mismatch on type({left_on})={leftDtypes[left_on]} vs type({right_on})={rightDtypes[right_on]}",
)
else:
assert len(onCols), "expected onCols to be defined since left_on was not"
assert "right_on" not in kwargs, "did not expect onCols and right_on"
for col in onCols:
check(
col,
leftDtypes[col] == rightDtypes[col],
f"Merge key mismatch on {col}: left={leftDtypes[col]} vs right={rightDtypes[col]}",
)
# Compute expected column types
leftSuffix, rightSuffix = kwargs.get("suffixes", ("_x", "_y"))
commonCols = set(leftFrame.columns) & set(rightFrame.columns)
expectedColTypes = dict()
for col in set(leftFrame.columns) | set(rightFrame.columns):
if col in onCols:
# Note that we check above whether leftDtypes[col] == rightDtypes[col] and either raise an
# error or log as appropriate if there is a mismatch.
if leftDtypes[col] == rightDtypes[col]:
expectedColTypes[col] = leftDtypes[col]
else:
# Set expectation to None since we don't know what will happen, but do want to log an
# error later
expectedColTypes[col] = None
elif col in commonCols:
expectedColTypes[f"{col}{leftSuffix}"] = leftDtypes[col]
expectedColTypes[f"{col}{rightSuffix}"] = rightDtypes[col]
elif col in leftDtypes:
assert col not in rightDtypes
expectedColTypes[col] = leftDtypes[col]
else:
expectedColTypes[col] = rightDtypes[col]
# Perform merge and validate results
result = self._origMerge(*args, **kwargs)
resultDtypes = dict(result.dtypes)
for col in resultDtypes:
check(
col,
resultDtypes[col] == expectedColTypes[col],
f"Output mismatch on {col}: result={resultDtypes[col]} expected={expectedColTypes[col]}",
)
lines.extend(self._validate_dataframe(result))
self._log_errors("MERGE", self._get_callsite(), lines)
return result
return _safe_merge
def safe_join(self) -> Callable:
"""Return a modified merge function that checks type stability."""
def _safe_join(*args, **kwargs):
"""Wrapper around pd.DataFrame.merge.
Args:
args: non-keyword arguments to pass through to merge.
kwargs: keyword arguments to pass through to merge.
"""
lines = []
check = self._get_check(lines, kwargs)
leftFrame = args[0]
rightFrame = args[1]
# Validate arguments are as expected
assert type(leftFrame) is pd.DataFrame
assert type(rightFrame) is pd.DataFrame
lines.extend(self._validate_dataframe(leftFrame))
lines.extend(self._validate_dataframe(rightFrame))
assert len(set(kwargs) - {"lsuffix", "rsuffix", "how"}) == 0, f"unexpected kwargs: {kwargs}"
# Validate the assumption that columns used as the join key in the index have the same type.
# This is analogous to validating that onCols match and have the same types in _safe_merge.
if len(leftFrame.index.names) == 1 and len(rightFrame.index.names) == 1:
match = leftFrame.index.dtype == rightFrame.index.dtype
elif len(leftFrame.index.names) == 1 and len(rightFrame.index.names) > 1:
indexTypes = dict(rightFrame.index.dtypes)
name = leftFrame.index.names[0]
assert name in indexTypes, f"{name} not found in {indexTypes}"
match = indexTypes[name] == leftFrame.index.dtype
elif len(leftFrame.index.names) > 1 and len(rightFrame.index.names) == 1:
indexTypes = dict(leftFrame.index.dtypes)
name = rightFrame.index.names[0]
assert name in indexTypes, f"{name} not found in {indexTypes}"
match = indexTypes[name] == rightFrame.index.dtype
else:
assert (
len(leftFrame.index.names) > 1
), f"unexpected left: {type(leftFrame.index)}, {leftFrame.index}"
assert (
len(rightFrame.index.names) > 1
), f"unexpected right: {type(rightFrame.index)}, {rightFrame.index}"
leftIndexTypes = dict(leftFrame.index.dtypes)
rightIndexTypes = dict(rightFrame.index.dtypes)
match = True
for col in set(leftIndexTypes) & set(rightIndexTypes):
match = match & (leftIndexTypes[col] == rightIndexTypes[col])
check(
list(set(leftFrame.index.names) | set(rightFrame.index.names)),
match,
"Join index mismatch:\nleft:\n{left}\nvs\nright:\n{right}".format(
left=leftFrame.index.dtype if len(leftFrame.index.names) == 1 else leftFrame.index.dtypes,
right=rightFrame.index.dtype
if len(rightFrame.index.names) == 1
else rightFrame.index.dtypes,
),
)
# Validate that input columns with the same name have the same types
leftDtypes = dict(leftFrame.dtypes)
rightDtypes = dict(rightFrame.dtypes)
for col in set(leftDtypes) & set(rightDtypes):
check(
col,
leftDtypes[col] == rightDtypes[col],
f"Input mismatch on {col}: left={leftDtypes[col]} vs right={rightDtypes[col]}",
)
# Validate that none of the columns in an index have the same name as a non-index column
# in the opposite dataframe
assert (
len(set(leftFrame.index.names) & set(rightFrame.columns)) == 0
), f"left index: {set(leftFrame.index.names)}; right columns {set(rightFrame.columns)}"
assert (
len(set(rightFrame.index.names) & set(leftFrame.columns)) == 0
), f"right index: {set(rightFrame.index.names)}; left columns {set(leftFrame.columns)}"
# Compute expected types for output columns
commonCols = set(leftFrame.columns) & set(rightFrame.columns)
expectedColTypes = dict()
leftSuffix = kwargs.get("lsuffix", "")
rightSuffix = kwargs.get("rsuffix", "")
for col in set(leftFrame.columns) | set(rightFrame.columns):
if col in commonCols:
expectedColTypes[f"{col}{leftSuffix}"] = leftDtypes[col]
expectedColTypes[f"{col}{rightSuffix}"] = rightDtypes[col]
elif col in leftDtypes:
assert col not in rightDtypes
expectedColTypes[col] = leftDtypes[col]
else:
expectedColTypes[col] = rightDtypes[col]
# Compute expected types for index columns
leftIndexCols = set(leftFrame.index.names)
rightIndexCols = set(rightFrame.index.names)
if len(leftIndexCols) > 1:
leftDtypes = dict(leftFrame.index.dtypes)
else:
leftDtypes = {leftFrame.index.name: rightFrame.index.dtype}
if len(rightIndexCols) > 1:
rightDtypes = dict(rightFrame.index.dtypes)
else:
rightDtypes = {rightFrame.index.name: rightFrame.index.dtype}
for col in leftIndexCols & rightIndexCols:
# For columns in both indices, type should not change if input types agree. If input types
# disagree, then we have no expectation.
if leftDtypes[col] == rightDtypes[col]:
expectedColTypes[col] = leftDtypes[col]
else:
expectedColTypes[col] = None
for col in (leftIndexCols | rightIndexCols) - (leftIndexCols & rightIndexCols):
# For columns in exactly one index, the expected output type should match the input column type
# and the column name should not change because we have validated that the column does not
# appear in the other dataframe
if col in leftDtypes:
assert col not in rightDtypes, f"unexpected column: {col}"
expectedColTypes[col] = leftDtypes[col]
else:
expectedColTypes[col] = rightDtypes[col]
# Perform join and validate results. Note that we already validated that the indices had the
# same columns and types, and that the "on" argument is unset, so now we only need to check
# the non-index columns.
result = self._origJoin(*args, **kwargs)
# Note that we must reset index to force any NaNs in the index to emerge as float types.
# See example below.
# left = pd.DataFrame({"idx0": [1, 2], "idx1": [11, 12], "val1": [4, 5]}).set_index(["idx0", "idx1"])
# right = pd.DataFrame({"idx0": [1, 2, 3], "idx2": [21, 22, 23], "val2": [7, 8, 9]}).set_index(["idx0", "idx2"])
# print(dict(left.join(right, how="outer").index.dtypes))
# print(dict(left.join(right, how="outer").reset_index(drop=False).dtypes))
# $> {'idx0': dtype('int64'), 'idx1': dtype('int64'), 'idx2': dtype('int64')}
# $> {'idx0': dtype('int64'), 'idx1': dtype('float64'), 'idx2': dtype('int64'), 'val1': dtype('float64'), 'val2': dtype('int64')}
resultDtypes = dict(result.reset_index(drop=False).dtypes)
# Add default type for index
if "index" not in expectedColTypes:
expectedColTypes["index"] = np.int64
for col, dtype in resultDtypes.items():
if len(col) == 2 and col[1] == "":
col = col[0]
check(
col,
dtype == expectedColTypes[col],
f"Output mismatch on {col}: result={dtype} expected={expectedColTypes[col]}",
)
lines.extend(self._validate_dataframe(result))
self._log_errors("JOIN", self._get_callsite(), lines)
return result
return _safe_join
# TODO: restore original functionality before return
# TODO: make enforce_types an explicit arguemnt so this is less error prone
def patch_pandas(main: Callable) -> Callable:
"""Return a decorator for wrapping main with pandas patching and logging
Args:
main: "main" function for program binary
"""
def _inner(*args, **kwargs) -> Any:
"""Determine patching behavior, apply patch and add logging."""
print("Patching pandas")
if "args" in kwargs:
# Handle birdwatch/scoring/src/main/python/public/scoring/runner.py, which expects
# args as a keyword argument and not as a positional argument.
assert len(args) == 0, f"positional arguments not expected, but found {len(args)}"
clArgs = kwargs["args"]
else:
# Handle the following, which expect args as the second positional argument:
# birdwatch/scoring/src/main/python/run_post_selection_similarity.py
# birdwatch/scoring/src/main/python/run_prescoring.py
# birdwatch/scoring/src/main/python/run_final_scoring.py
# birdwatch/scoring/src/main/python/run_contributor_scoring.py
# birdwatch/scoring/src/main/python/run.py
assert len(args) == 1, f"unexpected 1 positional args, but found {len(args)}"
assert len(kwargs) == 0, f"expected kwargs to be empty, but found {len(kwargs)}"
clArgs = args[0]
# Apply patches, configured based on whether types should be enforced or logged
patcher = PandasPatcher(clArgs.enforce_types)
pd.concat = patcher.safe_concat()
# Note that this will work when calling df1.merge(df2) because the first argument
# to "merge" is df1 (i.e. self).
pd.DataFrame.merge = patcher.safe_merge()
pd.DataFrame.join = patcher.safe_join()
pd.DataFrame.apply = patcher.safe_apply()
pd.DataFrame.__init__ = patcher.safe_init()
# Run main
retVal = main(*args, **kwargs)
# Log type error summary
if hasattr(clArgs, "parallel") and not clArgs.parallel:
print(patcher.get_summary(), file=sys.stderr)
else:
# Don't show type summary because counters will be inaccurate due to scorers running
# in their own process.
print("Type summary omitted when running in parallel.", file=sys.stderr)
# Return result of main
return retVal
return _inner