in ratatool-sampling/src/main/scala/com/spotify/ratatool/samplers/util/SamplerSCollectionFunctions.scala [193:234]
private def stratifiedThresholdByKey[T: ClassTag: Coder, U: ClassTag: Coder](
s: SCollection[(U, (T, Double))],
prob: Double,
delta: Double,
sizePerKey: Int
): SCollection[(U, Double)] = {
val countByKey = s.countByKey
val targetByKey = countByKey.map { case (k, c) => (k, (c * prob).toLong) }
val boundsByKey = countByKey
.map { case (k, c) =>
(k, (getLowerBound(c, prob, delta), getUpperBound(c, prob, delta)))
}
val boundCountsByKey = s
.map { case (k, (_, d)) => (k, d) }
.hashJoin(boundsByKey)
.filter { case (_, (d, (_, u))) => d < u }
.map { case (k, (d, (l, _))) => (k, (if (d < l) 1L else 0L, 1L)) }
.sumByKey
s.map { case (k, (_, d)) => (k, d) }
.hashJoin(countByKey)
.filter { case (_, (d, c)) =>
d < getUpperBound(c, prob, delta) && d >= getLowerBound(c, prob, delta)
}
.map { case (k, (d, _)) => (k, d) }
// TODO: Clean up magic number
.topByKey(sizePerKey)(Ordering.by(identity[Double]).reverse)
.hashJoin(boundCountsByKey)
.hashJoin(boundsByKey)
.hashJoin(targetByKey)
.map { case (k, (((itr, (lCounts, uCounts)), (l, u)), target)) =>
if (lCounts >= target) {
(k, l)
} else if (uCounts < target) {
(k, u)
} else {
val threshold = itr.drop(max(0, (target - lCounts).toInt)).headOption
(k, threshold.getOrElse(u))
}
}
}