def safe_merge()

in sourcecode/scoring/pandas_utils.py [0:0]


  def safe_merge(self) -> Callable:
    """Return a modified merge function that checks type stability."""

    def _safe_merge(*args, **kwargs):
      """Wrapper around pd.DataFrame.merge.

      Args:
        args: non-keyword arguments to pass through to merge.
        kwargs: keyword arguments to pass through to merge.
      """
      lines = []
      check = self._get_check(lines, kwargs)
      leftFrame = args[0]
      rightFrame = args[1]
      # Validate that argument types are as expected
      assert type(leftFrame) is pd.DataFrame
      assert type(rightFrame) is pd.DataFrame
      lines.extend(self._validate_dataframe(leftFrame))
      lines.extend(self._validate_dataframe(rightFrame))
      # Store dtypes and validate that any common columns have the same type
      leftDtypes = dict(leftFrame.reset_index(drop=False).dtypes)
      rightDtypes = dict(rightFrame.reset_index(drop=False).dtypes)
      for col in set(leftDtypes) & set(rightDtypes):
        check(
          col,
          leftDtypes[col] == rightDtypes[col],
          f"Input mismatch on {col}: left={leftDtypes[col]} vs right={rightDtypes[col]}",
        )
      # Identify the columns we are merging on, if left_on and right_on are unset
      if "on" in kwargs and type(kwargs["on"]) == str:
        onCols = set([kwargs["on"]])
      elif "on" in kwargs and type(kwargs["on"]) == list:
        onCols = set(kwargs["on"])
      elif "left_on" in kwargs:
        assert "on" not in kwargs, "not expecting both on and left_on"
        assert "right_on" in kwargs, "expecting both left_on and right_on to be set"
        onCols = set()
      else:
        assert "on" not in kwargs, f"""unexpected type for on: {type(kwargs["on"])}"""
        onCols = set(leftFrame.columns) & set(rightFrame.columns)
      # Validate that merge columns have matching types
      if "left_on" in kwargs:
        assert "right_on" in kwargs
        left_on = kwargs["left_on"]
        right_on = kwargs["right_on"]
        check(
          [left_on, right_on],
          leftDtypes[left_on] == rightDtypes[right_on],
          f"Merge key mismatch on type({left_on})={leftDtypes[left_on]} vs type({right_on})={rightDtypes[right_on]}",
        )
      else:
        assert len(onCols), "expected onCols to be defined since left_on was not"
        assert "right_on" not in kwargs, "did not expect onCols and right_on"
        for col in onCols:
          check(
            col,
            leftDtypes[col] == rightDtypes[col],
            f"Merge key mismatch on {col}: left={leftDtypes[col]} vs right={rightDtypes[col]}",
          )
      # Compute expected column types
      leftSuffix, rightSuffix = kwargs.get("suffixes", ("_x", "_y"))
      commonCols = set(leftFrame.columns) & set(rightFrame.columns)
      expectedColTypes = dict()
      for col in set(leftFrame.columns) | set(rightFrame.columns):
        if col in onCols:
          # Note that we check above whether leftDtypes[col] == rightDtypes[col] and either raise an
          # error or log as appropriate if there is a mismatch.
          if leftDtypes[col] == rightDtypes[col]:
            expectedColTypes[col] = leftDtypes[col]
          else:
            # Set expectation to None since we don't know what will happen, but do want to log an
            # error later
            expectedColTypes[col] = None
        elif col in commonCols:
          expectedColTypes[f"{col}{leftSuffix}"] = leftDtypes[col]
          expectedColTypes[f"{col}{rightSuffix}"] = rightDtypes[col]
        elif col in leftDtypes:
          assert col not in rightDtypes
          expectedColTypes[col] = leftDtypes[col]
        else:
          expectedColTypes[col] = rightDtypes[col]
      # Perform merge and validate results
      result = self._origMerge(*args, **kwargs)
      resultDtypes = dict(result.dtypes)
      for col in resultDtypes:
        check(
          col,
          resultDtypes[col] == expectedColTypes[col],
          f"Output mismatch on {col}: result={resultDtypes[col]} expected={expectedColTypes[col]}",
        )
      lines.extend(self._validate_dataframe(result))
      self._log_errors("MERGE", self._get_callsite(), lines)
      return result

    return _safe_merge