def safe_join()

in sourcecode/scoring/pandas_utils.py [0:0]


  def safe_join(self) -> Callable:
    """Return a modified merge function that checks type stability."""

    def _safe_join(*args, **kwargs):
      """Wrapper around pd.DataFrame.merge.

      Args:
        args: non-keyword arguments to pass through to merge.
        kwargs: keyword arguments to pass through to merge.
      """
      lines = []
      check = self._get_check(lines, kwargs)
      leftFrame = args[0]
      rightFrame = args[1]
      # Validate arguments are as expected
      assert type(leftFrame) is pd.DataFrame
      assert type(rightFrame) is pd.DataFrame
      lines.extend(self._validate_dataframe(leftFrame))
      lines.extend(self._validate_dataframe(rightFrame))
      assert len(set(kwargs) - {"lsuffix", "rsuffix", "how"}) == 0, f"unexpected kwargs: {kwargs}"
      # Validate the assumption that columns used as the join key in the index have the same type.
      # This is analogous to validating that onCols match and have the same types in _safe_merge.
      if len(leftFrame.index.names) == 1 and len(rightFrame.index.names) == 1:
        match = leftFrame.index.dtype == rightFrame.index.dtype
      elif len(leftFrame.index.names) == 1 and len(rightFrame.index.names) > 1:
        indexTypes = dict(rightFrame.index.dtypes)
        name = leftFrame.index.names[0]
        assert name in indexTypes, f"{name} not found in {indexTypes}"
        match = indexTypes[name] == leftFrame.index.dtype
      elif len(leftFrame.index.names) > 1 and len(rightFrame.index.names) == 1:
        indexTypes = dict(leftFrame.index.dtypes)
        name = rightFrame.index.names[0]
        assert name in indexTypes, f"{name} not found in {indexTypes}"
        match = indexTypes[name] == rightFrame.index.dtype
      else:
        assert (
          len(leftFrame.index.names) > 1
        ), f"unexpected left: {type(leftFrame.index)}, {leftFrame.index}"
        assert (
          len(rightFrame.index.names) > 1
        ), f"unexpected right: {type(rightFrame.index)}, {rightFrame.index}"
        leftIndexTypes = dict(leftFrame.index.dtypes)
        rightIndexTypes = dict(rightFrame.index.dtypes)
        match = True
        for col in set(leftIndexTypes) & set(rightIndexTypes):
          match = match & (leftIndexTypes[col] == rightIndexTypes[col])
      check(
        list(set(leftFrame.index.names) | set(rightFrame.index.names)),
        match,
        "Join index mismatch:\nleft:\n{left}\nvs\nright:\n{right}".format(
          left=leftFrame.index.dtype if len(leftFrame.index.names) == 1 else leftFrame.index.dtypes,
          right=rightFrame.index.dtype
          if len(rightFrame.index.names) == 1
          else rightFrame.index.dtypes,
        ),
      )
      # Validate that input columns with the same name have the same types
      leftDtypes = dict(leftFrame.dtypes)
      rightDtypes = dict(rightFrame.dtypes)
      for col in set(leftDtypes) & set(rightDtypes):
        check(
          col,
          leftDtypes[col] == rightDtypes[col],
          f"Input mismatch on {col}: left={leftDtypes[col]} vs right={rightDtypes[col]}",
        )
      # Validate that none of the columns in an index have the same name as a non-index column
      # in the opposite dataframe
      assert (
        len(set(leftFrame.index.names) & set(rightFrame.columns)) == 0
      ), f"left index: {set(leftFrame.index.names)}; right columns {set(rightFrame.columns)}"
      assert (
        len(set(rightFrame.index.names) & set(leftFrame.columns)) == 0
      ), f"right index: {set(rightFrame.index.names)}; left columns {set(leftFrame.columns)}"
      # Compute expected types for output columns
      commonCols = set(leftFrame.columns) & set(rightFrame.columns)
      expectedColTypes = dict()
      leftSuffix = kwargs.get("lsuffix", "")
      rightSuffix = kwargs.get("rsuffix", "")
      for col in set(leftFrame.columns) | set(rightFrame.columns):
        if col in commonCols:
          expectedColTypes[f"{col}{leftSuffix}"] = leftDtypes[col]
          expectedColTypes[f"{col}{rightSuffix}"] = rightDtypes[col]
        elif col in leftDtypes:
          assert col not in rightDtypes
          expectedColTypes[col] = leftDtypes[col]
        else:
          expectedColTypes[col] = rightDtypes[col]
      # Compute expected types for index columns
      leftIndexCols = set(leftFrame.index.names)
      rightIndexCols = set(rightFrame.index.names)
      if len(leftIndexCols) > 1:
        leftDtypes = dict(leftFrame.index.dtypes)
      else:
        leftDtypes = {leftFrame.index.name: rightFrame.index.dtype}
      if len(rightIndexCols) > 1:
        rightDtypes = dict(rightFrame.index.dtypes)
      else:
        rightDtypes = {rightFrame.index.name: rightFrame.index.dtype}
      for col in leftIndexCols & rightIndexCols:
        # For columns in both indices, type should not change if input types agree.  If input types
        # disagree, then we have no expectation.
        if leftDtypes[col] == rightDtypes[col]:
          expectedColTypes[col] = leftDtypes[col]
        else:
          expectedColTypes[col] = None
      for col in (leftIndexCols | rightIndexCols) - (leftIndexCols & rightIndexCols):
        # For columns in exactly one index, the expected output type should match the input column type
        # and the column name should not change because we have validated that the column does not
        # appear in the other dataframe
        if col in leftDtypes:
          assert col not in rightDtypes, f"unexpected column: {col}"
          expectedColTypes[col] = leftDtypes[col]
        else:
          expectedColTypes[col] = rightDtypes[col]
      # Perform join and validate results.  Note that we already validated that the indices had the
      # same columns and types, and that the "on" argument is unset, so now we only need to check
      # the non-index columns.
      result = self._origJoin(*args, **kwargs)
      # Note that we must reset index to force any NaNs in the index to emerge as float types.
      # See example below.
      # left = pd.DataFrame({"idx0": [1, 2], "idx1": [11, 12], "val1": [4, 5]}).set_index(["idx0", "idx1"])
      # right = pd.DataFrame({"idx0": [1, 2, 3], "idx2": [21, 22, 23], "val2": [7, 8, 9]}).set_index(["idx0", "idx2"])
      # print(dict(left.join(right, how="outer").index.dtypes))
      # print(dict(left.join(right, how="outer").reset_index(drop=False).dtypes))
      # $> {'idx0': dtype('int64'), 'idx1': dtype('int64'), 'idx2': dtype('int64')}
      # $> {'idx0': dtype('int64'), 'idx1': dtype('float64'), 'idx2': dtype('int64'), 'val1': dtype('float64'), 'val2': dtype('int64')}
      resultDtypes = dict(result.reset_index(drop=False).dtypes)
      # Add default type for index
      if "index" not in expectedColTypes:
        expectedColTypes["index"] = np.int64
      for col, dtype in resultDtypes.items():
        if len(col) == 2 and col[1] == "":
          col = col[0]
        check(
          col,
          dtype == expectedColTypes[col],
          f"Output mismatch on {col}: result={dtype} expected={expectedColTypes[col]}",
        )
      lines.extend(self._validate_dataframe(result))
      self._log_errors("JOIN", self._get_callsite(), lines)
      return result

    return _safe_join