override def run[T2 <:()

in src/main/scala/com/twitter/stitch/Arrow.scala [1582:1735]


    override def run[T2 <: (Arrow[Any, Any], Any), V](
      ts: ArrayBuffer[Try[T2]],
      ls: ArrayBuffer[Locals],
      tail: Arrow[Any, V]
    ): Stitch[ArrayBuffer[Try[V]]] = {
      val len = ts.length
      assert(ls.length == len)

      if (ts.isEmpty) return tail.run(ts.asInstanceOf[ArrayBuffer[Try[Any]]], ls)

      val tts = ts.asInstanceOf[ArrayBuffer[Try[(Arrow[Any, Any], Any)]]]

      /** contains the arrow of the first input if head is a Return, and null if head is a Throw */
      val firstArrow: Arrow[Any, Any] = tts.head match {
        case Return((a, _)) => a
        case Throw(_) => null
      }

      /**
       * true if all inputs are Return((arrow, _)), where arrow is the same for all inputs,
       * or true if all inputs are Throws, otherwise false
       */
      val allSame = tts.forall {
        case Return((a, _)) => firstArrow == a
        case Throw(_) => firstArrow == null
      }

      if (allSame && firstArrow != null) { // all inputs have the same arrow and are Returns
        val args = TryBuffer.mapTry(tts) {
          case Return((_, arg)) => Return(arg)
          case _ =>
            throw new IllegalStateException(
              "This should not be possible since firstArrow should only be populated if all args are Returns"
            )
        }
        firstArrow.run(args, ls, tail)
      } else if (allSame) { // all inputs are Throws
        tail.run(tts, ls)
      } else {

        /**
         * Input contains arrows that are not all the same or is a mix of Returns and Throws.
         * Inputs are grouped so that all inputs for a given arrow are batched together
         * and all inputs that are throws are batched together. This strategy ensures optimal batching.
         *
         * This grouping process necessarily shuffles the ordering of the inputs,
         * so when the batches are run they need to be unshuffled so they are back in the original order.
         */
        val countByArrow = new java.util.HashMap[Arrow[Any, Any], Int]()

        /** count the number of inputs for each arrow */
        tts.foreach {
          case Return((a, _)) => countByArrow.merge(a, 1, _ + _)
          case Throw(_) => countByArrow.merge(identityInstance, 1, _ + _)
        }

        // Map[Arrow -> (Indices, Args, Locals)], reuse the same Hashmap
        val groupedByArrow = countByArrow.asInstanceOf[java.util.HashMap[
          Arrow[Any, Any],
          (ArrayBuffer[Int], ArrayBuffer[Try[Any]], ArrayBuffer[Locals])
        ]]

        // Map[Arrow -> [Int | (Indices, Args, Locals)]], reuse the same Hashmap,
        // but since theres no type unions just do a pattern match on an Any
        val unionOfCountAndGrouped =
          countByArrow.asInstanceOf[java.util.HashMap[Arrow[Any, Any], Any]]

        // groupBy, store result in groupedByArrow
        var i = 0
        while (i < len) {
          ts(i) match {
            case Return((arrow, arg)) =>
              unionOfCountAndGrouped.get(arrow) match {
                case size: Int =>
                  groupedByArrow.put(
                    arrow,
                    (
                      new ArrayBuffer[Int](size) += i,
                      new ArrayBuffer[Try[Any]](size) += Return(arg),
                      new ArrayBuffer[Locals](size) += ls(i)
                    )
                  )
                case (
                      indices: ArrayBuffer[Int @unchecked],
                      args: ArrayBuffer[Try[Any] @unchecked],
                      locals: ArrayBuffer[Locals @unchecked]) =>
                  indices += i
                  args += Return(arg)
                  locals += ls(i)
                case arg =>
                  throw new IllegalArgumentException(
                    s"Expected either Int or (ArrayBuffer, ArrayBuffer, ArrayBuffer) but got $arg")
              }

            case t: Try[_] =>
              unionOfCountAndGrouped.get(identityInstance) match {
                case size: Int =>
                  groupedByArrow.put(
                    identityInstance,
                    (
                      new ArrayBuffer[Int](size) += i,
                      new ArrayBuffer[Try[Any]](size) += t,
                      new ArrayBuffer[Locals](size) += ls(i)
                    )
                  )
                case (
                      indices: ArrayBuffer[Int @unchecked],
                      args: ArrayBuffer[Try[Any] @unchecked],
                      locals: ArrayBuffer[Locals @unchecked]) =>
                  indices += i
                  args += t
                  locals += ls(i)
                case arg =>
                  throw new IllegalArgumentException(
                    s"Expected either Int or (ArrayBuffer, ArrayBuffer, ArrayBuffer) but got $arg")
              }
          }
          i += 1
        }

        val ss = ts.asInstanceOf[ArrayBuffer[Stitch[Any]]]

        val groupedIterator = groupedByArrow.entrySet().iterator()
        while (groupedIterator.hasNext) {
          val n = groupedIterator.next()
          val arrow = n.getKey
          val (indices, args, locals) = n.getValue
          // run optimal batches
          // We return a ref here to ensure that the Stitch returned by `arrow`
          // is run only once.  That way, accessing the result at different indexes
          // in the loop below will not rerun the entire Stitch.
          // This prevents a very specific class of bugs when arrow is a long chain of arrows
          // which have not yet been resolved from being resolved early.
          // Because the grouped arrows are resolved at the same time it means the second iteration
          // of the loop - which is expecting an unresolved arrow - finds a resolved one instead
          // and you get a ClassCastException.  This happens in the while loop below.
          // We were not able to create a unit test in Stitch/Arrow that could reproduce but there is
          // a test case in Strato which does: STTR-6433 regression test in
          // strato/src/test/scala/com/twitter/strato/rpc/ServerTest.scala
          val stitches = Stitch.ref(arrow.run(args, locals))
          // reshuffle the results back into the original order
          var i = 0
          while (i < indices.length) {
            // save the current value of var, `i` because accessing the Stitch at the
            // index is async and using `i` could result in the wrong value when finally read
            val currIndex = i
            ss(indices(i)) = stitches.map(buf => buf(currIndex)).lowerFromTry
            i += 1
          }
        }

        Stitch.transformSeq(ss, ls, tail)
      }
    }