in ratatool-sampling/src/main/scala/com/spotify/ratatool/samplers/util/SamplerSCollectionFunctions.scala [236:275]
private def uniformThresholdByKey[T: ClassTag: Coder, U: ClassTag: Coder](
s: SCollection[(U, (T, Double))],
probByKey: SCollection[(U, Double)],
popPerKey: SideInput[Double],
delta: Double,
sizePerKey: Int
): SCollection[(U, Double)] = {
val countByKey = s.countByKey
val boundsByKey = probByKey
.hashJoin(countByKey)
.map { case (k, (p, c)) => (k, (getLowerBound(c, p, delta), getUpperBound(c, p, delta))) }
val boundCountsByKey = s
.map { case (k, (_, d)) => (k, d) }
.hashJoin(boundsByKey)
.filter { case (_, (d, (_, u))) => d < u }
.map { case (k, (d, (l, _))) => (k, (if (d < l) 1L else 0L, 1L)) }
.sumByKey
s.map { case (k, (_, d)) => (k, d) }
.hashJoin(boundsByKey)
.filter { case (_, (d, (l, u))) => d >= l && d < u }
.map { case (k, (d, _)) => (k, d) }
// TODO: Clean up magic number
.topByKey(sizePerKey)(Ordering.by(identity[Double]).reverse)
.hashJoin(boundCountsByKey)
.hashJoin(boundsByKey)
.withSideInputs(popPerKey)
.map { case ((k, ((itr, (lCounts, uCounts)), (l, u))), sic) =>
if (lCounts >= sic(popPerKey)) {
(k, l)
} else if (uCounts < sic(popPerKey)) {
(k, u)
} else {
val threshold = itr.drop(max(0, (sic(popPerKey) - lCounts).toInt)).headOption
(k, threshold.getOrElse(u))
}
}
.toSCollection
}