scripts/smb_multijoin.py (168 lines of code) (raw):

#!/usr/bin/env python3 # # Copyright 2021 Spotify AB. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import string import sys import textwrap # Utilities def mkTypes(n): return list(string.ascii_uppercase[:n]) def mkVals(n): return [x.lower() for x in mkTypes(n)] def mkRawClassTags(n): return ', '.join(['KEY'] + mkTypes(n)) def mkClassTags(n): arg_list = ['{}: Coder'.format(element) for element in mkTypes(n)] return ', '.join(['KEY: Coder'] + arg_list) def mkReadArgs(n): return ', '.join('%s: SortedBucketIO.Read[%s]' % (x.lower(), x) for x in mkTypes(n)) def mkFnArgs(n): return 'keyClass: Class[KEY], ' + mkReadArgs(n) def fnRetValHelper(n, aWrapper=None, otherWrapper=None): def wrap(wrapper, x): return x if wrapper is None else wrapper + '[%s]' % x vals = (wrap(aWrapper if x == 'A' else otherWrapper, x) for x in mkTypes(n)) return vals def mkCogroupFnRetVal(n, aWrapper=None, otherWrapper=None): vals = fnRetValHelper(n, aWrapper, otherWrapper) return 'SCollection[(KEY, (%s))]' % ', '.join(vals) def mkTransformFnRetVal(n, aWrapper=None, otherWrapper=None): vals = fnRetValHelper(n, aWrapper, otherWrapper) return 'SortMergeTransform.ReadBuilder[KEY, KEY, Void, (%s)]' % ', '.join(vals) def mkTupleTag(n): return ['val tupleTag%s = %s.getTupleTag' % (x, x.lower()) for x in mkTypes(n)] # Functions def sortMergeCoGroup(out, n): types = mkTypes(n) vals = mkVals(n) args = ', '.join(vals) fnArgs = mkFnArgs(n) print('\tdef sortMergeCoGroup[%s](%s): %s = self.requireNotClosed {' % ( mkClassTags(n), fnArgs + ', ' + 'targetParallelism: TargetParallelism', mkCogroupFnRetVal(n, 'Iterable', 'Iterable')), file=out) print('\t\tval tfName = self.tfName''', file=out) print('\t\tval keyed = if (self.isTest) {', file=out) print('\t\t\ttestCoGroup[KEY](%s)' % args, file=out) print('\t\t} else {', file=out) print('\t\t\tval transform = SortedBucketIO', file=out) print('\t\t\t\t.read(keyClass)', file=out) print('\t\t\t\t.of(%s)' % args, file=out) print('\t\t\t\t.withTargetParallelism(targetParallelism)', file=out) print('\t\t\tself.wrap(self.pipeline.apply(s"SMB CoGroupByKey@$tfName", transform))', file=out) print('\t\t}') print('\t\t' + '\n\t\t'.join(mkTupleTag(n)), file=out) print('\t\tkeyed', file=out) print('\t\t\t.withName(tfName)', file=out) print('\t\t\t.map { kv =>', file=out) print('\t\t\t\tval result = kv.getValue', file=out) print('\t\t\t\t(', file=out) print('\t\t\t\t\tkv.getKey(),', file=out) print('\t\t\t\t\t(', file=out) print('\t\t\t\t\t\t' + ',\n\t\t\t\t\t\t'.join('result.getAll(tupleTag%s).asScala' % x for x in types), file=out) print('\t\t\t\t\t)', file=out) print('\t\t\t\t)', file=out) print('\t\t\t}', file=out) print('\t}', file=out) print(file=out) print('\tdef sortMergeCoGroup[%s](%s): %s = {' % ( mkClassTags(n), fnArgs, mkCogroupFnRetVal(n, 'Iterable', 'Iterable')), file=out) print('\t\tsortMergeCoGroup(keyClass, %s, TargetParallelism.auto())' % args, file=out) print('\t}', file=out) print(file=out) def sortMergeTransform(out, n): types = mkTypes(n) vals = mkVals(n) args = ', '.join(vals) fnArgs = mkFnArgs(n) print('\tdef sortMergeTransform[%s](%s): %s = self.requireNotClosed {' % ( mkClassTags(n), fnArgs + ', ' + 'targetParallelism: TargetParallelism', mkTransformFnRetVal(n, 'Iterable', 'Iterable')), file=out) print('\t\t' + '\n\t\t'.join(mkTupleTag(n)), file=out) print('\t\tval fromResult = { (result: CoGbkResult) =>', file=out) print('\t\t\t(', file=out) print('\t\t\t\t' + ',\n\t\t\t\t'.join('result.getAll(tupleTag%s).asScala' % x for x in types), file=out) print('\t\t\t)', file=out) print('\t\t}', file=out) print('\t\tif (self.isTest) {', file=out) print('\t\t\tval result = testCoGroup[KEY](%s)' % args, file=out) print('\t\t\tval keyed = result.map(kv => kv.getKey -> fromResult(kv.getValue))', file=out) print('\t\t\tnew SortMergeTransform.ReadBuilderTest(self, keyed)', file=out) print('\t\t} else {', file=out) print('\t\t\tval transform = SortedBucketIO', file=out) print('\t\t\t\t.read(keyClass)', file=out) print('\t\t\t\t.of(%s)' % args, file=out) print('\t\t\t\t.withTargetParallelism(targetParallelism)', file=out) print('\t\t\tnew SortMergeTransform.ReadBuilderImpl(self, transform, fromResult)', file=out) print('\t\t}', file=out) print('\t}', file=out) print(file=out) print('\tdef sortMergeTransform[%s](%s): %s = {' % ( mkClassTags(n), fnArgs, mkTransformFnRetVal(n, 'Iterable', 'Iterable')), file=out) print('\t\tsortMergeTransform(keyClass, %s, TargetParallelism.auto())' % args, file=out) print('\t}', file=out) print(file=out) def testCoGroup(out): print(''' \tprivate[smb] def testCoGroup[K]( \t\treads: SortedBucketIO.Read[_]* \t): SCollection[KV[K, CoGbkResult]] = { \t\tval testInput = TestDataManager.getInput(self.testId.get) \t\tval read :: rs = reads.asInstanceOf[Seq[SortedBucketIO.Read[Any]]].toList: @nowarn \t\tval test = testInput[(K, Any)](SortedBucketIOUtil.testId(read)).toSCollection(self) \t\tval keyed = rs \t\t\t.foldLeft(KeyedPCollectionTuple.of(read.getTupleTag, test.toKV.internal)) { (kpt, r) => \t\t\t\tval c = testInput[(K, Any)](SortedBucketIOUtil.testId(r)).toSCollection(self) \t\t\t\tkpt.and(r.getTupleTag, c.toKV.internal) \t\t\t} \t\t\t.apply(CoGroupByKey.create()) \t\tself.wrap(keyed) \t}''', file=out) def main(out): print(''' /* * Copyright 2021 Spotify AB. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ // generated with smb-multijoin.py package com.spotify.scio.smb.util import com.spotify.scio.ScioContext import com.spotify.scio.coders.Coder import com.spotify.scio.testing.TestDataManager import com.spotify.scio.values._ import org.apache.beam.sdk.extensions.smb.{SortedBucketIO, SortedBucketIOUtil, TargetParallelism} import org.apache.beam.sdk.transforms.join.{CoGbkResult, CoGroupByKey, KeyedPCollectionTuple} import org.apache.beam.sdk.values.KV import com.spotify.scio.smb.SortMergeTransform import org.typelevel.scalaccompat.annotation.nowarn import scala.jdk.CollectionConverters._ final class SMBMultiJoin(private val self: ScioContext) {'''.lstrip('\n'), file=out) N = 22 for i in range(2, N + 1): sortMergeCoGroup(out, i) for i in range(2, N + 1): sortMergeTransform(out, i) testCoGroup(out) print('}', file=out) print(''' object SMBMultiJoin { \tfinal def apply(sc: ScioContext): SMBMultiJoin = new SMBMultiJoin(sc) }''', file=out) if __name__ == '__main__': main(sys.stdout)