in scio-core/src/main/scala/com/spotify/scio/io/ZstdDictIO.scala [37:126]
override protected def read(sc: ScioContext, params: Nothing): SCollection[T] =
throw new UnsupportedOperationException("ZstdDictIO is write-only")
override protected def write(
data: SCollection[T],
params: ZstdDictIO.WriteParam
): Tap[Nothing] = {
val ZstdDictIO.WriteParam(
zstdDictSizeBytes,
numElementsForSizeEstimation,
trainingBytesTarget
) = params
// see https://github.com/facebook/zstd/issues/3769#issuecomment-1730261489
if (zstdDictSizeBytes > (ZstdDictIO.DictionarySizeMbWarningThreshold * 1024 * 1024)) {
logger.warn(
s"Dictionary sizes over ${ZstdDictIO.DictionarySizeMbWarningThreshold}MB are of " +
s"questionable utility. Consider reducing zstdDictSizeBytes."
)
}
if (numElementsForSizeEstimation <= 0) {
throw new IllegalArgumentException(
s"numElementsForSizeEstimation must be positive, found $numElementsForSizeEstimation"
)
}
// training bytes may not exceed 2GiB a.k.a. the max value of an Int
val trainingBytesTargetActual: Int = trainingBytesTarget.getOrElse {
// see https://github.com/facebook/zstd/blob/v1.5.5/lib/zdict.h#L132-L136
val computed = Try(Math.multiplyExact(zstdDictSizeBytes, 100)).toOption.getOrElse {
throw new IllegalArgumentException(
"Using 100 * zstdDictSizeBytes would exceed 2GiB training bytes. " +
"Reduce dictionary size"
)
}
logger.info(s"No trainingBytesTarget passed, using ${computed} bytes")
computed
}
if (trainingBytesTargetActual <= 0) {
throw new IllegalArgumentException(
s"trainingBytesTarget must be positive, found $trainingBytesTargetActual"
)
}
val beamCoder = CoderMaterializer.beam(data.context, data.coder)
def toBytes(v: T): Array[Byte] = CoderUtils.encodeToByteArray(beamCoder, v)
data
.transform("Create Zstd Dictionary") { scoll =>
// estimate the sample rate we need by examining numElementsForSizeEstimation elements
val streamsCntSI = scoll.count.asSingletonSideInput(0L)
val sampleRateSI = scoll
.take(numElementsForSizeEstimation)
.map(v => toBytes(v).length)
.sum
.withSideInputs(streamsCntSI)
.map { case (totalSize, ctx) =>
val avgSize = totalSize / numElementsForSizeEstimation
val targetNumElements = trainingBytesTargetActual / avgSize
val sampleRate = targetNumElements.toDouble / ctx(streamsCntSI)
logger.info(s"Computed sample rate for Zstd dictionary: ${sampleRate}")
sampleRate
}
.toSCollection
.asSingletonSideInput
scoll
.withSideInputs(sampleRateSI)
.flatMap {
case (s, ctx) if new Random().nextDouble() <= ctx(sampleRateSI) =>
Some(toBytes(s))
case _ => None
}
.toSCollection
.keyBy(_ => ())
.groupByKey
.map { case (_, elements) =>
val zstdSampleSize = {
val sum = elements.map(_.length.toLong).sum
if (sum > Int.MaxValue.toLong) Int.MaxValue else sum.toInt
}
logger.info(s"Training set size for for Zstd dictionary: ${zstdSampleSize}")
val trainer = new ZstdDictTrainer(zstdSampleSize, zstdDictSizeBytes)
elements.foreach(trainer.addSample)
trainer.trainSamples()
}
}
.withName("Save Zstd Dictionary")
.saveAsBinaryFile(path)
.underlying
}