def patch_pandas()

in sourcecode/scoring/pandas_utils.py [0:0]


def patch_pandas(main: Callable) -> Callable:
  """Return a decorator for wrapping main with pandas patching and logging

  Args:
    main: "main" function for program binary
  """

  def _inner(*args, **kwargs) -> Any:
    """Determine patching behavior, apply patch and add logging."""
    print("Patching pandas")
    if "args" in kwargs:
      # Handle birdwatch/scoring/src/main/python/public/scoring/runner.py, which expects
      # args as a keyword argument and not as a positional argument.
      assert len(args) == 0, f"positional arguments not expected, but found {len(args)}"
      clArgs = kwargs["args"]
    else:
      # Handle the following, which expect args as the second positional argument:
      # birdwatch/scoring/src/main/python/run_post_selection_similarity.py
      # birdwatch/scoring/src/main/python/run_prescoring.py
      # birdwatch/scoring/src/main/python/run_final_scoring.py
      # birdwatch/scoring/src/main/python/run_contributor_scoring.py
      # birdwatch/scoring/src/main/python/run.py
      assert len(args) == 1, f"unexpected 1 positional args, but found {len(args)}"
      assert len(kwargs) == 0, f"expected kwargs to be empty, but found {len(kwargs)}"
      clArgs = args[0]
    # Apply patches, configured based on whether types should be enforced or logged
    patcher = PandasPatcher(clArgs.enforce_types)
    pd.concat = patcher.safe_concat()
    # Note that this will work when calling df1.merge(df2) because the first argument
    # to "merge" is df1 (i.e. self).
    pd.DataFrame.merge = patcher.safe_merge()
    pd.DataFrame.join = patcher.safe_join()
    pd.DataFrame.apply = patcher.safe_apply()
    pd.DataFrame.__init__ = patcher.safe_init()
    # Run main
    retVal = main(*args, **kwargs)
    # Log type error summary
    if hasattr(clArgs, "parallel") and not clArgs.parallel:
      print(patcher.get_summary(), file=sys.stderr)
    else:
      # Don't show type summary because counters will be inaccurate due to scorers running
      # in their own process.
      print("Type summary omitted when running in parallel.", file=sys.stderr)
    # Return result of main
    return retVal

  return _inner