in ratatool-sampling/src/main/scala/com/spotify/ratatool/samplers/util/SamplerSCollectionFunctions.scala [321:349]
def sampleDist[U: ClassTag: Coder](dist: SampleDistribution, keyFn: T => U, prob: Double)(
implicit coder: Coder[T]
): SCollection[T] = {
@transient lazy val logSerDe = LoggerFactory.getLogger(this.getClass)
val (sampled, sampledDiffs) = dist match {
case StratifiedDistribution =>
val sampleFn: RandomValueSampler[U, T, _] = new BernoulliValueSampler[U, T]
val keyed = s.keyBy(keyFn)
val sample = keyed.map((_, prob)).applyTransform(ParDo.of(sampleFn))
val diffs = buildStratifiedDiffs(s, sample, keyFn, prob)
(sample, diffs)
case UniformDistribution =>
val sampleFn: RandomValueSampler[U, T, _] = new BernoulliValueSampler[U, T]
val (popPerKey, probPerKey) = uniformParams(s, keyFn, prob)
val sample = s
.keyBy(keyFn)
.hashJoin(probPerKey)
.map { case (k, (v, keyProb)) => ((k, v), keyProb) }
.applyTransform(ParDo.of(sampleFn))
val diffs = buildUniformDiffs(s, sample, keyFn, prob, popPerKey)
(sample, diffs)
}
logDistributionDiffs(sampledDiffs, logSerDe)
sampled.values
}