def toOrderedSerialization[T]()

in scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/TreeOrderedBuf.scala [63:312]


  def toOrderedSerialization[T](
      c: Context
  )(t: TreeOrderedBuf[c.type])(implicit T: t.ctx.WeakTypeTag[T]): t.ctx.Expr[OrderedSerialization[T]] = {
    import t.ctx.universe._
    def freshT(id: String) = TermName(c.freshName(s"fresh_$id"))
    val outputLength = freshT("outputLength")

    val innerLengthFn: Tree = {
      val element = freshT("element")

      val fnBodyOpt = t.length(q"$element") match {
        case _: NoLengthCalculationAvailable[_]  => None
        case const: ConstantLengthCalculation[_] => None
        case f: FastLengthCalculation[_] =>
          Some(q"""
        _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen(${f
              .asInstanceOf[FastLengthCalculation[c.type]]
              .t})
        """)
        case m: MaybeLengthCalculation[_] => Some(m.asInstanceOf[MaybeLengthCalculation[c.type]].t)
      }

      fnBodyOpt
        .map { fnBody =>
          q"""
        private[this] def payloadLength($element: $T): _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.MaybeLength = {
          lengthCalculationAttempts += 1
          $fnBody
        }
        """
        }
        .getOrElse(q"()")
    }

    def binaryLengthGen(typeName: Tree): (Tree, Tree) = {
      val tempLen = freshT("tempLen")
      val lensLen = freshT("lensLen")
      val element = freshT("element")
      val callDynamic = (
        q"""override def staticSize: _root_.scala.Option[_root_.scala.Int] = _root_.scala.None""",
        q"""

      override def dynamicSize($element: $typeName): _root_.scala.Option[_root_.scala.Int] = {
        if(skipLenCalc) _root_.scala.None else {
          val $tempLen = payloadLength($element) match {
            case _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.NoLengthCalculation =>
              failedLengthCalc()
              _root_.scala.None
            case _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.ConstLen(l) => _root_.scala.Some(l)
            case _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen(l) => _root_.scala.Some(l)
          }
          (if ($tempLen.isDefined) {
            // Avoid a closure here while we are geeking out
            val innerLen = $tempLen.get
            val $lensLen = posVarIntSize(innerLen)
            _root_.scala.Some(innerLen + $lensLen)
         } else _root_.scala.None): _root_.scala.Option[_root_.scala.Int]
      }
     }
      """
      )

      t.length(q"$element") match {
        case _: NoLengthCalculationAvailable[_] =>
          (
            q"""
          override def staticSize: _root_.scala.Option[_root_.scala.Int] = _root_.scala.None""",
            q"""
          override def dynamicSize($element: $typeName): _root_.scala.Option[_root_.scala.Int] = _root_.scala.None"""
          )
        case const: ConstantLengthCalculation[_] =>
          (
            q"""
          override val staticSize: _root_.scala.Option[_root_.scala.Int] = _root_.scala.Some(${const.toInt})""",
            q"""
          override def dynamicSize($element: $typeName): _root_.scala.Option[_root_.scala.Int] = staticSize"""
          )
        case f: FastLengthCalculation[_]  => callDynamic
        case m: MaybeLengthCalculation[_] => callDynamic
      }
    }

    def genNoLenCalc = {
      val baos = freshT("baos")
      val element = freshT("element")
      val outerOutputStream = freshT("os")
      val len = freshT("len")

      /**
       * This is the worst case: we have to serialize in a side buffer and then see how large it actually is.
       * This happens for cases, like string, where the cost to see the serialized size is not cheaper than
       * directly serializing.
       */
      q"""
      private[this] def noLengthWrite($element: $T, $outerOutputStream: _root_.java.io.OutputStream): Unit = {
        // Start with pretty big buffers because reallocation will be expensive
        val $baos = new _root_.java.io.ByteArrayOutputStream(512)
        ${t.put(baos, element)}
        val $len = $baos.size
        $outerOutputStream.writePosVarInt($len)
        $baos.writeTo($outerOutputStream)
      }
      """
    }

    def putFnGen(outerbaos: TermName, element: TermName) = {
      val oldPos = freshT("oldPos")
      val len = freshT("len")

      /**
       * This is the case where the length is cheap to compute, either constant or easily computable from an
       * instance.
       */
      def withLenCalc(lenC: Tree) = q"""
        val $len = $lenC
        $outerbaos.writePosVarInt($len)
        ${t.put(outerbaos, element)}
      """

      t.length(q"$element") match {
        case _: ConstantLengthCalculation[_] =>
          q"""${t.put(outerbaos, element)}"""
        case f: FastLengthCalculation[_] =>
          withLenCalc(f.asInstanceOf[FastLengthCalculation[c.type]].t)
        case m: MaybeLengthCalculation[_] =>
          val tmpLenRes = freshT("tmpLenRes")
          q"""
            if(skipLenCalc) {
              noLengthWrite($element, $outerbaos)
            } else {
              val $tmpLenRes: _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.MaybeLength = payloadLength($element)
              $tmpLenRes match {
                case _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.NoLengthCalculation =>
                  failedLengthCalc()
                  noLengthWrite($element, $outerbaos)
                case _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.ConstLen(const) =>
                  ${withLenCalc(q"const")}
                case _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen(s) =>
                  ${withLenCalc(q"s")}
              }
            }
        """
        case _ => q"noLengthWrite($element, $outerbaos)"
      }
    }

    def readLength(inputStream: TermName) =
      t.length(q"e") match {
        case const: ConstantLengthCalculation[_] => q"${const.toInt}"
        case _                                   => q"$inputStream.readPosVarInt"
      }

    def discardLength(inputStream: TermName) =
      t.length(q"e") match {
        case const: ConstantLengthCalculation[_] => q"()"
        case _                                   => q"$inputStream.readPosVarInt"
      }

    val lazyVariables = t.lazyOuterVariables.map { case (n, t) =>
      val termName = TermName(n)
      q"""lazy val $termName = $t"""
    }

    val element = freshT("element")

    val inputStreamA = freshT("inputStreamA")
    val inputStreamB = freshT("inputStreamB")
    val posStreamA = freshT("posStreamA")
    val posStreamB = freshT("posStreamB")

    val lenA = freshT("lenA")
    val lenB = freshT("lenB")

    t.ctx.Expr[OrderedSerialization[T]](q"""
      new _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.MacroEqualityOrderedSerialization[$T] {
        override val uniqueId: _root_.java.lang.String = ${T.tpe.toString}

        private[this] var lengthCalculationAttempts: _root_.scala.Long = 0L
        private[this] var couldNotLenCalc: _root_.scala.Long = 0L
        private[this] var skipLenCalc:_root_.scala.Boolean = false

        import _root_.com.twitter.scalding.serialization.JavaStreamEnrichments._
        ..$lazyVariables

        override def compareBinary($inputStreamA: _root_.java.io.InputStream, $inputStreamB: _root_.java.io.InputStream): _root_.com.twitter.scalding.serialization.OrderedSerialization.Result =
          try _root_.com.twitter.scalding.serialization.OrderedSerialization.resultFrom {
            val $lenA = ${readLength(inputStreamA)}
            val $lenB = ${readLength(inputStreamB)}
            val $posStreamA = _root_.com.twitter.scalding.serialization.PositionInputStream($inputStreamA)
            val initialPositionA = $posStreamA.position
            val $posStreamB = _root_.com.twitter.scalding.serialization.PositionInputStream($inputStreamB)
            val initialPositionB = $posStreamB.position

            val innerR = ${t.compareBinary(posStreamA, posStreamB)}

            $posStreamA.seekToPosition(initialPositionA + $lenA)
            $posStreamB.seekToPosition(initialPositionB + $lenB)
            innerR
          } catch {
            case _root_.scala.util.control.NonFatal(e) =>
              _root_.com.twitter.scalding.serialization.OrderedSerialization.CompareFailure(e)
          }

        override def hash(passedInObjectToHash: $T): _root_.scala.Int = {
          ${t.hash(TermName("passedInObjectToHash"))}
        }

        private[this] def failedLengthCalc(): _root_.scala.Unit = {
          couldNotLenCalc += 1L
          if(lengthCalculationAttempts > 50 && (couldNotLenCalc.toDouble / lengthCalculationAttempts) > 0.4f) {
            skipLenCalc = true
          }
        }

        // What to do if we don't have a length calculation
        $genNoLenCalc

        // defines payloadLength private method
        $innerLengthFn

        // static size:
        ${binaryLengthGen(q"$T")._1}

        // dynamic size:
        ${binaryLengthGen(q"$T")._2}

        override def read(from: _root_.java.io.InputStream): _root_.scala.util.Try[$T] = {
          try {
              ${discardLength(TermName("from"))}
             _root_.scala.util.Success(${t.get(TermName("from"))})
          } catch { case _root_.scala.util.control.NonFatal(e) =>
            _root_.scala.util.Failure(e)
          }
        }

        override def write(into: _root_.java.io.OutputStream, e: $T): _root_.scala.util.Try[Unit] = {
          try {
              ${putFnGen(TermName("into"), TermName("e"))}
              _root_.com.twitter.scalding.serialization.Serialization.successUnit
          } catch { case _root_.scala.util.control.NonFatal(e) =>
            _root_.scala.util.Failure(e)
          }
        }

        override def compare(x: $T, y: $T): _root_.scala.Int = {
          ${t.compare(TermName("x"), TermName("y"))}
        }
      }
    """)
  }