private[samplers] def assignRandomRoll[T: ClassTag: Coder, U: ClassTag: Coder]()

in ratatool-sampling/src/main/scala/com/spotify/ratatool/samplers/util/SamplerSCollectionFunctions.scala [46:96]


  private[samplers] def assignRandomRoll[T: ClassTag: Coder, U: ClassTag: Coder](
    s: SCollection[T],
    keyFn: T => U
  ) =
    s.keyBy(keyFn).applyTransform(ParDo.of(new RandomValueAssigner[U, T]))

  private[samplers] def buildStratifiedDiffs[T: ClassTag: Coder, U: ClassTag: Coder](
    s: SCollection[T],
    sampled: SCollection[(U, T)],
    keyFn: T => U,
    prob: Double,
    exact: Boolean = false
  ): SCollection[(Double, Map[U, Double])] = {
    val targets = s
      .map(t => (1L, Map[U, Long](keyFn(t) -> 1L)))
      .sum
      .map { case (total, m) =>
        (total * prob, m.map { case (k, v) => (k, v * prob) })
      }
      .asSingletonSideInput

    sampled.keys
      .map(k => (1L, Map[U, Long](k -> 1L)))
      .sum
      .withSideInputs(targets)
      .map { case (res, sic) =>
        val (targetTotal, keyTargets) = sic(targets)
        val (totalCount, keyCounts) = res
        val totalDiff = (targetTotal - totalCount) / targetTotal
        val keyDiffs = keyTargets.keySet
          .map(k => k -> (keyTargets(k) - keyCounts.getOrElse(k, 0L)) / keyTargets(k))
          .toMap

        if (exact) {
          if (totalDiff > errorTolerance) {
            throw new Exception(
              s"Total elements sampled off by ${totalDiff * 100}% (> ${errorTolerance * 100}%)"
            )
          }
          keyDiffs.foreach { case (k, diff) =>
            if (diff > errorTolerance) {
              throw new Exception(
                s"Elements for key $k sample off by ${diff * 100}% (> ${errorTolerance * 100}%)"
              )
            }
          }
        }
        (totalDiff, keyDiffs)
      }
      .toSCollection
  }