in src/main/java/com/twitter/sbf/core/MHAlgorithm.java [479:548]
private Map<String, Object> getMetrics(SparseBinaryMatrix matrix,
int epoch,
double elapsedSeconds,
double acceptRate,
double temperature,
boolean doExpensiveEval) {
Graph graph = g;
ImmutableMap.Builder<String, Object> metricsMap = new ImmutableMap.Builder<>();
metricsMap.put("epoch", epoch);
metricsMap.put("nnz/vertex", matrix.nnzPerRow());
metricsMap.put("emptyRow", matrix.emptyRowProportion());
metricsMap.put("emptyCol", matrix.emptyColProportion());
metricsMap.put("minSize", matrix.minColSizeAboveZero());
metricsMap.put("maxSize", matrix.maxColSize());
metricsMap.put("totalSec", elapsedSeconds);
metricsMap.put("acceptRate", acceptRate);
metricsMap.put("temp", temperature);
metricsMap.put("emptyProposal", emptyProposals * 1.0 / graph.getNumVertices());
metricsMap.put("sameProposal", sameProposals * 1.0 / graph.getNumVertices());
if (doExpensiveEval) {
long tic = System.nanoTime();
// Evaluate prediction metrics
PredictionStat stat;
if (graph.getNumEdges() < 5000L) {
int subsetSize =
(config.evalRatio < 1.0) ? ((int) (graph.getNumVertices() * config.evalRatio))
: graph.getNumVertices();
int[] evalVertexIds =
(config.evalRatio < 1.0) ? Util.reservoirSampling(graph.getNumVertices(), subsetSize)
: graph.getAllVertexIds();
// Parallel execution using Executor
ExecutorService executor = Executors.newFixedThreadPool(config.cpu);
stat = getPredictionStatVertexSampling(graph, matrix, executor, evalVertexIds);
executor.shutdown();
try {
while (!executor.isTerminated()) {
executor.awaitTermination(20, TimeUnit.MINUTES);
}
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
metricsMap.put("orphans", stat.orphanRate());
metricsMap.put("prec", stat.precision() * 100);
metricsMap.put("rec", stat.recall() * 100);
metricsMap.put("f1", stat.f1() * 100);
metricsMap.put("wtPrec", stat.weightedPrecision() * 100);
metricsMap.put("wtRec", stat.weightedRecall() * 100);
metricsMap.put("wtF1", stat.weightedF1() * 100);
} else {
Map<String, PredictionStat> ss = getPredictionStatEdgeSampling(graph, matrix, config.rng,
100000, 1000000, 5000);
double precision = ss.get("precision").precision();
double recall = ss.get("recall").recall();
double wtPrec = ss.get("precision").weightedPrecision();
double wtRec = ss.get("recall").weightedRecall();
metricsMap.put("orphans", ss.get("orphans").orphanRate());
metricsMap.put("prec", precision * 100);
metricsMap.put("rec", recall * 100);
metricsMap.put("f1", 200 * precision * recall / (precision + recall));
metricsMap.put("wtPrec", wtPrec * 100);
metricsMap.put("wtRec", wtRec * 100);
metricsMap.put("wtF1", 200 * wtPrec * wtRec / (wtPrec + wtRec));
}
long toc = System.nanoTime();
double evalTime = (toc - tic) * 1e-9;
metricsMap.put("evalSec", evalTime);
}
return metricsMap.build();
}