private def uniformThresholdByKey[T: ClassTag: Coder, U: ClassTag: Coder]()

in ratatool-sampling/src/main/scala/com/spotify/ratatool/samplers/util/SamplerSCollectionFunctions.scala [236:275]


  private def uniformThresholdByKey[T: ClassTag: Coder, U: ClassTag: Coder](
    s: SCollection[(U, (T, Double))],
    probByKey: SCollection[(U, Double)],
    popPerKey: SideInput[Double],
    delta: Double,
    sizePerKey: Int
  ): SCollection[(U, Double)] = {
    val countByKey = s.countByKey
    val boundsByKey = probByKey
      .hashJoin(countByKey)
      .map { case (k, (p, c)) => (k, (getLowerBound(c, p, delta), getUpperBound(c, p, delta))) }

    val boundCountsByKey = s
      .map { case (k, (_, d)) => (k, d) }
      .hashJoin(boundsByKey)
      .filter { case (_, (d, (_, u))) => d < u }
      .map { case (k, (d, (l, _))) => (k, (if (d < l) 1L else 0L, 1L)) }
      .sumByKey

    s.map { case (k, (_, d)) => (k, d) }
      .hashJoin(boundsByKey)
      .filter { case (_, (d, (l, u))) => d >= l && d < u }
      .map { case (k, (d, _)) => (k, d) }
      // TODO: Clean up magic number
      .topByKey(sizePerKey)(Ordering.by(identity[Double]).reverse)
      .hashJoin(boundCountsByKey)
      .hashJoin(boundsByKey)
      .withSideInputs(popPerKey)
      .map { case ((k, ((itr, (lCounts, uCounts)), (l, u))), sic) =>
        if (lCounts >= sic(popPerKey)) {
          (k, l)
        } else if (uCounts < sic(popPerKey)) {
          (k, u)
        } else {
          val threshold = itr.drop(max(0, (sic(popPerKey) - lCounts).toInt)).headOption
          (k, threshold.getOrElse(u))
        }
      }
      .toSCollection
  }