scripts/multijoin.py (158 lines of code) (raw):

#!/usr/bin/env python3 # # Copyright 2016 Spotify AB. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import string import sys import textwrap # Utilities def mkVals(n): return list(string.ascii_uppercase[:n]) def mkArgs(n): return ', '.join(map(lambda x: x.lower(), mkVals(n))) def mkClassTags(n): return ', '.join(['KEY'] + mkVals(n)) def mkFnArgs(n): return ', '.join( x.lower() + ': SCollection[(KEY, %s)]' % x for x in mkVals(n)) def mkFnRetVal(n, aWrapper=None, otherWrapper=None): def wrap(wrapper, x): return x if wrapper is None else wrapper + '[%s]' % x vals = (wrap(aWrapper if x == 'A' else otherWrapper, x) for x in mkVals(n)) return 'SCollection[(KEY, (%s))]' % ', '.join(vals) def common(out, vals): print(' implicit val keyCoder = a.keyCoder', file=out) print(' implicit val (%s) = (%s)' % ( ', '.join('coder' + x for x in vals), ', '.join('%s.valueCoder' % x.lower() for x in vals)), file=out) print(' val (%s) = (%s)' % ( ', '.join('tag' + x for x in vals), ', '.join('new TupleTag[%s]()' % x for x in vals)), file=out) print(' val keyed = KeyedPCollectionTuple', file=out) print(' .of(tagA, a.toKV.internal)', file=out) for x in vals[1:]: print(' .and(tag%s, %s.toKV.internal)' % (x, x.lower()), file=out) print( ' .apply(s"CoGroupByKey@$tfName", CoGroupByKey.create())', file=out) # Functions def cogroup(out, n): vals = mkVals(n) print(' def cogroup[%s](%s): %s = {' % ( mkClassTags(n), mkFnArgs(n), mkFnRetVal(n, 'Iterable', 'Iterable')), file=out) common(out, vals) print(' a.context.wrap(keyed).withName(tfName).map { kv =>', file=out) print(' val (key, result) = (kv.getKey, kv.getValue)', file=out) print(' (key, (%s))' % ', '.join( 'result.getAll(tag%s).asScala' % x for x in vals), file=out) # NOQA print(' }', file=out) print(' }', file=out) print(file=out) def join(out, n): vals = mkVals(n) print(' def apply[%s](%s): %s = {' % ( mkClassTags(n), mkFnArgs(n), mkFnRetVal(n)), file=out) common(out, vals) print( ' a.context.wrap(keyed).withName(tfName).flatMap { kv =>', file=out) print(' val (key, result) = (kv.getKey, kv.getValue)', file=out) print(' for {', file=out) for x in reversed(vals): print(' %s <- result.getAll(tag%s).asScala.iterator' % ( x.lower(), x), file=out) print(' } yield (key, (%s))' % mkArgs(n), file=out) print(' }', file=out) print(' }', file=out) print(file=out) def left(out, n): vals = mkVals(n) print(' def left[%s](%s): %s = {' % ( mkClassTags(n), mkFnArgs(n), mkFnRetVal(n, None, 'Option')), file=out) common(out, vals) print( ' a.context.wrap(keyed).withName(tfName).flatMap { kv =>', file=out) print(' val (key, result) = (kv.getKey, kv.getValue)', file=out) print(' for {', file=out) for (i, x) in enumerate(reversed(vals)): if (i == n - 1): print(' %s <- result.getAll(tag%s).asScala.iterator' % ( x.lower(), x), file=out) else: print(' %s <- toOptions(result.getAll(tag%s).asScala.iterator)' % ( # NOQA x.lower(), x), file=out) print(' } yield (key, (%s))' % mkArgs(n), file=out) print(' }', file=out) print(' }', file=out) print(file=out) def outer(out, n): vals = mkVals(n) print(' def outer[%s](%s): %s = {' % ( mkClassTags(n), mkFnArgs(n), mkFnRetVal(n, 'Option', 'Option')), file=out) common(out, vals) print( ' a.context.wrap(keyed).withName(tfName).flatMap { kv =>', file=out) print(' val (key, result) = (kv.getKey, kv.getValue)', file=out) print(' for {', file=out) for (i, x) in enumerate(reversed(vals)): print( ' %s <- toOptions(result.getAll(tag%s).asScala.iterator)' % (x.lower(), x), # NOQA file=out) print(' } yield (key, (%s))' % mkArgs(n), file=out) print(' }', file=out) print(' }', file=out) print(file=out) def main(out): print(textwrap.dedent(''' /* * Copyright 2019 Spotify AB. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ // generated with multijoin.py package com.spotify.scio.util import com.spotify.scio.values.SCollection import org.apache.beam.sdk.transforms.join.{CoGroupByKey, KeyedPCollectionTuple} # NOQA import org.apache.beam.sdk.values.TupleTag import scala.jdk.CollectionConverters._ trait MultiJoin extends Serializable { protected def tfName: String = CallSites.getCurrent def toOptions[T](xs: Iterator[T]): Iterator[Option[T]] = if (xs.isEmpty) Iterator(None) else xs.map(Option(_)) ''').replace(' # NOQA', '').lstrip('\n'), file=out) N = 22 for i in range(2, N + 1): cogroup(out, i) for i in range(2, N + 1): join(out, i) for i in range(2, N + 1): left(out, i) for i in range(2, N + 1): outer(out, i) print('}', file=out) print(textwrap.dedent(''' object MultiJoin extends MultiJoin { def withName(name: String): MultiJoin = new NamedMultiJoin(name) } private class NamedMultiJoin(val name: String) extends MultiJoin { override def tfName: String = name } ''').rstrip('\n'), file=out) if __name__ == '__main__': main(sys.stdout)