def prepend_transform()

in core/metric_mixin.py [0:0]


def prepend_transform(base_metric: torchmetrics.Metric, transform: Callable):
  """Returns new class using MetricMixin and given base_metric.

  Functionally the same using inheritance, just saves some lines of code
  if no need for class attributes.

  """

  def transform_method(_self, *args, **kwargs):
    return transform(*args, **kwargs)

  return type(
    base_metric.__name__,
    (
      MetricMixin,
      base_metric,
    ),
    {"transform": transform_method},
  )