in util-core/src/main/scala/com/twitter/util/Memoize.scala [73:166]
def apply[A, B](f: A => B): A => B = snappable[A, B](f)
/**
* Thread-safe memoization for a Function2.
*/
def function2[A, B, C](f: (A, B) => C): (A, B) => C = scala.Function.untupled(apply(f.tupled))
/**
* Produces [[com.twitter.util.Memoize.Snappable]], thread-safe
* memoization for a function.
*/
def snappable[A, B](f: A => B): Snappable[A, B] =
new Snappable[A, B] {
private[this] var memo = Map.empty[A, Either[GuardedCountDownLatch, B]]
override def size: Long = memo.size
def snap: Map[A, B] =
synchronized(memo) collect {
case (a, Right(b)) => (a, b)
}
/**
* What to do if we do not find the value already in the memo
* table.
*/
@tailrec private[this] def missing(a: A): B =
synchronized {
// With the lock, check to see what state the value is in.
memo.get(a) match {
case None =>
// If it's missing, then claim the slot by putting in a
// CountDownLatch that will be completed when the value is
// available.
val latch = new GuardedCountDownLatch(1)
memo = memo + (a -> Left(latch))
// The latch wrapped in Left indicates that the value
// needs to be computed in this thread, and then the
// latch counted down.
Left(latch)
case Some(other) =>
// This is either the latch that will indicate that the
// work has been done, or the computed value.
Right(other)
}
} match {
case Right(Right(b)) =>
// The computation is already done.
b
case Right(Left(latch)) =>
// Someone else is doing the computation.
latch.await()
// This recursive call will happen when there is an
// exception computing the value, or if the value is
// currently being computed.
missing(a)
case Left(latch) =>
// Compute the value outside of the synchronized block.
val b =
try {
f(a)
} catch {
case t: Throwable =>
// If there was an exception running the
// computation, then we need to make sure we do not
// starve any waiters before propagating the
// exception.
synchronized { memo = memo - a }
latch.countDown()
throw t
}
// Update the memo table to indicate that the work has
// been done, and signal to any waiting threads that the
// work is complete.
synchronized { memo = memo + (a -> Right(b)) }
latch.countDown()
b
}
override def apply(a: A): B =
// Look in the (possibly stale) memo table. If the value is
// present, then it is guaranteed to be the final value. If it
// is absent, call missing() to determine what to do.
memo.get(a) match {
case Some(Right(b)) => b
case _ => missing(a)
}
}