private def stratifiedThresholdByKey[T: ClassTag: Coder, U: ClassTag: Coder]()

in ratatool-sampling/src/main/scala/com/spotify/ratatool/samplers/util/SamplerSCollectionFunctions.scala [193:234]


  private def stratifiedThresholdByKey[T: ClassTag: Coder, U: ClassTag: Coder](
    s: SCollection[(U, (T, Double))],
    prob: Double,
    delta: Double,
    sizePerKey: Int
  ): SCollection[(U, Double)] = {
    val countByKey = s.countByKey
    val targetByKey = countByKey.map { case (k, c) => (k, (c * prob).toLong) }
    val boundsByKey = countByKey
      .map { case (k, c) =>
        (k, (getLowerBound(c, prob, delta), getUpperBound(c, prob, delta)))
      }

    val boundCountsByKey = s
      .map { case (k, (_, d)) => (k, d) }
      .hashJoin(boundsByKey)
      .filter { case (_, (d, (_, u))) => d < u }
      .map { case (k, (d, (l, _))) => (k, (if (d < l) 1L else 0L, 1L)) }
      .sumByKey

    s.map { case (k, (_, d)) => (k, d) }
      .hashJoin(countByKey)
      .filter { case (_, (d, c)) =>
        d < getUpperBound(c, prob, delta) && d >= getLowerBound(c, prob, delta)
      }
      .map { case (k, (d, _)) => (k, d) }
      // TODO: Clean up magic number
      .topByKey(sizePerKey)(Ordering.by(identity[Double]).reverse)
      .hashJoin(boundCountsByKey)
      .hashJoin(boundsByKey)
      .hashJoin(targetByKey)
      .map { case (k, (((itr, (lCounts, uCounts)), (l, u)), target)) =>
        if (lCounts >= target) {
          (k, l)
        } else if (uCounts < target) {
          (k, u)
        } else {
          val threshold = itr.drop(max(0, (target - lCounts).toInt)).headOption
          (k, threshold.getOrElse(u))
        }
      }
  }