in ratatool-sampling/src/main/scala/com/spotify/ratatool/samplers/util/SamplerSCollectionFunctions.scala [46:96]
private[samplers] def assignRandomRoll[T: ClassTag: Coder, U: ClassTag: Coder](
s: SCollection[T],
keyFn: T => U
) =
s.keyBy(keyFn).applyTransform(ParDo.of(new RandomValueAssigner[U, T]))
private[samplers] def buildStratifiedDiffs[T: ClassTag: Coder, U: ClassTag: Coder](
s: SCollection[T],
sampled: SCollection[(U, T)],
keyFn: T => U,
prob: Double,
exact: Boolean = false
): SCollection[(Double, Map[U, Double])] = {
val targets = s
.map(t => (1L, Map[U, Long](keyFn(t) -> 1L)))
.sum
.map { case (total, m) =>
(total * prob, m.map { case (k, v) => (k, v * prob) })
}
.asSingletonSideInput
sampled.keys
.map(k => (1L, Map[U, Long](k -> 1L)))
.sum
.withSideInputs(targets)
.map { case (res, sic) =>
val (targetTotal, keyTargets) = sic(targets)
val (totalCount, keyCounts) = res
val totalDiff = (targetTotal - totalCount) / targetTotal
val keyDiffs = keyTargets.keySet
.map(k => k -> (keyTargets(k) - keyCounts.getOrElse(k, 0L)) / keyTargets(k))
.toMap
if (exact) {
if (totalDiff > errorTolerance) {
throw new Exception(
s"Total elements sampled off by ${totalDiff * 100}% (> ${errorTolerance * 100}%)"
)
}
keyDiffs.foreach { case (k, diff) =>
if (diff > errorTolerance) {
throw new Exception(
s"Elements for key $k sample off by ${diff * 100}% (> ${errorTolerance * 100}%)"
)
}
}
}
(totalDiff, keyDiffs)
}
.toSCollection
}