java/thread_pool.h (51 lines of code) (raw):
/**
* Non-metric Space Library
*
* Authors: Bilegsaikhan Naidan (https://github.com/bileg), Leonid Boytsov
* (http://boytsov.info). With contributions from Lawrence Cayton
* (http://lcayton.com/) and others.
*
* For the complete list of contributors and further details see:
* https://github.com/searchivarius/NonMetricSpaceLib
*
* Copyright (c) 2017
*
* This code is released under the
* Apache License Version 2.0 http://www.apache.org/licenses/.
*
*/
#include <atomic>
#include <mutex>
#include <queue>
#include <thread>
namespace similarity {
// See sample usage below
template <class T>
bool GetNextQueueObj(std::mutex &mtx, std::queue<T> &queue, T &obj) {
std::unique_lock<std::mutex> lock(mtx);
if (queue.empty()) {
return false;
}
obj = queue.front();
queue.pop();
return true;
}
/*
Sample usage of helper function GetNextQueueObj:
queue<MSWNode*> toPatchQueue; // the job queue
for (MSWNode* node : toPatchNodes) toPatchQueue.push(node);
mutex mtx;
vector<thread> threads;
for (int i = 0; i < indexThreadQty_; ++i) {
threads.push_back(thread(
[&]() {
MSWNode* node = nullptr;
// get the next job from the queue
while(GetNextQueueObj(mtx, toPatchQueue, node)) {
node->removeGivenFriends(delNodesBitset);
}
}
));
}
// Don't forget to join!
for (auto& thread : threads) thread.join();
*/
/*
* replacement for the openmp '#pragma omp parallel for' directive
* only handles a subset of functionality (no reductions etc)
* Process ids from start (inclusive) to end (EXCLUSIVE)
*/
template <class Function>
inline void ParallelFor(size_t start, size_t end, size_t numThreads,
Function fn) {
if (numThreads <= 0) {
numThreads = std::thread::hardware_concurrency();
}
std::vector<std::thread> threads;
std::atomic<size_t> current(start);
// keep track of exceptions in threads
// https://stackoverflow.com/a/32428427/1713196
std::exception_ptr lastException = nullptr;
std::mutex lastExceptMutex;
for (size_t i = 0; i < numThreads; ++i) {
threads.push_back(std::thread([&] {
while (true) {
size_t id = current.fetch_add(1);
if ((id >= end)) {
break;
}
try {
fn(id);
} catch (...) {
std::unique_lock<std::mutex> lastExcepLock(lastExceptMutex);
lastException = std::current_exception();
/*
* This will work even when current is the largest value that
* size_t can fit, because fetch_add returns the previous value
* before the increment (what will result in overflow
* and produce 0 instead of current + 1).
*/
current = end;
break;
}
}
}));
}
for (auto &thread : threads) {
thread.join();
}
if (lastException) {
std::rethrow_exception(lastException);
}
}
}; // namespace similarity