jobjectArray Java_com_spotify_voyager_jni_Index_query___3_3FIIJ()

in java/com_spotify_voyager_jni_Index.cpp [480:547]


jobjectArray Java_com_spotify_voyager_jni_Index_query___3_3FIIJ(
    JNIEnv *env, jobject self, jobjectArray queryVectors, jint numNeighbors,
    jint numThreads, jlong queryEf) {
  try {
    std::shared_ptr<Index> index = getHandle<Index>(env, self);

    int numQueries = env->GetArrayLength(queryVectors);

    std::tuple<NDArray<hnswlib::labeltype, 2>, NDArray<float, 2>> queryResults =
        index->query(toNDArray(env, queryVectors), numNeighbors, numThreads,
                     queryEf);

    jclass queryResultsClass =
        env->FindClass("com/spotify/voyager/jni/Index$QueryResults");
    if (!queryResultsClass) {
      throw std::runtime_error(
          "C++ bindings failed to find QueryResults class.");
    }

    jmethodID constructor =
        env->GetMethodID(queryResultsClass, "<init>", "([J[F)V");

    if (!constructor) {
      throw std::runtime_error(
          "C++ bindings failed to find QueryResults constructor.");
    }

    jobjectArray javaQueryResults =
        env->NewObjectArray(numQueries, queryResultsClass, NULL);

    for (int i = 0; i < numQueries; i++) {
      // Allocate a Java long array for the indices, and a float array for the
      // distances:
      jlongArray labels = env->NewLongArray(numNeighbors);

      // queryResults is a (size_t *), but labels is a signed (long *).
      //  This may overflow if we have more than... 2^63 = 9.223372037e18
      //  elements. We're probably safe doing this.
      env->SetLongArrayRegion(labels, 0, numNeighbors,
                              (jlong *)std::get<0>(queryResults)[i]);

      jfloatArray distances = env->NewFloatArray(numNeighbors);
      env->SetFloatArrayRegion(distances, 0, numNeighbors,
                               std::get<1>(queryResults)[i]);

      jobject queryResults =
          env->NewObject(queryResultsClass, constructor, labels, distances);
      env->SetObjectArrayElement(javaQueryResults, i, queryResults);
      env->DeleteLocalRef(labels);
      env->DeleteLocalRef(distances);
      env->DeleteLocalRef(queryResults);
    }

    return javaQueryResults;
  } catch (RecallError const &e) {
    if (!env->ExceptionCheck()) {
      env->ThrowNew(
          env->FindClass("com/spotify/voyager/jni/exception/RecallException"),
          e.what());
    }
    return nullptr;
  } catch (std::exception const &e) {
    if (!env->ExceptionCheck()) {
      env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what());
    }
    return nullptr;
  }
}