sourcecode/scoring/pandas_utils.py (479 lines of code) (raw):

"""This module patches Pandas to alert or fail on unexpected dtype conversions. The module corrently supports the merge, join and concat operations as these functions can generate derived dataframes with type conversions. The patch can be configured to either log to stderr or assert False when an unexpected type conversion is detected. This module should support type-related work in the scorer, including: * Setting all input datatypes to the appropriate np (non-nullable) or pd (nullable) datatype for the associated input. For example, noteIds should be np.int64, timestamp of first status should be pd.Int64Dtype, etc. * Enforcing type expectations on outputs. For example, validating that the participantId is an int64 and has not been converted to float. * Fixing unexpected type conversion errors by specifying default values for rows that are lacking columns during a merge, join or concat. For example, if we generate numRatings and then join with noteStatusHistory, we should be able to pass fillna={"numRatings": 0} to "merge" so that the resulting column should still have type np.int64 where missing values have been filled with 0 (as opposed to cast to a float with missing values set to np.NaN). * Add an "allow_unsafe" keyword argument to merge, join and concat that overrides "fail" and instead logs to stderr. This will allow us to default all current and new code to enforced safe behavior except for callsites that haven't been fixed yet. """ from collections import Counter from dataclasses import dataclass from enum import Enum from hashlib import sha256 import re import sys from threading import Lock import traceback from typing import Any, Callable, Dict, List, Optional, Set, Tuple from . import constants as c import numpy as np import pandas as pd def get_df_fingerprint(df, cols): """Fingerprint the order of select column values within a dataframe.""" try: strs = [ sha256(b"".join(map(lambda v: int(v).to_bytes(8, "big"), df[col]))).hexdigest() for col in cols ] return sha256(",".join(strs).encode("utf-8")).hexdigest() except ValueError: strs = [sha256(",".join(map(str, df[col])).encode("utf-8")).hexdigest() for col in cols] return sha256(",".join(strs).encode("utf-8")).hexdigest() def keep_columns(df: pd.DataFrame, cols: List[str]): cols = [col for col in cols if col in df] return df[cols] def get_df_info( df: pd.DataFrame, name: Optional[str] = None, deep: bool = False, counter: bool = False ) -> str: """Log dtype and RAM usage stats for each input DataFrame.""" stats = ( df.dtypes.to_frame().reset_index(drop=False).rename(columns={"index": "column", 0: "dtype"}) ).merge( # deep=True shows memory usage for the entire contained object (e.g. if the type # of a column is "object", then deep=True shows the size of the objects instead # of the size of the pointers. df.memory_usage(index=True, deep=deep) .to_frame() .reset_index(drop=False) .rename(columns={"index": "column", 0: "RAM"}) ) ramBytes = stats["RAM"].sum() if name is not None: lines = [f"""{name} total RAM: {ramBytes} bytes ({ramBytes * 1e-9:.3f} GB)"""] else: lines = [f"""total RAM: {ramBytes} bytes ({ramBytes * 1e-9:.3f} GB)"""] lines.extend(str(stats).split("\n")) if counter: for col, dtype in zip(stats["column"], stats["dtype"]): if dtype != object: continue lines.append(f"{col}: {Counter(type(obj) for obj in df[col])}") return "\n".join(lines) class TypeErrorCounter(object): def __init__(self): self._callCounts: Dict[Tuple[str, str], int] = dict() self._typeErrors: Dict[Tuple[str, str], Counter[str]] = dict() self._lock = Lock() def log_errors(self, method: str, callsite: str, errors: List[str]) -> None: key = (method, callsite) with self._lock: if key not in self._callCounts: self._callCounts[key] = 0 self._callCounts[key] += 1 if key not in self._typeErrors: self._typeErrors[key] = Counter() for error in errors: self._typeErrors[key][error] += 1 def get_summary(self): lines = [] keys = [ (method, -1 * count, callsite) for ((method, callsite), count) in self._callCounts.items() ] for method, count, callsite in sorted(keys): lines.append(f"{method}: {-1 * count} BAD CALLS AT: {callsite.rstrip()}") for error, errorCount in self._typeErrors[(method, callsite)].items(): lines.append(f" {errorCount:3d}x {error}") lines.append("") return "\n".join(lines) class LogLevel(Enum): # Raise an error if the expecatation is violated FATAL = 1 # Log to stderr when the expectation is violated ERROR = 2 # Log to stderr any time the column is observed INFO = 3 @dataclass class TypeExpectation: dtype: type logLevel: LogLevel class PandasPatcher(object): def __init__( self, fail: bool, typeOverrides: Dict[str, TypeExpectation] = dict(), silent: bool = False ): """Initialize a PandasPatcher with particular failure and type expectations. Args: fail: Whether to raise errors or log to stderr when expectations are violated. expectations: Type expecatations for select columns. """ self._silent = silent # Set to True to basically disable self._fail = fail self._counter = TypeErrorCounter() self._origConcat = pd.concat self._origJoin = pd.DataFrame.join self._origMerge = pd.DataFrame.merge self._origApply = pd.DataFrame.apply self._origInit = pd.DataFrame.__init__ self._origGetItem = pd.DataFrame.__getitem__ self._origSetItem = pd.DataFrame.__setitem__ self._origLocGetItem = pd.core.indexing._LocationIndexer.__getitem__ self._origLocSetItem = pd.core.indexing._LocationIndexer.__setitem__ self._expectations = { c.noteIdKey: TypeExpectation(np.int64, LogLevel.ERROR), } for column, expectation in typeOverrides.items(): self._expectations[column] = expectation def get_summary(self) -> str: return f"\nTYPE WARNING SUMMARY\n{self._counter.get_summary()}" def _log_errors(self, method: str, callsite: str, lines: List[str]) -> None: if not lines: return self._counter.log_errors(method, callsite, lines) errorLines = "\n".join([f" PandasTypeError: {l}" for l in lines]) msg = f"\n{method} ERROR(S) AT: {callsite}\n{errorLines}\n" if not self._silent: print(msg, file=sys.stderr) def _get_check(self, lines: List[str], kwargs: Dict) -> Callable: """Return a function which will either assert a condition or append to a list of errors. Note that this function does not actually log to stderr, but rather appends to a list so that all """ unsafeAllowed = set() if "unsafeAllowed" in kwargs: unsafeAllowedArg = kwargs["unsafeAllowed"] if isinstance(unsafeAllowedArg, str): unsafeAllowed = {unsafeAllowedArg} elif isinstance(unsafeAllowedArg, List): unsafeAllowed = set(unsafeAllowedArg) else: assert isinstance(unsafeAllowedArg, Set) unsafeAllowed = unsafeAllowedArg del kwargs["unsafeAllowed"] def _check(columns: Any, condition: bool, msg: str): if isinstance(columns, str): failDisabled = columns in unsafeAllowed elif isinstance(columns, List): failDisabled = all(col in unsafeAllowed for col in columns) else: # Note there are multiple circumstances where the type of Columns may not be a str # or List[str], including when we are concatenating a Series (column name will be # set to None), when there are mulit-level column names (column name will be a tuple) # or when Pandas has set column names to a RangeIndex. failDisabled = False if self._fail and not failDisabled: assert condition, msg elif not condition: if failDisabled: lines.append(f"{msg} (allowed)") else: lines.append(f"{msg} (UNALLOWED)") return _check def _get_callsite(self) -> str: """Return the file, function, line numer and pandas API call on a single line.""" for line in traceback.format_stack()[::-1]: path = line.split(",")[0] if "/pandas_utils.py" in path: continue if "/pandas/" in path: continue break # Handle paths resulting from bazel invocation match = re.match(r'^ File ".*?/site-packages(/.*?)", (.*?), (.*?)\n (.*)\n$', line) if match: return f"{match.group(1)}, {match.group(3)}, at {match.group(2)}: {match.group(4)}" # Handle paths fresulting from pytest invocation match = re.match(r'^ File ".*?/src/(test|main)/python(/.*?)", (.*?), (.*?)\n (.*)\n$', line) if match: return f"{match.group(2)}, {match.group(4)}, at {match.group(3)}: {match.group(5)}" # Handle other paths (e.g. notebook, public code) match = re.match(r'^ File "(.*?)", (.*?), (.*?)\n (.*)\n$', line) if match: return f"{match.group(1)}, {match.group(3)}, at {match.group(2)}: {match.group(4)}" else: stack = "\n\n".join(traceback.format_stack()[::-1]) print(f"parsing error:\n{stack}", file=sys.stderr) return "parsing error. callsite unknown." def _check_dtype(self, dtype: Any, expected: type) -> bool: """Return True IFF dtype corresponds to expected. Note that for non-nullable columns, dtype may equal type (e.g. np.int64), but for nullable columns the column type is actually an instance of a pandas dtype (e.g. pd.Int64Dtype) """ assert expected != object, "expectation must be more specific than object" return dtype == expected or isinstance(dtype, expected) def _check_name_and_type(self, name: str, dtype: Any) -> List[str]: """Returns a list of type mismatches if any are found, or raises an error.""" if name not in self._expectations: return [] typeExpectation = self._expectations[name] msg = f"Type expectation mismatch on {name}: found={dtype} expected={typeExpectation.dtype.__name__}" match = self._check_dtype(dtype, typeExpectation.dtype) if typeExpectation.logLevel == LogLevel.INFO: return ( [msg] if not match else [ f"Type expectation match on {name}: found={dtype} expected={typeExpectation.dtype.__name__}" ] ) elif typeExpectation.logLevel == LogLevel.ERROR or not self._fail: return [msg] if not match else [] else: assert typeExpectation.logLevel == LogLevel.FATAL assert self._fail assert match, msg return [] def _validate_series(self, series: pd.Series) -> List[str]: assert isinstance(series, pd.Series), f"unexpected type: {type(series)}" return self._check_name_and_type(series.name, series.dtype) def _validate_dataframe(self, df: pd.DataFrame) -> List[str]: """Returns a list of type mismatches if any are found, or raises an error.""" assert isinstance(df, pd.DataFrame), f"unexpected type: {type(df)}" lines = [] # Check index types if type(df.index) == pd.MultiIndex: for name, dtype in df.index.dtypes.to_dict().items(): lines.extend(self._check_name_and_type(name, dtype)) elif type(df.index) == pd.RangeIndex or df.index.name is None: # Index is uninteresting - none was specified by the caller. pass else: lines.extend(self._check_name_and_type(df.index.name, df.index.dtype)) # Check column types for name, dtype in df.dtypes.to_dict().items(): lines.extend(self._check_name_and_type(name, dtype)) return lines def safe_init(self) -> Callable: """Return a modified __init__ function that checks type expectations.""" def _safe_init(*args, **kwargs): """Wrapper around pd.concat Args: args: non-keyword arguments to pass through to merge. kwargs: keyword arguments to pass through to merge. """ df = args[0] assert isinstance(df, pd.DataFrame), f"unexpected type: {type(df)}" retVal = self._origInit(*args, **kwargs) assert retVal is None lines = self._validate_dataframe(df) self._log_errors("INIT", self._get_callsite(), lines) return retVal return _safe_init def safe_concat(self) -> Callable: """Return a modified concat function that checks type stability.""" def _safe_concat(*args, **kwargs): """Wrapper around pd.concat Args: args: non-keyword arguments to pass through to merge. kwargs: keyword arguments to pass through to merge. """ lines = [] check = self._get_check(lines, kwargs) # Validate that all objects being concatenated are either Series or DataFrames objs = args[0] assert type(objs) == list, f"expected first argument to be a list: type={type(objs)}" assert ( all(type(obj) == pd.Series for obj in objs) or all(type(obj) == pd.DataFrame for obj in objs) ), f"Expected concat args to be either pd.Series or pd.DataFrame: {[type(obj) for obj in objs]}" if type(objs[0]) == pd.Series: if "axis" in kwargs and kwargs["axis"] == 1: # Since the call is concatenating Series as columns in a DataFrame, validate that the sequence # of Series dtypes matches the sequence of column dtypes in the dataframe. result = self._origConcat(*args, **kwargs) objDtypes = [obj.dtype for obj in objs] assert len(objDtypes) == len( result.dtypes ), f"dtype length mismatch: {len(objDtypes)} vs {len(result.dtypes)}" for col, seriesType, colType in zip(result.columns, objDtypes, result.dtypes): check( col, seriesType == colType, f"Series concat on {col}: {seriesType} vs {colType}", ) else: # If Series, validate that all series were same type and return seriesTypes = set(obj.dtype for obj in objs) check(None, len(seriesTypes) == 1, f"More than 1 unique Series type: {seriesTypes}") result = self._origConcat(*args, **kwargs) else: # If DataFrame, validate that all input columns with matching names have the same type # and build expectation for output column types assert type(objs[0]) == pd.DataFrame # Validate all inputs for dfArg in objs: lines.extend(self._validate_dataframe(dfArg)) colTypes: Dict[str, List[type]] = dict() for df in objs: for col, dtype in df.reset_index(drop=False).dtypes.items(): if col not in colTypes: colTypes[col] = [] colTypes[col].append(dtype) # Perform concatenation and validate that there weren't any type changes result = self._origConcat(*args, **kwargs) for col, outputType in result.reset_index(drop=False).dtypes.items(): check( col, all(inputType == outputType for inputType in colTypes[col]), f"DataFrame concat on {col}: output={outputType} inputs={colTypes[col]}", ) if isinstance(result, pd.DataFrame): lines.extend(self._validate_dataframe(result)) elif isinstance(result, pd.Series): lines.extend(self._validate_series(result)) self._log_errors("CONCAT", self._get_callsite(), lines) return result return _safe_concat def safe_apply(self) -> Callable: """Return a modified apply function that checks type stability.""" def _safe_apply(*args, **kwargs): """Wrapper around pd.DataFrame.apply Args: args: non-keyword arguments to pass through to merge. kwargs: keyword arguments to pass through to merge. """ # TODO: Flesh this out with additional expectatoins around input and output types result = self._origApply(*args, **kwargs) if isinstance(result, pd.DataFrame): self._log_errors("APPLY", self._get_callsite(), self._validate_dataframe(result)) elif isinstance(result, pd.Series): self._log_errors("APPLY", self._get_callsite(), self._validate_series(result)) return result return _safe_apply def safe_merge(self) -> Callable: """Return a modified merge function that checks type stability.""" def _safe_merge(*args, **kwargs): """Wrapper around pd.DataFrame.merge. Args: args: non-keyword arguments to pass through to merge. kwargs: keyword arguments to pass through to merge. """ lines = [] check = self._get_check(lines, kwargs) leftFrame = args[0] rightFrame = args[1] # Validate that argument types are as expected assert type(leftFrame) is pd.DataFrame assert type(rightFrame) is pd.DataFrame lines.extend(self._validate_dataframe(leftFrame)) lines.extend(self._validate_dataframe(rightFrame)) # Store dtypes and validate that any common columns have the same type leftDtypes = dict(leftFrame.reset_index(drop=False).dtypes) rightDtypes = dict(rightFrame.reset_index(drop=False).dtypes) for col in set(leftDtypes) & set(rightDtypes): check( col, leftDtypes[col] == rightDtypes[col], f"Input mismatch on {col}: left={leftDtypes[col]} vs right={rightDtypes[col]}", ) # Identify the columns we are merging on, if left_on and right_on are unset if "on" in kwargs and type(kwargs["on"]) == str: onCols = set([kwargs["on"]]) elif "on" in kwargs and type(kwargs["on"]) == list: onCols = set(kwargs["on"]) elif "left_on" in kwargs: assert "on" not in kwargs, "not expecting both on and left_on" assert "right_on" in kwargs, "expecting both left_on and right_on to be set" onCols = set() else: assert "on" not in kwargs, f"""unexpected type for on: {type(kwargs["on"])}""" onCols = set(leftFrame.columns) & set(rightFrame.columns) # Validate that merge columns have matching types if "left_on" in kwargs: assert "right_on" in kwargs left_on = kwargs["left_on"] right_on = kwargs["right_on"] check( [left_on, right_on], leftDtypes[left_on] == rightDtypes[right_on], f"Merge key mismatch on type({left_on})={leftDtypes[left_on]} vs type({right_on})={rightDtypes[right_on]}", ) else: assert len(onCols), "expected onCols to be defined since left_on was not" assert "right_on" not in kwargs, "did not expect onCols and right_on" for col in onCols: check( col, leftDtypes[col] == rightDtypes[col], f"Merge key mismatch on {col}: left={leftDtypes[col]} vs right={rightDtypes[col]}", ) # Compute expected column types leftSuffix, rightSuffix = kwargs.get("suffixes", ("_x", "_y")) commonCols = set(leftFrame.columns) & set(rightFrame.columns) expectedColTypes = dict() for col in set(leftFrame.columns) | set(rightFrame.columns): if col in onCols: # Note that we check above whether leftDtypes[col] == rightDtypes[col] and either raise an # error or log as appropriate if there is a mismatch. if leftDtypes[col] == rightDtypes[col]: expectedColTypes[col] = leftDtypes[col] else: # Set expectation to None since we don't know what will happen, but do want to log an # error later expectedColTypes[col] = None elif col in commonCols: expectedColTypes[f"{col}{leftSuffix}"] = leftDtypes[col] expectedColTypes[f"{col}{rightSuffix}"] = rightDtypes[col] elif col in leftDtypes: assert col not in rightDtypes expectedColTypes[col] = leftDtypes[col] else: expectedColTypes[col] = rightDtypes[col] # Perform merge and validate results result = self._origMerge(*args, **kwargs) resultDtypes = dict(result.dtypes) for col in resultDtypes: check( col, resultDtypes[col] == expectedColTypes[col], f"Output mismatch on {col}: result={resultDtypes[col]} expected={expectedColTypes[col]}", ) lines.extend(self._validate_dataframe(result)) self._log_errors("MERGE", self._get_callsite(), lines) return result return _safe_merge def safe_join(self) -> Callable: """Return a modified merge function that checks type stability.""" def _safe_join(*args, **kwargs): """Wrapper around pd.DataFrame.merge. Args: args: non-keyword arguments to pass through to merge. kwargs: keyword arguments to pass through to merge. """ lines = [] check = self._get_check(lines, kwargs) leftFrame = args[0] rightFrame = args[1] # Validate arguments are as expected assert type(leftFrame) is pd.DataFrame assert type(rightFrame) is pd.DataFrame lines.extend(self._validate_dataframe(leftFrame)) lines.extend(self._validate_dataframe(rightFrame)) assert len(set(kwargs) - {"lsuffix", "rsuffix", "how"}) == 0, f"unexpected kwargs: {kwargs}" # Validate the assumption that columns used as the join key in the index have the same type. # This is analogous to validating that onCols match and have the same types in _safe_merge. if len(leftFrame.index.names) == 1 and len(rightFrame.index.names) == 1: match = leftFrame.index.dtype == rightFrame.index.dtype elif len(leftFrame.index.names) == 1 and len(rightFrame.index.names) > 1: indexTypes = dict(rightFrame.index.dtypes) name = leftFrame.index.names[0] assert name in indexTypes, f"{name} not found in {indexTypes}" match = indexTypes[name] == leftFrame.index.dtype elif len(leftFrame.index.names) > 1 and len(rightFrame.index.names) == 1: indexTypes = dict(leftFrame.index.dtypes) name = rightFrame.index.names[0] assert name in indexTypes, f"{name} not found in {indexTypes}" match = indexTypes[name] == rightFrame.index.dtype else: assert ( len(leftFrame.index.names) > 1 ), f"unexpected left: {type(leftFrame.index)}, {leftFrame.index}" assert ( len(rightFrame.index.names) > 1 ), f"unexpected right: {type(rightFrame.index)}, {rightFrame.index}" leftIndexTypes = dict(leftFrame.index.dtypes) rightIndexTypes = dict(rightFrame.index.dtypes) match = True for col in set(leftIndexTypes) & set(rightIndexTypes): match = match & (leftIndexTypes[col] == rightIndexTypes[col]) check( list(set(leftFrame.index.names) | set(rightFrame.index.names)), match, "Join index mismatch:\nleft:\n{left}\nvs\nright:\n{right}".format( left=leftFrame.index.dtype if len(leftFrame.index.names) == 1 else leftFrame.index.dtypes, right=rightFrame.index.dtype if len(rightFrame.index.names) == 1 else rightFrame.index.dtypes, ), ) # Validate that input columns with the same name have the same types leftDtypes = dict(leftFrame.dtypes) rightDtypes = dict(rightFrame.dtypes) for col in set(leftDtypes) & set(rightDtypes): check( col, leftDtypes[col] == rightDtypes[col], f"Input mismatch on {col}: left={leftDtypes[col]} vs right={rightDtypes[col]}", ) # Validate that none of the columns in an index have the same name as a non-index column # in the opposite dataframe assert ( len(set(leftFrame.index.names) & set(rightFrame.columns)) == 0 ), f"left index: {set(leftFrame.index.names)}; right columns {set(rightFrame.columns)}" assert ( len(set(rightFrame.index.names) & set(leftFrame.columns)) == 0 ), f"right index: {set(rightFrame.index.names)}; left columns {set(leftFrame.columns)}" # Compute expected types for output columns commonCols = set(leftFrame.columns) & set(rightFrame.columns) expectedColTypes = dict() leftSuffix = kwargs.get("lsuffix", "") rightSuffix = kwargs.get("rsuffix", "") for col in set(leftFrame.columns) | set(rightFrame.columns): if col in commonCols: expectedColTypes[f"{col}{leftSuffix}"] = leftDtypes[col] expectedColTypes[f"{col}{rightSuffix}"] = rightDtypes[col] elif col in leftDtypes: assert col not in rightDtypes expectedColTypes[col] = leftDtypes[col] else: expectedColTypes[col] = rightDtypes[col] # Compute expected types for index columns leftIndexCols = set(leftFrame.index.names) rightIndexCols = set(rightFrame.index.names) if len(leftIndexCols) > 1: leftDtypes = dict(leftFrame.index.dtypes) else: leftDtypes = {leftFrame.index.name: rightFrame.index.dtype} if len(rightIndexCols) > 1: rightDtypes = dict(rightFrame.index.dtypes) else: rightDtypes = {rightFrame.index.name: rightFrame.index.dtype} for col in leftIndexCols & rightIndexCols: # For columns in both indices, type should not change if input types agree. If input types # disagree, then we have no expectation. if leftDtypes[col] == rightDtypes[col]: expectedColTypes[col] = leftDtypes[col] else: expectedColTypes[col] = None for col in (leftIndexCols | rightIndexCols) - (leftIndexCols & rightIndexCols): # For columns in exactly one index, the expected output type should match the input column type # and the column name should not change because we have validated that the column does not # appear in the other dataframe if col in leftDtypes: assert col not in rightDtypes, f"unexpected column: {col}" expectedColTypes[col] = leftDtypes[col] else: expectedColTypes[col] = rightDtypes[col] # Perform join and validate results. Note that we already validated that the indices had the # same columns and types, and that the "on" argument is unset, so now we only need to check # the non-index columns. result = self._origJoin(*args, **kwargs) # Note that we must reset index to force any NaNs in the index to emerge as float types. # See example below. # left = pd.DataFrame({"idx0": [1, 2], "idx1": [11, 12], "val1": [4, 5]}).set_index(["idx0", "idx1"]) # right = pd.DataFrame({"idx0": [1, 2, 3], "idx2": [21, 22, 23], "val2": [7, 8, 9]}).set_index(["idx0", "idx2"]) # print(dict(left.join(right, how="outer").index.dtypes)) # print(dict(left.join(right, how="outer").reset_index(drop=False).dtypes)) # $> {'idx0': dtype('int64'), 'idx1': dtype('int64'), 'idx2': dtype('int64')} # $> {'idx0': dtype('int64'), 'idx1': dtype('float64'), 'idx2': dtype('int64'), 'val1': dtype('float64'), 'val2': dtype('int64')} resultDtypes = dict(result.reset_index(drop=False).dtypes) # Add default type for index if "index" not in expectedColTypes: expectedColTypes["index"] = np.int64 for col, dtype in resultDtypes.items(): if len(col) == 2 and col[1] == "": col = col[0] check( col, dtype == expectedColTypes[col], f"Output mismatch on {col}: result={dtype} expected={expectedColTypes[col]}", ) lines.extend(self._validate_dataframe(result)) self._log_errors("JOIN", self._get_callsite(), lines) return result return _safe_join # TODO: restore original functionality before return # TODO: make enforce_types an explicit arguemnt so this is less error prone def patch_pandas(main: Callable) -> Callable: """Return a decorator for wrapping main with pandas patching and logging Args: main: "main" function for program binary """ def _inner(*args, **kwargs) -> Any: """Determine patching behavior, apply patch and add logging.""" print("Patching pandas") if "args" in kwargs: # Handle birdwatch/scoring/src/main/python/public/scoring/runner.py, which expects # args as a keyword argument and not as a positional argument. assert len(args) == 0, f"positional arguments not expected, but found {len(args)}" clArgs = kwargs["args"] else: # Handle the following, which expect args as the second positional argument: # birdwatch/scoring/src/main/python/run_post_selection_similarity.py # birdwatch/scoring/src/main/python/run_prescoring.py # birdwatch/scoring/src/main/python/run_final_scoring.py # birdwatch/scoring/src/main/python/run_contributor_scoring.py # birdwatch/scoring/src/main/python/run.py assert len(args) == 1, f"unexpected 1 positional args, but found {len(args)}" assert len(kwargs) == 0, f"expected kwargs to be empty, but found {len(kwargs)}" clArgs = args[0] # Apply patches, configured based on whether types should be enforced or logged patcher = PandasPatcher(clArgs.enforce_types) pd.concat = patcher.safe_concat() # Note that this will work when calling df1.merge(df2) because the first argument # to "merge" is df1 (i.e. self). pd.DataFrame.merge = patcher.safe_merge() pd.DataFrame.join = patcher.safe_join() pd.DataFrame.apply = patcher.safe_apply() pd.DataFrame.__init__ = patcher.safe_init() # Run main retVal = main(*args, **kwargs) # Log type error summary if hasattr(clArgs, "parallel") and not clArgs.parallel: print(patcher.get_summary(), file=sys.stderr) else: # Don't show type summary because counters will be inaccurate due to scorers running # in their own process. print("Type summary omitted when running in parallel.", file=sys.stderr) # Return result of main return retVal return _inner