def sampleDist[U: ClassTag: Coder]()

in ratatool-sampling/src/main/scala/com/spotify/ratatool/samplers/util/SamplerSCollectionFunctions.scala [321:349]


    def sampleDist[U: ClassTag: Coder](dist: SampleDistribution, keyFn: T => U, prob: Double)(
      implicit coder: Coder[T]
    ): SCollection[T] = {
      @transient lazy val logSerDe = LoggerFactory.getLogger(this.getClass)

      val (sampled, sampledDiffs) = dist match {
        case StratifiedDistribution =>
          val sampleFn: RandomValueSampler[U, T, _] = new BernoulliValueSampler[U, T]
          val keyed = s.keyBy(keyFn)
          val sample = keyed.map((_, prob)).applyTransform(ParDo.of(sampleFn))
          val diffs = buildStratifiedDiffs(s, sample, keyFn, prob)
          (sample, diffs)

        case UniformDistribution =>
          val sampleFn: RandomValueSampler[U, T, _] = new BernoulliValueSampler[U, T]
          val (popPerKey, probPerKey) = uniformParams(s, keyFn, prob)
          val sample = s
            .keyBy(keyFn)
            .hashJoin(probPerKey)
            .map { case (k, (v, keyProb)) => ((k, v), keyProb) }
            .applyTransform(ParDo.of(sampleFn))

          val diffs = buildUniformDiffs(s, sample, keyFn, prob, popPerKey)
          (sample, diffs)
      }

      logDistributionDiffs(sampledDiffs, logSerDe)
      sampled.values
    }