scripts/multijoin.py (158 lines of code) (raw):
#!/usr/bin/env python3
#
# Copyright 2016 Spotify AB.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import string
import sys
import textwrap
# Utilities
def mkVals(n):
return list(string.ascii_uppercase[:n])
def mkArgs(n):
return ', '.join(map(lambda x: x.lower(), mkVals(n)))
def mkClassTags(n):
return ', '.join(['KEY'] + mkVals(n))
def mkFnArgs(n):
return ', '.join(
x.lower() + ': SCollection[(KEY, %s)]' % x
for x in mkVals(n))
def mkFnRetVal(n, aWrapper=None, otherWrapper=None):
def wrap(wrapper, x):
return x if wrapper is None else wrapper + '[%s]' % x
vals = (wrap(aWrapper if x == 'A' else otherWrapper, x) for x in mkVals(n))
return 'SCollection[(KEY, (%s))]' % ', '.join(vals)
def common(out, vals):
print(' implicit val keyCoder = a.keyCoder', file=out)
print(' implicit val (%s) = (%s)' % (
', '.join('coder' + x for x in vals),
', '.join('%s.valueCoder' % x.lower() for x in vals)),
file=out)
print(' val (%s) = (%s)' % (
', '.join('tag' + x for x in vals),
', '.join('new TupleTag[%s]()' % x for x in vals)),
file=out)
print(' val keyed = KeyedPCollectionTuple', file=out)
print(' .of(tagA, a.toKV.internal)', file=out)
for x in vals[1:]:
print(' .and(tag%s, %s.toKV.internal)' % (x, x.lower()), file=out)
print(
' .apply(s"CoGroupByKey@$tfName", CoGroupByKey.create())',
file=out)
# Functions
def cogroup(out, n):
vals = mkVals(n)
print(' def cogroup[%s](%s): %s = {' % (
mkClassTags(n), mkFnArgs(n), mkFnRetVal(n, 'Iterable', 'Iterable')),
file=out)
common(out, vals)
print(' a.context.wrap(keyed).withName(tfName).map { kv =>', file=out)
print(' val (key, result) = (kv.getKey, kv.getValue)', file=out)
print(' (key, (%s))' % ', '.join(
'result.getAll(tag%s).asScala' % x for x in vals),
file=out) # NOQA
print(' }', file=out)
print(' }', file=out)
print(file=out)
def join(out, n):
vals = mkVals(n)
print(' def apply[%s](%s): %s = {' % (
mkClassTags(n), mkFnArgs(n), mkFnRetVal(n)), file=out)
common(out, vals)
print(
' a.context.wrap(keyed).withName(tfName).flatMap { kv =>',
file=out)
print(' val (key, result) = (kv.getKey, kv.getValue)', file=out)
print(' for {', file=out)
for x in reversed(vals):
print(' %s <- result.getAll(tag%s).asScala.iterator' % (
x.lower(), x),
file=out)
print(' } yield (key, (%s))' % mkArgs(n), file=out)
print(' }', file=out)
print(' }', file=out)
print(file=out)
def left(out, n):
vals = mkVals(n)
print(' def left[%s](%s): %s = {' % (
mkClassTags(n), mkFnArgs(n), mkFnRetVal(n, None, 'Option')),
file=out)
common(out, vals)
print(
' a.context.wrap(keyed).withName(tfName).flatMap { kv =>',
file=out)
print(' val (key, result) = (kv.getKey, kv.getValue)', file=out)
print(' for {', file=out)
for (i, x) in enumerate(reversed(vals)):
if (i == n - 1):
print(' %s <- result.getAll(tag%s).asScala.iterator' % (
x.lower(), x),
file=out)
else:
print(' %s <- toOptions(result.getAll(tag%s).asScala.iterator)' % ( # NOQA
x.lower(), x),
file=out)
print(' } yield (key, (%s))' % mkArgs(n), file=out)
print(' }', file=out)
print(' }', file=out)
print(file=out)
def outer(out, n):
vals = mkVals(n)
print(' def outer[%s](%s): %s = {' % (
mkClassTags(n), mkFnArgs(n), mkFnRetVal(n, 'Option', 'Option')),
file=out)
common(out, vals)
print(
' a.context.wrap(keyed).withName(tfName).flatMap { kv =>',
file=out)
print(' val (key, result) = (kv.getKey, kv.getValue)', file=out)
print(' for {', file=out)
for (i, x) in enumerate(reversed(vals)):
print(
' %s <- toOptions(result.getAll(tag%s).asScala.iterator)' % (x.lower(), x), # NOQA
file=out)
print(' } yield (key, (%s))' % mkArgs(n), file=out)
print(' }', file=out)
print(' }', file=out)
print(file=out)
def main(out):
print(textwrap.dedent('''
/*
* Copyright 2019 Spotify AB.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
// generated with multijoin.py
package com.spotify.scio.util
import com.spotify.scio.values.SCollection
import org.apache.beam.sdk.transforms.join.{CoGroupByKey, KeyedPCollectionTuple} # NOQA
import org.apache.beam.sdk.values.TupleTag
import scala.jdk.CollectionConverters._
trait MultiJoin extends Serializable {
protected def tfName: String = CallSites.getCurrent
def toOptions[T](xs: Iterator[T]): Iterator[Option[T]] = if (xs.isEmpty) Iterator(None) else xs.map(Option(_))
''').replace(' # NOQA', '').lstrip('\n'), file=out)
N = 22
for i in range(2, N + 1):
cogroup(out, i)
for i in range(2, N + 1):
join(out, i)
for i in range(2, N + 1):
left(out, i)
for i in range(2, N + 1):
outer(out, i)
print('}', file=out)
print(textwrap.dedent('''
object MultiJoin extends MultiJoin {
def withName(name: String): MultiJoin = new NamedMultiJoin(name)
}
private class NamedMultiJoin(val name: String) extends MultiJoin {
override def tfName: String = name
}
''').rstrip('\n'), file=out)
if __name__ == '__main__':
main(sys.stdout)