in scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/TraversablesOrderedBuf.scala [90:327]
def apply(c: Context)(
buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]],
outerType: c.Type,
maybeSort: ShouldSort,
maybeArray: MaybeArray
): TreeOrderedBuf[c.type] = {
import c.universe._
def freshT(id: String) = TermName(c.freshName(s"fresh_$id"))
val dispatcher = buildDispatcher
val companionSymbol = outerType.typeSymbol.companionSymbol
// When dealing with a map we have 2 type args, and need to generate the tuple type
// it would correspond to if we .toList the Map.
val innerType = if (outerType.asInstanceOf[TypeRefApi].args.size == 2) {
val (tpe1, tpe2) = (
outerType.asInstanceOf[TypeRefApi].args.head,
outerType.asInstanceOf[TypeRefApi].args(1)
) // linter:ignore
val containerType = typeOf[Tuple2[Any, Any]].asInstanceOf[TypeRef]
import compat._
TypeRef.apply(containerType.pre, containerType.sym, List(tpe1, tpe2))
} else {
outerType.asInstanceOf[TypeRefApi].args.head
}
val innerTypes = outerType.asInstanceOf[TypeRefApi].args
val innerBuf: TreeOrderedBuf[c.type] = dispatcher(innerType)
// TODO it would be nice to capture one instance of this rather
// than allocate in every call in the materialized class
val ioa = freshT("ioa")
val iob = freshT("iob")
val innerOrd = q"""
new _root_.scala.math.Ordering[${innerBuf.tpe}] {
def compare(a: ${innerBuf.tpe}, b: ${innerBuf.tpe}) = {
val $ioa: ${innerBuf.tpe} = a
val $iob: ${innerBuf.tpe} = b
${innerBuf.compare(ioa, iob)}
}
}
"""
new TreeOrderedBuf[c.type] {
override val ctx: c.type = c
override val tpe = outerType
override def compareBinary(inputStreamA: ctx.TermName, inputStreamB: ctx.TermName) = {
val innerCompareFn = freshT("innerCompareFn")
val a = freshT("a")
val b = freshT("b")
q"""
val $innerCompareFn = { (a: _root_.java.io.InputStream, b: _root_.java.io.InputStream) =>
val $a: _root_.java.io.InputStream = a
val $b: _root_.java.io.InputStream = b
${innerBuf.compareBinary(a, b)}
};
_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.TraversableHelpers.rawCompare($inputStreamA, $inputStreamB)($innerCompareFn)
"""
}
override def put(inputStream: ctx.TermName, element: ctx.TermName) = {
val asArray = freshT("asArray")
val bytes = freshT("bytes")
val len = freshT("len")
val pos = freshT("pos")
val innerElement = freshT("innerElement")
val cmpRes = freshT("cmpRes")
maybeSort match {
case DoSort =>
q"""
val $len = $element.size
$inputStream.writePosVarInt($len)
if($len > 0) {
val $asArray = $element.toArray[${innerBuf.tpe}]
// Sorting on the in-memory is the same as binary
_root_.scala.util.Sorting.quickSort[${innerBuf.tpe}]($asArray)($innerOrd)
var $pos = 0
while($pos < $len) {
val $innerElement = $asArray($pos)
${innerBuf.put(inputStream, innerElement)}
$pos += 1
}
}
"""
case NoSort =>
q"""
val $len: Int = $element.size
$inputStream.writePosVarInt($len)
$element.foreach { case $innerElement =>
${innerBuf.put(inputStream, innerElement)}
}
"""
}
}
override def hash(element: ctx.TermName): ctx.Tree = {
val currentHash = freshT("currentHash")
val len = freshT("len")
val target = freshT("target")
maybeSort match {
case NoSort =>
q"""
var $currentHash: Int = _root_.com.twitter.scalding.serialization.MurmurHashUtils.seed
var $len = 0
$element.foreach { t =>
val $target = t
$currentHash =
_root_.com.twitter.scalding.serialization.MurmurHashUtils.mixH1($currentHash, ${innerBuf
.hash(target)})
// go ahead and compute the length so we don't traverse twice for lists
$len += 1
}
_root_.com.twitter.scalding.serialization.MurmurHashUtils.fmix($currentHash, $len)
"""
case DoSort =>
// We actually don't sort here, which would be expensive, but combine with a commutative operation
// so the order that we see items won't matter. For this we use XOR
q"""
var $currentHash: Int = _root_.com.twitter.scalding.serialization.MurmurHashUtils.seed
var $len = 0
$element.foreach { t =>
val $target = t
$currentHash = $currentHash ^ ${innerBuf.hash(target)}
$len += 1
}
// Might as well be fancy when we mix in the length
_root_.com.twitter.scalding.serialization.MurmurHashUtils.fmix($currentHash, $len)
"""
}
}
override def get(inputStream: ctx.TermName): ctx.Tree = {
val len = freshT("len")
val firstVal = freshT("firstVal")
val travBuilder = freshT("travBuilder")
val iter = freshT("iter")
val extractionTree = maybeArray match {
case IsArray =>
q"""val $travBuilder = new _root_.scala.Array[..$innerTypes]($len)
var $iter = 0
while($iter < $len) {
$travBuilder($iter) = ${innerBuf.get(inputStream)}
$iter = $iter + 1
}
$travBuilder : $outerType
"""
case NotArray =>
q"""val $travBuilder = $companionSymbol.newBuilder[..$innerTypes]
$travBuilder.sizeHint($len)
var $iter = 0
while($iter < $len) {
$travBuilder += ${innerBuf.get(inputStream)}
$iter = $iter + 1
}
$travBuilder.result : $outerType
"""
}
q"""
val $len: _root_.scala.Int = $inputStream.readPosVarInt
if($len > 0) {
if($len == 1) {
val $firstVal: $innerType = ${innerBuf.get(inputStream)}
$companionSymbol.apply($firstVal) : $outerType
} else {
$extractionTree : $outerType
}
} else {
$companionSymbol.empty : $outerType
}
"""
}
override def compare(elementA: ctx.TermName, elementB: ctx.TermName): ctx.Tree = {
val a = freshT("a")
val b = freshT("b")
val cmpFnName = freshT("cmpFnName")
maybeSort match {
case DoSort =>
q"""
_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.TraversableHelpers.sortedCompare[${innerBuf.tpe}]($elementA, $elementB)($innerOrd)
"""
case NoSort =>
q"""
_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.TraversableHelpers.iteratorCompare[${innerBuf.tpe}]($elementA.iterator, $elementB.iterator)($innerOrd)
"""
}
}
override val lazyOuterVariables: Map[String, ctx.Tree] = innerBuf.lazyOuterVariables
override def length(element: Tree): CompileTimeLengthTypes[c.type] =
innerBuf.length(q"$element.head") match {
case const: ConstantLengthCalculation[_] =>
FastLengthCalculation(c)(q"""{
posVarIntSize($element.size) + $element.size * ${const.toInt}
}""")
case m: MaybeLengthCalculation[_] =>
val maybeRes = freshT("maybeRes")
MaybeLengthCalculation(c)(q"""
if($element.isEmpty) {
val sizeOfZero = 1 // writing the constant 0, for length, takes 1 byte
_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen(sizeOfZero)
} else {
val maybeRes = ${m.asInstanceOf[MaybeLengthCalculation[c.type]].t}
maybeRes match {
case _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.ConstLen(constSize) =>
val sizeOverhead = posVarIntSize($element.size)
_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen(constSize * $element.size + sizeOverhead)
// todo maybe we should support this case
// where we can visit every member of the list relatively fast to ask
// its length. Should we care about sizes instead maybe?
case _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen(_) =>
_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.NoLengthCalculation
case _ => _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.NoLengthCalculation
}
}
""")
// Something we can't workout the size of ahead of time
case _ =>
MaybeLengthCalculation(c)(q"""
if($element.isEmpty) {
val sizeOfZero = 1 // writing the constant 0, for length, takes 1 byte
_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen(sizeOfZero)
} else {
_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.NoLengthCalculation
}
""")
}
}
}