in src/main/java/com/twitter/sbf/core/MHAlgorithm.java [279:403]
private int runEpoch(SparseBinaryMatrix matrix, final int epoch, boolean doMetrics) {
double wtCoeff = config.wtCoeff;
long tic = System.nanoTime();
emptyProposals = 0;
sameProposals = 0;
double temp = Math.pow(config.temperatureRatio, epoch) * config.maxTemperature;
double multiplier = config.useTemperatureSchedule ? 1.0 / temp : config.scaleCoeff;
int n = g.getNumVertices();
Integer[] shuffledIds = shuffleVertexIds();
final Hashtable<Integer, Optional<int[]>> updates = new Hashtable<>(n, 1.0f);
int acceptCount = 0;
if (config.cpu <= 1) {
acceptCount = runEpochSerial(matrix, shuffledIds, multiplier, wtCoeff);
} else {
/*
* All of the below complicated logic is complicated because it's multi-threaded and there's
* a few different options to enable/disable when running multi-threaded;
* the essential logic is the same as in runEpochSerial.
*/
long maxWork = 0;
long minWork = Long.MAX_VALUE;
int numTasks = Math.max(1, n / 1000);
ArrayList<Future<Integer>> futures = new ArrayList<>(numTasks);
ExecutorService exec = Executors.newWorkStealingPool(config.cpu);
for (Range range : Util.chunkArray(n, numTasks)) {
long workHere = 0;
for (int i = range.start; i < range.end; i++) {
workHere += g.getDegree(shuffledIds[i]);
}
if (workHere > maxWork) {
maxWork = workHere;
}
if (workHere < minWork && workHere > 0) {
minWork = workHere;
}
Future<Integer> fut = exec.submit(() -> {
int accepted = 0;
for (int i = range.start; i < range.end; i++) {
int vertexId = shuffledIds[i];
if (!config.updateImmediately) {
Optional<int[]> newRow = mhStep(matrix, vertexId, multiplier, wtCoeff);
if (newRow.isPresent()) {
accepted++;
}
updates.put(vertexId, newRow);
} else {
if (!config.noLocking) {
IntSet colsToLock = lockAndReturnLockedColumns(matrix, vertexId);
try {
Optional<int[]> newRow =
mhStep(matrix, vertexId, multiplier, wtCoeff);
if (newRow.isPresent()) {
matrix.updateRow(vertexId, newRow.get());
accepted++;
}
} finally {
for (int colId : colsToLock) {
colLocks[colId].unlock();
}
}
} else { // HogWild!!
Optional<int[]> newRow = mhStep(matrix, vertexId, multiplier, wtCoeff);
if (newRow.isPresent()) {
matrix.updateRow(vertexId, newRow.get());
accepted++;
}
}
}
}
return accepted;
});
futures.add(fut);
}
if (maxWork / minWork >= 2) {
diagnosticsWriter.format("epoch is %d, maxWork is %d, minWork is %d\n",
epoch, maxWork, minWork);
diagnosticsWriter.flush();
}
exec.shutdown();
try {
while (!exec.isTerminated()) {
exec.awaitTermination(1200, TimeUnit.SECONDS);
if (!exec.isTerminated()) {
diagnosticsWriter.println("epoch is " + epoch + ", going to wait for 20 more minutes, "
+ "executor.isTerminated " + exec.isTerminated());
diagnosticsWriter.flush();
}
}
for (Future<Integer> fut : futures) {
acceptCount += fut.get();
}
} catch (InterruptedException | ExecutionException e) {
e.printStackTrace(diagnosticsWriter);
}
if (!config.updateImmediately) {
for (Map.Entry<Integer, Optional<int[]>> update : updates.entrySet()) {
if (update.getValue().isPresent()) {
matrix.updateRow(update.getKey(), update.getValue().get());
}
}
}
}
if (doMetrics) {
double acceptRate = acceptCount * 1.0 / n;
long toc = System.nanoTime();
double timeInSecs = (toc - tic) * 1e-9;
diagnosticsWriter.println(
getMetricsLine(matrix, epoch, timeInSecs, acceptRate, 1.0 / multiplier,
epoch % config.evalEvery == 0)
);
diagnosticsWriter.flush();
}
return acceptCount;
}