private int runEpoch()

in src/main/java/com/twitter/sbf/core/MHAlgorithm.java [279:403]


  private int runEpoch(SparseBinaryMatrix matrix, final int epoch, boolean doMetrics) {
    double wtCoeff = config.wtCoeff;

    long tic = System.nanoTime();
    emptyProposals = 0;
    sameProposals = 0;
    double temp = Math.pow(config.temperatureRatio, epoch) * config.maxTemperature;
    double multiplier = config.useTemperatureSchedule ? 1.0 / temp : config.scaleCoeff;

    int n = g.getNumVertices();

    Integer[] shuffledIds = shuffleVertexIds();

    final Hashtable<Integer, Optional<int[]>> updates = new Hashtable<>(n, 1.0f);

    int acceptCount = 0;
    if (config.cpu <= 1) {
      acceptCount = runEpochSerial(matrix, shuffledIds, multiplier, wtCoeff);
    } else {
      /*
       * All of the below complicated logic is complicated because it's multi-threaded and there's
       * a few different options to enable/disable when running multi-threaded;
       * the essential logic is the same as in runEpochSerial.
       */
      long maxWork = 0;
      long minWork = Long.MAX_VALUE;
      int numTasks = Math.max(1, n / 1000);
      ArrayList<Future<Integer>> futures = new ArrayList<>(numTasks);
      ExecutorService exec = Executors.newWorkStealingPool(config.cpu);
      for (Range range : Util.chunkArray(n, numTasks)) {
        long workHere = 0;
        for (int i = range.start; i < range.end; i++) {
          workHere += g.getDegree(shuffledIds[i]);
        }
        if (workHere > maxWork) {
          maxWork = workHere;
        }
        if (workHere < minWork && workHere > 0) {
          minWork = workHere;
        }
        Future<Integer> fut = exec.submit(() -> {
          int accepted = 0;
          for (int i = range.start; i < range.end; i++) {
            int vertexId = shuffledIds[i];
            if (!config.updateImmediately) {
              Optional<int[]> newRow = mhStep(matrix, vertexId, multiplier, wtCoeff);
              if (newRow.isPresent()) {
                accepted++;
              }
              updates.put(vertexId, newRow);
            } else {
              if (!config.noLocking) {
                IntSet colsToLock = lockAndReturnLockedColumns(matrix, vertexId);
                try {
                  Optional<int[]> newRow =
                      mhStep(matrix, vertexId, multiplier, wtCoeff);
                  if (newRow.isPresent()) {
                    matrix.updateRow(vertexId, newRow.get());
                    accepted++;
                  }
                } finally {
                  for (int colId : colsToLock) {
                    colLocks[colId].unlock();
                  }
                }
              } else { // HogWild!!
                Optional<int[]> newRow = mhStep(matrix, vertexId, multiplier, wtCoeff);
                if (newRow.isPresent()) {
                  matrix.updateRow(vertexId, newRow.get());
                  accepted++;
                }
              }
            }
          }
          return accepted;
        });
        futures.add(fut);
      }

      if (maxWork / minWork >= 2) {
        diagnosticsWriter.format("epoch is %d, maxWork is %d, minWork is %d\n",
            epoch, maxWork, minWork);
        diagnosticsWriter.flush();
      }

      exec.shutdown();
      try {
        while (!exec.isTerminated()) {
          exec.awaitTermination(1200, TimeUnit.SECONDS);
          if (!exec.isTerminated()) {
            diagnosticsWriter.println("epoch is " + epoch + ", going to wait for 20 more minutes, "
                + "executor.isTerminated " + exec.isTerminated());
            diagnosticsWriter.flush();
          }
        }
        for (Future<Integer> fut : futures) {
          acceptCount += fut.get();
        }
      } catch (InterruptedException | ExecutionException e) {
        e.printStackTrace(diagnosticsWriter);
      }

      if (!config.updateImmediately) {
        for (Map.Entry<Integer, Optional<int[]>> update : updates.entrySet()) {
          if (update.getValue().isPresent()) {
            matrix.updateRow(update.getKey(), update.getValue().get());
          }
        }
      }
    }

    if (doMetrics) {
      double acceptRate = acceptCount * 1.0 / n;

      long toc = System.nanoTime();
      double timeInSecs = (toc - tic) * 1e-9;
      diagnosticsWriter.println(
          getMetricsLine(matrix, epoch, timeInSecs, acceptRate, 1.0 / multiplier,
              epoch % config.evalEvery == 0)
      );
      diagnosticsWriter.flush();
    }

    return acceptCount;
  }