in scio-smb/src/main/scala/com/spotify/scio/smb/syntax/SortMergeBucketScioContextSyntax.scala [334:414]
def sortMergeCoGroup[K1: Coder, K2: Coder, A: Coder, B: Coder](
keyClass: Class[K1],
keyClassSecondary: Class[K2],
a: SortedBucketIO.Read[A],
b: SortedBucketIO.Read[B]
): SCollection[((K1, K2), (Iterable[A], Iterable[B]))] =
sortMergeCoGroup(keyClass, keyClassSecondary, a, b, TargetParallelism.auto())
/**
* For each key K in `a` or `b` or `c`, return a resulting SCollection that contains a tuple with
* the list of values for that key in `a`, `b` and `c`.
*
* See note on [[SortedBucketScioContext.sortMergeJoin]] for information on how an SMB cogroup
* differs from a regular [[org.apache.beam.sdk.transforms.join.CoGroupByKey]] operation.
*
* @group cogroup
*
* @param keyClass
* cogroup key class. Must have a Coder in Beam's default
* [[org.apache.beam.sdk.coders.CoderRegistry]] as custom key coders are not supported yet.
* @param targetParallelism
* the desired parallelism of the job. See
* [[org.apache.beam.sdk.extensions.smb.TargetParallelism]] for more information.
*/
@experimental
def sortMergeCoGroup[K: Coder, A: Coder, B: Coder, C: Coder](
keyClass: Class[K],
a: SortedBucketIO.Read[A],
b: SortedBucketIO.Read[B],
c: SortedBucketIO.Read[C],
targetParallelism: TargetParallelism
): SCollection[(K, (Iterable[A], Iterable[B], Iterable[C]))] =
SMBMultiJoin(self).sortMergeCoGroup(keyClass, a, b, c, targetParallelism)
/** `targetParallelism` defaults to `TargetParallelism.auto()` */
@experimental
def sortMergeCoGroup[K: Coder, A: Coder, B: Coder, C: Coder](
keyClass: Class[K],
a: SortedBucketIO.Read[A],
b: SortedBucketIO.Read[B],
c: SortedBucketIO.Read[C]
): SCollection[(K, (Iterable[A], Iterable[B], Iterable[C]))] =
SMBMultiJoin(self).sortMergeCoGroup(keyClass, a, b, c)
/** Secondary keyed variant */
@experimental
def sortMergeCoGroup[K1: Coder, K2: Coder, A: Coder, B: Coder, C: Coder](
keyClass: Class[K1],
keyClassSecondary: Class[K2],
a: SortedBucketIO.Read[A],
b: SortedBucketIO.Read[B],
c: SortedBucketIO.Read[C],
targetParallelism: TargetParallelism
): SCollection[((K1, K2), (Iterable[A], Iterable[B], Iterable[C]))] = self.requireNotClosed {
val tfName = self.tfName
val keyed = if (self.isTest) {
SMBMultiJoin(self).testCoGroup(a, b, c)
} else {
val t = SortedBucketIO
.read(keyClass, keyClassSecondary)
.of(a)
.and(b)
.and(c)
.withTargetParallelism(targetParallelism)
self.wrap(self.pipeline.apply(s"SMB CoGroupForKey@$tfName", t))
}
keyed
.withName(tfName)
.map { kv =>
val k = kv.getKey
val k1 = k.getKey
val k2 = k.getValue
val cgbkResult = kv.getValue
val asForK = cgbkResult.getAll(a.getTupleTag).asScala
val bsForK = cgbkResult.getAll(b.getTupleTag).asScala
val csForK = cgbkResult.getAll(c.getTupleTag).asScala
(k1, k2) -> ((asForK, bsForK, csForK))
}
}