def rank_specific()

in ml_logging/torch_logging.py [0:0]


def rank_specific(logger):
  """Ensures that we only override a given logger once."""
  if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC"):
    return logger

  def _if_rank(logger_method, limit: Optional[int] = None):
    if limit:
      # If we are limiting redundant logs, wrap logging call with a cache
      # to not execute if already cached.
      def _wrap(_call):
        @functools.lru_cache(limit)
        def _logger_method(*args, **kwargs):
          _call(*args, **kwargs)

        return _logger_method

      logger_method = _wrap(logger_method)

    def _inner(msg, *args, rank: int = 0, **kwargs):
      if not dist.is_initialized():
        logger_method(msg, *args, **kwargs)
      elif dist.get_rank() == rank:
        logger_method(msg, *args, **kwargs)
      elif rank < 0:
        logger_method(f"Rank{dist.get_rank()}: {msg}", *args, **kwargs)

    # Register this stack frame with absl logging so that it doesn't trample logging lines.
    absl_logging.ABSLLogger.register_frame_to_skip(__file__, _inner.__name__)

    return _inner

  logger.fatal = _if_rank(logger.fatal)
  logger.error = _if_rank(logger.error)
  logger.warning = _if_rank(logger.warning, limit=1)
  logger.info = _if_rank(logger.info)
  logger.debug = _if_rank(logger.debug)
  logger.exception = _if_rank(logger.exception)

  logger._ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC = True