def safe_concat()

in sourcecode/scoring/pandas_utils.py [0:0]


  def safe_concat(self) -> Callable:
    """Return a modified concat function that checks type stability."""

    def _safe_concat(*args, **kwargs):
      """Wrapper around pd.concat

      Args:
        args: non-keyword arguments to pass through to merge.
        kwargs: keyword arguments to pass through to merge.
      """
      lines = []
      check = self._get_check(lines, kwargs)
      # Validate that all objects being concatenated are either Series or DataFrames
      objs = args[0]
      assert type(objs) == list, f"expected first argument to be a list: type={type(objs)}"
      assert (
        all(type(obj) == pd.Series for obj in objs)
        or all(type(obj) == pd.DataFrame for obj in objs)
      ), f"Expected concat args to be either pd.Series or pd.DataFrame: {[type(obj) for obj in objs]}"
      if type(objs[0]) == pd.Series:
        if "axis" in kwargs and kwargs["axis"] == 1:
          # Since the call is concatenating Series as columns in a DataFrame, validate that the sequence
          # of Series dtypes matches the sequence of column dtypes in the dataframe.
          result = self._origConcat(*args, **kwargs)
          objDtypes = [obj.dtype for obj in objs]
          assert len(objDtypes) == len(
            result.dtypes
          ), f"dtype length mismatch: {len(objDtypes)} vs {len(result.dtypes)}"
          for col, seriesType, colType in zip(result.columns, objDtypes, result.dtypes):
            check(
              col,
              seriesType == colType,
              f"Series concat on {col}: {seriesType} vs {colType}",
            )
        else:
          # If Series, validate that all series were same type and return
          seriesTypes = set(obj.dtype for obj in objs)
          check(None, len(seriesTypes) == 1, f"More than 1 unique Series type: {seriesTypes}")
          result = self._origConcat(*args, **kwargs)
      else:
        # If DataFrame, validate that all input columns with matching names have the same type
        # and build expectation for output column types
        assert type(objs[0]) == pd.DataFrame
        # Validate all inputs
        for dfArg in objs:
          lines.extend(self._validate_dataframe(dfArg))
        colTypes: Dict[str, List[type]] = dict()
        for df in objs:
          for col, dtype in df.reset_index(drop=False).dtypes.items():
            if col not in colTypes:
              colTypes[col] = []
            colTypes[col].append(dtype)
        # Perform concatenation and validate that there weren't any type changes
        result = self._origConcat(*args, **kwargs)
        for col, outputType in result.reset_index(drop=False).dtypes.items():
          check(
            col,
            all(inputType == outputType for inputType in colTypes[col]),
            f"DataFrame concat on {col}: output={outputType} inputs={colTypes[col]}",
          )
      if isinstance(result, pd.DataFrame):
        lines.extend(self._validate_dataframe(result))
      elif isinstance(result, pd.Series):
        lines.extend(self._validate_series(result))
      self._log_errors("CONCAT", self._get_callsite(), lines)
      return result

    return _safe_concat