in scalding-spark/src/main/scala/com/twitter/scalding/spark_backend/SparkBackend.scala [67:238]
def plan(
config: Config,
srcs: Resolver[Input, SparkSource],
counter: SparkCountersInternal
): FunctionK[TypedPipe, Op] =
Memoize.functionK(new Memoize.RecursiveK[TypedPipe, Op] {
import TypedPipe._
def toFunction[A] = {
case (cp @ CounterPipe(_), rec) =>
def go[A](p: CounterPipe[A]): Op[A] = {
// spark only guarantees accurate accumulators inside of RDD actions
// so instead of a straight map, we double loop with forEach
val resolve = rec(p.pipe)
resolve
.forEachIdentity { x =>
// increment counter for values in pipe
counter.add(x._2.iterator)
}
.map { x =>
// pass through values in pipe
x._1
}
}
go(cp)
case (cp @ CrossPipe(_, _), rec) =>
def go[A, B](cp: CrossPipe[A, B]): Op[(A, B)] =
rec(cp.viaHashJoin)
go(cp)
case (CrossValue(left, EmptyValue), rec) => rec(EmptyTypedPipe)
case (CrossValue(left, LiteralValue(v)), rec) =>
val op = rec(left) // linter:disable:UndesirableTypeInference
op.map((_, v))
case (CrossValue(left, ComputedValue(right)), rec) =>
rec(CrossPipe(left, right))
case (p: DebugPipe[a], rec) =>
// There is really little that can be done here but println
rec[a](p.input.map(DebugFn()))
case (EmptyTypedPipe, rec) =>
Op.Empty
case (fk @ FilterKeys(_, _), rec) =>
def go[K, V](node: FilterKeys[K, V]): Op[(K, V)] = {
val FilterKeys(pipe, fn) = node
rec(pipe).filter(FilterKeysToFilter(fn))
}
go(fk)
case (f @ Filter(_, _), rec) =>
def go[T](f: Filter[T]): Op[T] = {
val Filter(p, fn) = f
rec[T](p).filter(fn)
}
go(f)
case (f @ FlatMapValues(_, _), rec) =>
def go[K, V, U](node: FlatMapValues[K, V, U]) = {
val fn = node.fn
rec(node.input).flatMapValues(fn)
}
go(f)
case (FlatMapped(prev, fn), rec) =>
val op = rec(prev) // linter:disable:UndesirableTypeInference
op.concatMap(fn)
case (ForceToDisk(pipe), rec) =>
val sparkPipe = rec(pipe)
config.getForceToDiskPersistMode.getOrElse(StorageLevel.DISK_ONLY) match {
case StorageLevel.NONE => sparkPipe
case notNone => sparkPipe.persist(notNone)
}
case (Fork(pipe), rec) =>
val sparkPipe = rec(pipe)
// just let spark do it's default thing on Forks.
// unfortunately, that may mean recomputing the upstream
// multiple times, so users may want to override this,
// or be careful about using forceToDisk
config.getForkPersistMode.getOrElse(StorageLevel.NONE) match {
case StorageLevel.NONE => sparkPipe
case notNone => sparkPipe.persist(notNone)
}
case (IterablePipe(iterable), _) =>
Op.FromIterable(iterable)
case (f @ MapValues(_, _), rec) =>
def go[K, V, U](node: MapValues[K, V, U]): Op[(K, U)] =
rec(node.input).mapValues(node.fn)
go(f)
case (Mapped(input, fn), rec) =>
val op = rec(input) // linter:disable:UndesirableTypeInference
op.map(fn)
case (m @ MergedTypedPipe(_, _), rec) =>
// Spark can handle merging several inputs at once,
// but won't otherwise optimize if not given in
// one batch
OptimizationRules.unrollMerge(m) match {
case Nil => rec(EmptyTypedPipe)
case h :: Nil => rec(h)
case h :: rest =>
val pc = ConfigPartitionComputer(config, None)
Op.Merged(pc, rec(h), rest.map(rec(_)))
}
case (SourcePipe(src), _) =>
Op.Source(config, src, srcs(src))
case (slk @ SumByLocalKeys(_, _), rec) =>
def sum[K, V](sblk: SumByLocalKeys[K, V]): Op[(K, V)] = {
// we can use Algebird's SummingCache https://github.com/twitter/algebird/blob/develop/algebird-core/src/main/scala/com/twitter/algebird/SummingCache.scala#L36
// plus mapPartitions to implement this
val SumByLocalKeys(p, sg) = sblk
// TODO set a default in a better place
val defaultCapacity = 10000
val capacity = config.getMapSideAggregationThreshold.getOrElse(defaultCapacity)
rec(p).mapPartitions(CachingSum(capacity, sg))
}
sum(slk)
case (tp: TrappedPipe[a], rec) =>
// this can be interpretted as catching any exception
// on the map-phase until the next partition, so it can
// be made to work by changing Op to return all
// the values that fail on error
rec[a](tp.input)
case (wd: WithDescriptionTypedPipe[a], rec) =>
// TODO we could optionally print out the descriptions
// after the future completes
rec[a](wd.input)
case (woc: WithOnComplete[a], rec) =>
// TODO
rec[a](woc.input)
case (hcg @ HashCoGroup(_, _, _), rec) =>
def go[K, V1, V2, W](hcg: HashCoGroup[K, V1, V2, W]): Op[(K, W)] = {
val leftOp = rec(hcg.left)
val rightOp = rec(ReduceStepPipe(HashJoinable.toReduceStep(hcg.right)))
leftOp.hashJoin(rightOp)(hcg.joiner)
}
go(hcg)
case (CoGroupedPipe(cg), rec) =>
planCoGroup(config, cg, rec)
case (ReduceStepPipe(ir @ IdentityReduce(_, _, _, descriptions, _)), rec) =>
def go[K, V1, V2](ir: IdentityReduce[K, V1, V2]): Op[(K, V2)] = {
type OpT[V] = Op[(K, V)]
val op = rec(ir.mapped)
ir.evidence.subst[OpT](op)
}
go(ir)
case (ReduceStepPipe(uir @ UnsortedIdentityReduce(_, _, _, descriptions, _)), rec) =>
def go[K, V1, V2](uir: UnsortedIdentityReduce[K, V1, V2]): Op[(K, V2)] = {
type OpT[V] = Op[(K, V)]
val op = rec(uir.mapped)
uir.evidence.subst[OpT](op)
}
go(uir)
case (ReduceStepPipe(ivsr @ IdentityValueSortedReduce(_, _, _, _, _, _)), rec) =>
def go[K, V1, V2](uir: IdentityValueSortedReduce[K, V1, V2]): Op[(K, V2)] = {
type OpT[V] = Op[(K, V)]
val op = rec(uir.mapped)
val pc = ConfigPartitionComputer(config, uir.reducers)
val sortedOp = op.sorted(pc)(uir.keyOrdering, uir.valueSort)
uir.evidence.subst[OpT](sortedOp)
}
go(ivsr)
case (ReduceStepPipe(ValueSortedReduce(ordK, pipe, ordV, fn, red, _)), rec) =>
val op = rec(pipe)
val pc = ConfigPartitionComputer(config, red)
op.sortedMapGroup(pc)(fn)(ordK, ordV)
case (ReduceStepPipe(IteratorMappedReduce(ordK, pipe, fn, red, _)), rec) =>
val op = rec(pipe)
val pc = ConfigPartitionComputer(config, red)
op.mapGroup(pc)(fn)(ordK)
}
})