override protected def read()

in scio-core/src/main/scala/com/spotify/scio/io/ZstdDictIO.scala [37:126]


  override protected def read(sc: ScioContext, params: Nothing): SCollection[T] =
    throw new UnsupportedOperationException("ZstdDictIO is write-only")

  override protected def write(
    data: SCollection[T],
    params: ZstdDictIO.WriteParam
  ): Tap[Nothing] = {
    val ZstdDictIO.WriteParam(
      zstdDictSizeBytes,
      numElementsForSizeEstimation,
      trainingBytesTarget
    ) = params

    // see https://github.com/facebook/zstd/issues/3769#issuecomment-1730261489
    if (zstdDictSizeBytes > (ZstdDictIO.DictionarySizeMbWarningThreshold * 1024 * 1024)) {
      logger.warn(
        s"Dictionary sizes over ${ZstdDictIO.DictionarySizeMbWarningThreshold}MB are of " +
          s"questionable utility. Consider reducing zstdDictSizeBytes."
      )
    }
    if (numElementsForSizeEstimation <= 0) {
      throw new IllegalArgumentException(
        s"numElementsForSizeEstimation must be positive, found $numElementsForSizeEstimation"
      )
    }
    // training bytes may not exceed 2GiB a.k.a. the max value of an Int
    val trainingBytesTargetActual: Int = trainingBytesTarget.getOrElse {
      // see https://github.com/facebook/zstd/blob/v1.5.5/lib/zdict.h#L132-L136
      val computed = Try(Math.multiplyExact(zstdDictSizeBytes, 100)).toOption.getOrElse {
        throw new IllegalArgumentException(
          "Using 100 * zstdDictSizeBytes would exceed 2GiB training bytes. " +
            "Reduce dictionary size"
        )
      }
      logger.info(s"No trainingBytesTarget passed, using ${computed} bytes")
      computed
    }
    if (trainingBytesTargetActual <= 0) {
      throw new IllegalArgumentException(
        s"trainingBytesTarget must be positive, found $trainingBytesTargetActual"
      )
    }

    val beamCoder = CoderMaterializer.beam(data.context, data.coder)
    def toBytes(v: T): Array[Byte] = CoderUtils.encodeToByteArray(beamCoder, v)

    data
      .transform("Create Zstd Dictionary") { scoll =>
        // estimate the sample rate we need by examining numElementsForSizeEstimation elements
        val streamsCntSI = scoll.count.asSingletonSideInput(0L)
        val sampleRateSI = scoll
          .take(numElementsForSizeEstimation)
          .map(v => toBytes(v).length)
          .sum
          .withSideInputs(streamsCntSI)
          .map { case (totalSize, ctx) =>
            val avgSize = totalSize / numElementsForSizeEstimation
            val targetNumElements = trainingBytesTargetActual / avgSize
            val sampleRate = targetNumElements.toDouble / ctx(streamsCntSI)
            logger.info(s"Computed sample rate for Zstd dictionary: ${sampleRate}")
            sampleRate
          }
          .toSCollection
          .asSingletonSideInput

        scoll
          .withSideInputs(sampleRateSI)
          .flatMap {
            case (s, ctx) if new Random().nextDouble() <= ctx(sampleRateSI) =>
              Some(toBytes(s))
            case _ => None
          }
          .toSCollection
          .keyBy(_ => ())
          .groupByKey
          .map { case (_, elements) =>
            val zstdSampleSize = {
              val sum = elements.map(_.length.toLong).sum
              if (sum > Int.MaxValue.toLong) Int.MaxValue else sum.toInt
            }
            logger.info(s"Training set size for for Zstd dictionary: ${zstdSampleSize}")
            val trainer = new ZstdDictTrainer(zstdSampleSize, zstdDictSizeBytes)
            elements.foreach(trainer.addSample)
            trainer.trainSamples()
          }
      }
      .withName("Save Zstd Dictionary")
      .saveAsBinaryFile(path)
      .underlying
  }