jobject Java_com_spotify_voyager_jni_Index_query___3FIJ()

in java/com_spotify_voyager_jni_Index.cpp [424:478]


jobject Java_com_spotify_voyager_jni_Index_query___3FIJ(JNIEnv *env,
                                                        jobject self,
                                                        jfloatArray queryVector,
                                                        jint numNeighbors,
                                                        jlong queryEf) {
  try {
    std::shared_ptr<Index> index = getHandle<Index>(env, self);

    std::tuple<std::vector<hnswlib::labeltype>, std::vector<float>>
        queryResults =
            index->query(toStdVector(env, queryVector), numNeighbors, queryEf);

    jclass queryResultsClass =
        env->FindClass("com/spotify/voyager/jni/Index$QueryResults");
    if (!queryResultsClass) {
      throw std::runtime_error(
          "C++ bindings failed to find QueryResults class.");
    }

    jmethodID constructor =
        env->GetMethodID(queryResultsClass, "<init>", "([J[F)V");

    if (!constructor) {
      throw std::runtime_error(
          "C++ bindings failed to find QueryResults constructor.");
    }

    // Allocate a Java long array for the IDs:
    jlongArray labels = env->NewLongArray(numNeighbors);

    // queryResults is a (size_t *), but labels is a signed (long *).
    //  This may overflow if we have more than... 2^63 = 9.223372037e18
    //  elements. We're probably safe doing this.
    env->SetLongArrayRegion(labels, 0, numNeighbors,
                            (jlong *)std::get<0>(queryResults).data());

    jfloatArray distances = env->NewFloatArray(numNeighbors);
    env->SetFloatArrayRegion(distances, 0, numNeighbors,
                             std::get<1>(queryResults).data());

    return env->NewObject(queryResultsClass, constructor, labels, distances);
  } catch (RecallError const &e) {
    if (!env->ExceptionCheck()) {
      env->ThrowNew(
          env->FindClass("com/spotify/voyager/jni/exception/RecallException"),
          e.what());
    }
    return nullptr;
  } catch (std::exception const &e) {
    if (!env->ExceptionCheck()) {
      env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what());
    }
    return nullptr;
  }
}