in scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/TreeOrderedBuf.scala [63:312]
def toOrderedSerialization[T](
c: Context
)(t: TreeOrderedBuf[c.type])(implicit T: t.ctx.WeakTypeTag[T]): t.ctx.Expr[OrderedSerialization[T]] = {
import t.ctx.universe._
def freshT(id: String) = TermName(c.freshName(s"fresh_$id"))
val outputLength = freshT("outputLength")
val innerLengthFn: Tree = {
val element = freshT("element")
val fnBodyOpt = t.length(q"$element") match {
case _: NoLengthCalculationAvailable[_] => None
case const: ConstantLengthCalculation[_] => None
case f: FastLengthCalculation[_] =>
Some(q"""
_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen(${f
.asInstanceOf[FastLengthCalculation[c.type]]
.t})
""")
case m: MaybeLengthCalculation[_] => Some(m.asInstanceOf[MaybeLengthCalculation[c.type]].t)
}
fnBodyOpt
.map { fnBody =>
q"""
private[this] def payloadLength($element: $T): _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.MaybeLength = {
lengthCalculationAttempts += 1
$fnBody
}
"""
}
.getOrElse(q"()")
}
def binaryLengthGen(typeName: Tree): (Tree, Tree) = {
val tempLen = freshT("tempLen")
val lensLen = freshT("lensLen")
val element = freshT("element")
val callDynamic = (
q"""override def staticSize: _root_.scala.Option[_root_.scala.Int] = _root_.scala.None""",
q"""
override def dynamicSize($element: $typeName): _root_.scala.Option[_root_.scala.Int] = {
if(skipLenCalc) _root_.scala.None else {
val $tempLen = payloadLength($element) match {
case _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.NoLengthCalculation =>
failedLengthCalc()
_root_.scala.None
case _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.ConstLen(l) => _root_.scala.Some(l)
case _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen(l) => _root_.scala.Some(l)
}
(if ($tempLen.isDefined) {
// Avoid a closure here while we are geeking out
val innerLen = $tempLen.get
val $lensLen = posVarIntSize(innerLen)
_root_.scala.Some(innerLen + $lensLen)
} else _root_.scala.None): _root_.scala.Option[_root_.scala.Int]
}
}
"""
)
t.length(q"$element") match {
case _: NoLengthCalculationAvailable[_] =>
(
q"""
override def staticSize: _root_.scala.Option[_root_.scala.Int] = _root_.scala.None""",
q"""
override def dynamicSize($element: $typeName): _root_.scala.Option[_root_.scala.Int] = _root_.scala.None"""
)
case const: ConstantLengthCalculation[_] =>
(
q"""
override val staticSize: _root_.scala.Option[_root_.scala.Int] = _root_.scala.Some(${const.toInt})""",
q"""
override def dynamicSize($element: $typeName): _root_.scala.Option[_root_.scala.Int] = staticSize"""
)
case f: FastLengthCalculation[_] => callDynamic
case m: MaybeLengthCalculation[_] => callDynamic
}
}
def genNoLenCalc = {
val baos = freshT("baos")
val element = freshT("element")
val outerOutputStream = freshT("os")
val len = freshT("len")
/**
* This is the worst case: we have to serialize in a side buffer and then see how large it actually is.
* This happens for cases, like string, where the cost to see the serialized size is not cheaper than
* directly serializing.
*/
q"""
private[this] def noLengthWrite($element: $T, $outerOutputStream: _root_.java.io.OutputStream): Unit = {
// Start with pretty big buffers because reallocation will be expensive
val $baos = new _root_.java.io.ByteArrayOutputStream(512)
${t.put(baos, element)}
val $len = $baos.size
$outerOutputStream.writePosVarInt($len)
$baos.writeTo($outerOutputStream)
}
"""
}
def putFnGen(outerbaos: TermName, element: TermName) = {
val oldPos = freshT("oldPos")
val len = freshT("len")
/**
* This is the case where the length is cheap to compute, either constant or easily computable from an
* instance.
*/
def withLenCalc(lenC: Tree) = q"""
val $len = $lenC
$outerbaos.writePosVarInt($len)
${t.put(outerbaos, element)}
"""
t.length(q"$element") match {
case _: ConstantLengthCalculation[_] =>
q"""${t.put(outerbaos, element)}"""
case f: FastLengthCalculation[_] =>
withLenCalc(f.asInstanceOf[FastLengthCalculation[c.type]].t)
case m: MaybeLengthCalculation[_] =>
val tmpLenRes = freshT("tmpLenRes")
q"""
if(skipLenCalc) {
noLengthWrite($element, $outerbaos)
} else {
val $tmpLenRes: _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.MaybeLength = payloadLength($element)
$tmpLenRes match {
case _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.NoLengthCalculation =>
failedLengthCalc()
noLengthWrite($element, $outerbaos)
case _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.ConstLen(const) =>
${withLenCalc(q"const")}
case _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen(s) =>
${withLenCalc(q"s")}
}
}
"""
case _ => q"noLengthWrite($element, $outerbaos)"
}
}
def readLength(inputStream: TermName) =
t.length(q"e") match {
case const: ConstantLengthCalculation[_] => q"${const.toInt}"
case _ => q"$inputStream.readPosVarInt"
}
def discardLength(inputStream: TermName) =
t.length(q"e") match {
case const: ConstantLengthCalculation[_] => q"()"
case _ => q"$inputStream.readPosVarInt"
}
val lazyVariables = t.lazyOuterVariables.map { case (n, t) =>
val termName = TermName(n)
q"""lazy val $termName = $t"""
}
val element = freshT("element")
val inputStreamA = freshT("inputStreamA")
val inputStreamB = freshT("inputStreamB")
val posStreamA = freshT("posStreamA")
val posStreamB = freshT("posStreamB")
val lenA = freshT("lenA")
val lenB = freshT("lenB")
t.ctx.Expr[OrderedSerialization[T]](q"""
new _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.MacroEqualityOrderedSerialization[$T] {
override val uniqueId: _root_.java.lang.String = ${T.tpe.toString}
private[this] var lengthCalculationAttempts: _root_.scala.Long = 0L
private[this] var couldNotLenCalc: _root_.scala.Long = 0L
private[this] var skipLenCalc:_root_.scala.Boolean = false
import _root_.com.twitter.scalding.serialization.JavaStreamEnrichments._
..$lazyVariables
override def compareBinary($inputStreamA: _root_.java.io.InputStream, $inputStreamB: _root_.java.io.InputStream): _root_.com.twitter.scalding.serialization.OrderedSerialization.Result =
try _root_.com.twitter.scalding.serialization.OrderedSerialization.resultFrom {
val $lenA = ${readLength(inputStreamA)}
val $lenB = ${readLength(inputStreamB)}
val $posStreamA = _root_.com.twitter.scalding.serialization.PositionInputStream($inputStreamA)
val initialPositionA = $posStreamA.position
val $posStreamB = _root_.com.twitter.scalding.serialization.PositionInputStream($inputStreamB)
val initialPositionB = $posStreamB.position
val innerR = ${t.compareBinary(posStreamA, posStreamB)}
$posStreamA.seekToPosition(initialPositionA + $lenA)
$posStreamB.seekToPosition(initialPositionB + $lenB)
innerR
} catch {
case _root_.scala.util.control.NonFatal(e) =>
_root_.com.twitter.scalding.serialization.OrderedSerialization.CompareFailure(e)
}
override def hash(passedInObjectToHash: $T): _root_.scala.Int = {
${t.hash(TermName("passedInObjectToHash"))}
}
private[this] def failedLengthCalc(): _root_.scala.Unit = {
couldNotLenCalc += 1L
if(lengthCalculationAttempts > 50 && (couldNotLenCalc.toDouble / lengthCalculationAttempts) > 0.4f) {
skipLenCalc = true
}
}
// What to do if we don't have a length calculation
$genNoLenCalc
// defines payloadLength private method
$innerLengthFn
// static size:
${binaryLengthGen(q"$T")._1}
// dynamic size:
${binaryLengthGen(q"$T")._2}
override def read(from: _root_.java.io.InputStream): _root_.scala.util.Try[$T] = {
try {
${discardLength(TermName("from"))}
_root_.scala.util.Success(${t.get(TermName("from"))})
} catch { case _root_.scala.util.control.NonFatal(e) =>
_root_.scala.util.Failure(e)
}
}
override def write(into: _root_.java.io.OutputStream, e: $T): _root_.scala.util.Try[Unit] = {
try {
${putFnGen(TermName("into"), TermName("e"))}
_root_.com.twitter.scalding.serialization.Serialization.successUnit
} catch { case _root_.scala.util.control.NonFatal(e) =>
_root_.scala.util.Failure(e)
}
}
override def compare(x: $T, y: $T): _root_.scala.Int = {
${t.compare(TermName("x"), TermName("y"))}
}
}
""")
}