in java/com_spotify_voyager_jni_Index.cpp [480:547]
jobjectArray Java_com_spotify_voyager_jni_Index_query___3_3FIIJ(
JNIEnv *env, jobject self, jobjectArray queryVectors, jint numNeighbors,
jint numThreads, jlong queryEf) {
try {
std::shared_ptr<Index> index = getHandle<Index>(env, self);
int numQueries = env->GetArrayLength(queryVectors);
std::tuple<NDArray<hnswlib::labeltype, 2>, NDArray<float, 2>> queryResults =
index->query(toNDArray(env, queryVectors), numNeighbors, numThreads,
queryEf);
jclass queryResultsClass =
env->FindClass("com/spotify/voyager/jni/Index$QueryResults");
if (!queryResultsClass) {
throw std::runtime_error(
"C++ bindings failed to find QueryResults class.");
}
jmethodID constructor =
env->GetMethodID(queryResultsClass, "<init>", "([J[F)V");
if (!constructor) {
throw std::runtime_error(
"C++ bindings failed to find QueryResults constructor.");
}
jobjectArray javaQueryResults =
env->NewObjectArray(numQueries, queryResultsClass, NULL);
for (int i = 0; i < numQueries; i++) {
// Allocate a Java long array for the indices, and a float array for the
// distances:
jlongArray labels = env->NewLongArray(numNeighbors);
// queryResults is a (size_t *), but labels is a signed (long *).
// This may overflow if we have more than... 2^63 = 9.223372037e18
// elements. We're probably safe doing this.
env->SetLongArrayRegion(labels, 0, numNeighbors,
(jlong *)std::get<0>(queryResults)[i]);
jfloatArray distances = env->NewFloatArray(numNeighbors);
env->SetFloatArrayRegion(distances, 0, numNeighbors,
std::get<1>(queryResults)[i]);
jobject queryResults =
env->NewObject(queryResultsClass, constructor, labels, distances);
env->SetObjectArrayElement(javaQueryResults, i, queryResults);
env->DeleteLocalRef(labels);
env->DeleteLocalRef(distances);
env->DeleteLocalRef(queryResults);
}
return javaQueryResults;
} catch (RecallError const &e) {
if (!env->ExceptionCheck()) {
env->ThrowNew(
env->FindClass("com/spotify/voyager/jni/exception/RecallException"),
e.what());
}
return nullptr;
} catch (std::exception const &e) {
if (!env->ExceptionCheck()) {
env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what());
}
return nullptr;
}
}