in src/main/scala/com/twitter/stitch/Arrow.scala [1582:1735]
override def run[T2 <: (Arrow[Any, Any], Any), V](
ts: ArrayBuffer[Try[T2]],
ls: ArrayBuffer[Locals],
tail: Arrow[Any, V]
): Stitch[ArrayBuffer[Try[V]]] = {
val len = ts.length
assert(ls.length == len)
if (ts.isEmpty) return tail.run(ts.asInstanceOf[ArrayBuffer[Try[Any]]], ls)
val tts = ts.asInstanceOf[ArrayBuffer[Try[(Arrow[Any, Any], Any)]]]
/** contains the arrow of the first input if head is a Return, and null if head is a Throw */
val firstArrow: Arrow[Any, Any] = tts.head match {
case Return((a, _)) => a
case Throw(_) => null
}
/**
* true if all inputs are Return((arrow, _)), where arrow is the same for all inputs,
* or true if all inputs are Throws, otherwise false
*/
val allSame = tts.forall {
case Return((a, _)) => firstArrow == a
case Throw(_) => firstArrow == null
}
if (allSame && firstArrow != null) { // all inputs have the same arrow and are Returns
val args = TryBuffer.mapTry(tts) {
case Return((_, arg)) => Return(arg)
case _ =>
throw new IllegalStateException(
"This should not be possible since firstArrow should only be populated if all args are Returns"
)
}
firstArrow.run(args, ls, tail)
} else if (allSame) { // all inputs are Throws
tail.run(tts, ls)
} else {
/**
* Input contains arrows that are not all the same or is a mix of Returns and Throws.
* Inputs are grouped so that all inputs for a given arrow are batched together
* and all inputs that are throws are batched together. This strategy ensures optimal batching.
*
* This grouping process necessarily shuffles the ordering of the inputs,
* so when the batches are run they need to be unshuffled so they are back in the original order.
*/
val countByArrow = new java.util.HashMap[Arrow[Any, Any], Int]()
/** count the number of inputs for each arrow */
tts.foreach {
case Return((a, _)) => countByArrow.merge(a, 1, _ + _)
case Throw(_) => countByArrow.merge(identityInstance, 1, _ + _)
}
// Map[Arrow -> (Indices, Args, Locals)], reuse the same Hashmap
val groupedByArrow = countByArrow.asInstanceOf[java.util.HashMap[
Arrow[Any, Any],
(ArrayBuffer[Int], ArrayBuffer[Try[Any]], ArrayBuffer[Locals])
]]
// Map[Arrow -> [Int | (Indices, Args, Locals)]], reuse the same Hashmap,
// but since theres no type unions just do a pattern match on an Any
val unionOfCountAndGrouped =
countByArrow.asInstanceOf[java.util.HashMap[Arrow[Any, Any], Any]]
// groupBy, store result in groupedByArrow
var i = 0
while (i < len) {
ts(i) match {
case Return((arrow, arg)) =>
unionOfCountAndGrouped.get(arrow) match {
case size: Int =>
groupedByArrow.put(
arrow,
(
new ArrayBuffer[Int](size) += i,
new ArrayBuffer[Try[Any]](size) += Return(arg),
new ArrayBuffer[Locals](size) += ls(i)
)
)
case (
indices: ArrayBuffer[Int @unchecked],
args: ArrayBuffer[Try[Any] @unchecked],
locals: ArrayBuffer[Locals @unchecked]) =>
indices += i
args += Return(arg)
locals += ls(i)
case arg =>
throw new IllegalArgumentException(
s"Expected either Int or (ArrayBuffer, ArrayBuffer, ArrayBuffer) but got $arg")
}
case t: Try[_] =>
unionOfCountAndGrouped.get(identityInstance) match {
case size: Int =>
groupedByArrow.put(
identityInstance,
(
new ArrayBuffer[Int](size) += i,
new ArrayBuffer[Try[Any]](size) += t,
new ArrayBuffer[Locals](size) += ls(i)
)
)
case (
indices: ArrayBuffer[Int @unchecked],
args: ArrayBuffer[Try[Any] @unchecked],
locals: ArrayBuffer[Locals @unchecked]) =>
indices += i
args += t
locals += ls(i)
case arg =>
throw new IllegalArgumentException(
s"Expected either Int or (ArrayBuffer, ArrayBuffer, ArrayBuffer) but got $arg")
}
}
i += 1
}
val ss = ts.asInstanceOf[ArrayBuffer[Stitch[Any]]]
val groupedIterator = groupedByArrow.entrySet().iterator()
while (groupedIterator.hasNext) {
val n = groupedIterator.next()
val arrow = n.getKey
val (indices, args, locals) = n.getValue
// run optimal batches
// We return a ref here to ensure that the Stitch returned by `arrow`
// is run only once. That way, accessing the result at different indexes
// in the loop below will not rerun the entire Stitch.
// This prevents a very specific class of bugs when arrow is a long chain of arrows
// which have not yet been resolved from being resolved early.
// Because the grouped arrows are resolved at the same time it means the second iteration
// of the loop - which is expecting an unresolved arrow - finds a resolved one instead
// and you get a ClassCastException. This happens in the while loop below.
// We were not able to create a unit test in Stitch/Arrow that could reproduce but there is
// a test case in Strato which does: STTR-6433 regression test in
// strato/src/test/scala/com/twitter/strato/rpc/ServerTest.scala
val stitches = Stitch.ref(arrow.run(args, locals))
// reshuffle the results back into the original order
var i = 0
while (i < indices.length) {
// save the current value of var, `i` because accessing the Stitch at the
// index is async and using `i` could result in the wrong value when finally read
val currIndex = i
ss(indices(i)) = stitches.map(buf => buf(currIndex)).lowerFromTry
i += 1
}
}
Stitch.transformSeq(ss, ls, tail)
}
}