java/com_spotify_voyager_jni_Index.cpp (782 lines of code) (raw):

/*- * -\-\- * voyager * -- * Copyright (C) 2016 - 2023 Spotify AB * -- * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * -/-/- */ #include "com_spotify_voyager_jni_Index.h" #include "JavaInputStream.h" #include "JavaOutputStream.h" #include <Enums.h> #include <Index.h> #include <TypedIndex.h> #include <cstring> #include <exception> #include <iostream> #include <thread> #include <type_traits> #include <vector> /** * Given a Java object, return the field ID for the "native handle" property * within that object, which can be used to store a C++ pointer. */ jfieldID getHandleFieldID(JNIEnv *env, jobject obj) { jclass c = env->GetObjectClass(obj); // J is the type signature for long: return env->GetFieldID(c, "nativeHandle", "J"); } template <typename T> std::shared_ptr<T> getHandle(JNIEnv *env, jobject obj, bool allow_missing = false) { env->MonitorEnter(obj); jlong handle = env->GetLongField(obj, getHandleFieldID(env, obj)); env->MonitorExit(obj); // Yes, we're storing a pointer to a shared pointer on a Java object. // A bit strange, but totally okay to ensure that we still get shared_ptr // semantics while storing a single Long value. std::shared_ptr<T> *pointer = reinterpret_cast<std::shared_ptr<T> *>(handle); if (!allow_missing && !pointer) { throw std::runtime_error( "This Voyager index has been closed and can no longer be used."); } // Return a copy of this shared pointer, thereby ensuring that it remains // alive until the shared_ptr goes out of scope. return *pointer; } template <typename T> void setHandle(JNIEnv *env, jobject obj, T *t) { std::shared_ptr<T> *sharedPointerForJava = new std::shared_ptr<T>(t); env->MonitorEnter(obj); env->SetLongField(obj, getHandleFieldID(env, obj), reinterpret_cast<jlong>(sharedPointerForJava)); env->MonitorExit(obj); } template <typename T> void deleteHandle(JNIEnv *env, jobject obj) { env->MonitorEnter(obj); jlong handle = env->GetLongField(obj, getHandleFieldID(env, obj)); env->MonitorExit(obj); if (handle == 0) return; std::shared_ptr<T> *pointer = reinterpret_cast<std::shared_ptr<T> *>(handle); // Note: This _may_ trigger the destructor of T, but if any other threads have // thread-local copies of this shared_ptr, the destructor will be triggered // on the last thread that has one of these shared_ptrs. env->MonitorEnter(obj); delete pointer; env->SetLongField(obj, getHandleFieldID(env, obj), 0); env->MonitorExit(obj); } std::string toString(JNIEnv *env, jstring js) { std::string result; long len = env->GetStringUTFLength(js); result.resize(len); if (len > 0) { env->GetStringUTFRegion(js, 0, len, result.data()); } return result; } std::string toString(JNIEnv *env, jobject object) { jclass javaClass = env->GetObjectClass(object); if (javaClass == 0) { throw std::runtime_error( "C++ bindings were unable to get the class for the provided object."); } return toString(env, (jstring)env->CallObjectMethod( object, env->GetMethodID(javaClass, "toString", "()Ljava/lang/String;"))); } SpaceType toSpaceType(JNIEnv *env, jobject enumVal) { std::string enumValueName = toString(env, enumVal); // TODO: Replace me with a usage of MagicEnum! if (enumValueName == "Euclidean") { return SpaceType::Euclidean; } else if (enumValueName == "InnerProduct") { return SpaceType::InnerProduct; } else if (enumValueName == "Cosine") { return SpaceType::Cosine; } else { throw std::runtime_error( "Voyager C++ bindings received unknown enum value \"" + enumValueName + "\"."); } } jobject toSpaceType(JNIEnv *env, SpaceType enumVal) { jclass enumClass = env->FindClass("com/spotify/voyager/jni/Index$SpaceType"); if (!enumClass) { throw std::runtime_error( "C++ bindings could not find SpaceType Java enum!"); } const char *enumValueName = nullptr; switch (enumVal) { case SpaceType::Euclidean: enumValueName = "Euclidean"; break; case SpaceType::InnerProduct: enumValueName = "InnerProduct"; break; case SpaceType::Cosine: enumValueName = "Cosine"; break; default: throw std::runtime_error( "Voyager C++ bindings received unknown enum value."); } jfieldID fieldID = env->GetStaticFieldID( enumClass, enumValueName, "Lcom/spotify/voyager/jni/Index$SpaceType;"); if (!fieldID) { throw std::runtime_error( "C++ bindings could not find value in SpaceType Java enum!"); } jobject javaValue = env->GetStaticObjectField(enumClass, fieldID); if (!javaValue) { throw std::runtime_error("C++ bindings could not find static object field " "for in SpaceType Java enum!"); } return javaValue; } StorageDataType toStorageDataType(JNIEnv *env, jobject enumVal) { std::string enumValueName = toString(env, enumVal); // TODO: Replace me with a usage of MagicEnum! if (enumValueName == "Float8") { return StorageDataType::Float8; } else if (enumValueName == "Float32") { return StorageDataType::Float32; } else if (enumValueName == "E4M3") { return StorageDataType::E4M3; } else { throw std::runtime_error( "Voyager C++ bindings received unknown enum value \"" + enumValueName + "\"."); } } jobject toStorageDataType(JNIEnv *env, StorageDataType enumVal) { jclass enumClass = env->FindClass("com/spotify/voyager/jni/Index$StorageDataType"); if (!enumClass) { throw std::runtime_error( "C++ bindings could not find StorageDataType Java enum!"); } const char *enumValueName = nullptr; switch (enumVal) { case StorageDataType::Float8: enumValueName = "Float8"; break; case StorageDataType::Float32: enumValueName = "Float32"; break; case StorageDataType::E4M3: enumValueName = "E4M3"; break; default: throw std::runtime_error( "Voyager C++ bindings received unknown enum value."); } jfieldID fieldID = env->GetStaticFieldID(enumClass, enumValueName, "Lcom/spotify/voyager/jni/Index$StorageDataType;"); if (!fieldID) { throw std::runtime_error( "C++ bindings could not find value in StorageDataType Java enum!"); } jobject javaValue = env->GetStaticObjectField(enumClass, fieldID); if (!javaValue) { throw std::runtime_error("C++ bindings could not find static object field " "for in StorageDataType Java enum!"); } return javaValue; } /** * Convert a Java nested array (array of float arrays) to a 2D NDArray. */ NDArray<float, 2> toNDArray(JNIEnv *env, jobjectArray floatArrays) { jsize numElements = env->GetArrayLength(floatArrays); if (numElements == 0) { return NDArray<float, 2>({0, 0}); } jobject firstElement = env->GetObjectArrayElement(floatArrays, 0); jsize numDimensions = env->GetArrayLength((jfloatArray)firstElement); env->DeleteLocalRef(firstElement); if (numDimensions == 0) { return NDArray<float, 2>({0, 0}); } NDArray<float, 2> output = NDArray<float, 2>({numElements, numDimensions}); float *outputPointer = output.data.data(); for (int i = 0; i < numElements; i++) { jobject element = env->GetObjectArrayElement(floatArrays, i); // TODO: Ensure that each element is actually a float array! jfloatArray floatArray = (jfloatArray)element; jsize numVectorDimensions = env->GetArrayLength(floatArray); if (numVectorDimensions != numDimensions) { throw std::runtime_error("When passing an array of arrays, all " "sub-arrays must be the same length."); } env->GetFloatArrayRegion(floatArray, 0, numDimensions, outputPointer); // Delete the local reference to the nested float array. // Note that this isn't necessary; it merely helps the Java GC // identify if/when these elements can be cleaned up earlier than when // this function returns. // Removing this call would not create a memory leak. env->DeleteLocalRef(element); outputPointer += numDimensions; } return output; } /** * Convert a Java float array to a std::vector<float>. */ std::vector<float> toStdVector(JNIEnv *env, jfloatArray floatArray) { jsize numElements = env->GetArrayLength(floatArray); std::vector<float> input(numElements); env->GetFloatArrayRegion(floatArray, 0, numElements, (float *)input.data()); return input; } /** * Convert a std::vector<float> to a Java float array. */ jfloatArray toFloatArray(JNIEnv *env, std::vector<float> floatArray) { jfloatArray returnArray = env->NewFloatArray(floatArray.size()); env->SetFloatArrayRegion(returnArray, 0, floatArray.size(), floatArray.data()); return returnArray; } /** * Convert a Java long array to a std::vector<size_t>. * Note that this function will underflow if any elements are negative. */ std::vector<size_t> toUnsignedStdVector(JNIEnv *env, jlongArray longArray) { jsize numElements = env->GetArrayLength(longArray); std::vector<size_t> input(numElements); env->GetLongArrayRegion(longArray, 0, numElements, (jlong *)input.data()); return input; } //////////////////////////////////////////////////////////////////////////////////////////////////// // Index Construction and Indexing //////////////////////////////////////////////////////////////////////////////////////////////////// void Java_com_spotify_voyager_jni_Index_nativeConstructor( JNIEnv *env, jobject self, jobject spaceType, jint numDimensions, jlong M, jlong efConstruction, jlong randomSeed, jlong maxElements, jobject storageDataType) { try { switch (toStorageDataType(env, storageDataType)) { case StorageDataType::Float32: setHandle<Index>(env, self, new TypedIndex<float>(toSpaceType(env, spaceType), numDimensions, M, efConstruction, randomSeed, maxElements)); break; case StorageDataType::Float8: setHandle<Index>(env, self, new TypedIndex<float, int8_t, std::ratio<1, 127>>( toSpaceType(env, spaceType), numDimensions, M, efConstruction, randomSeed, maxElements)); break; case StorageDataType::E4M3: setHandle<Index>(env, self, new TypedIndex<float, E4M3>( toSpaceType(env, spaceType), numDimensions, M, efConstruction, randomSeed, maxElements)); break; } } catch (std::exception const &e) { if (!env->ExceptionCheck()) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } } } jlong Java_com_spotify_voyager_jni_Index_addItem___3F(JNIEnv *env, jobject self, jfloatArray vector) { try { std::shared_ptr<Index> index = getHandle<Index>(env, self); return index->addItem(toStdVector(env, vector), {}); } catch (std::exception const &e) { if (!env->ExceptionCheck()) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } } return -1; } jlong Java_com_spotify_voyager_jni_Index_addItem___3FJ(JNIEnv *env, jobject self, jfloatArray vector, jlong id) { try { std::shared_ptr<Index> index = getHandle<Index>(env, self); return index->addItem(toStdVector(env, vector), {id}); } catch (std::exception const &e) { if (!env->ExceptionCheck()) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } } return -1; } jlongArray Java_com_spotify_voyager_jni_Index_addItems___3_3FI( JNIEnv *env, jobject self, jobjectArray vectors, jint numThreads) { try { std::shared_ptr<Index> index = getHandle<Index>(env, self); std::vector<hnswlib::labeltype> nativeIds = index->addItems(toNDArray(env, vectors), {}, numThreads); // Allocate a Java long array for the IDs: static_assert( sizeof(hnswlib::labeltype) == sizeof(jlong), "addItems expects hnswlib::labeltype to be a 64-bit integer."); jlongArray javaIds = env->NewLongArray(nativeIds.size()); env->SetLongArrayRegion(javaIds, 0, nativeIds.size(), (jlong *)nativeIds.data()); return javaIds; } catch (std::exception const &e) { if (!env->ExceptionCheck()) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } } return nullptr; } jlongArray Java_com_spotify_voyager_jni_Index_addItems___3_3F_3JI( JNIEnv *env, jobject self, jobjectArray vectors, jlongArray ids, jint numThreads) { try { std::shared_ptr<Index> index = getHandle<Index>(env, self); std::vector<hnswlib::labeltype> nativeIds = index->addItems( toNDArray(env, vectors), toUnsignedStdVector(env, ids), numThreads); // Allocate a Java long array for the IDs: static_assert( sizeof(hnswlib::labeltype) == sizeof(jlong), "addItems expects hnswlib::labeltype to be a 64-bit integer."); jlongArray javaIds = env->NewLongArray(nativeIds.size()); env->SetLongArrayRegion(javaIds, 0, nativeIds.size(), (jlong *)nativeIds.data()); return javaIds; } catch (std::exception const &e) { if (!env->ExceptionCheck()) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } } return nullptr; } //////////////////////////////////////////////////////////////////////////////////////////////////// // Querying //////////////////////////////////////////////////////////////////////////////////////////////////// jobject Java_com_spotify_voyager_jni_Index_query___3FIJ(JNIEnv *env, jobject self, jfloatArray queryVector, jint numNeighbors, jlong queryEf) { try { std::shared_ptr<Index> index = getHandle<Index>(env, self); std::tuple<std::vector<hnswlib::labeltype>, std::vector<float>> queryResults = index->query(toStdVector(env, queryVector), numNeighbors, queryEf); jclass queryResultsClass = env->FindClass("com/spotify/voyager/jni/Index$QueryResults"); if (!queryResultsClass) { throw std::runtime_error( "C++ bindings failed to find QueryResults class."); } jmethodID constructor = env->GetMethodID(queryResultsClass, "<init>", "([J[F)V"); if (!constructor) { throw std::runtime_error( "C++ bindings failed to find QueryResults constructor."); } // Allocate a Java long array for the IDs: jlongArray labels = env->NewLongArray(numNeighbors); // queryResults is a (size_t *), but labels is a signed (long *). // This may overflow if we have more than... 2^63 = 9.223372037e18 // elements. We're probably safe doing this. env->SetLongArrayRegion(labels, 0, numNeighbors, (jlong *)std::get<0>(queryResults).data()); jfloatArray distances = env->NewFloatArray(numNeighbors); env->SetFloatArrayRegion(distances, 0, numNeighbors, std::get<1>(queryResults).data()); return env->NewObject(queryResultsClass, constructor, labels, distances); } catch (RecallError const &e) { if (!env->ExceptionCheck()) { env->ThrowNew( env->FindClass("com/spotify/voyager/jni/exception/RecallException"), e.what()); } return nullptr; } catch (std::exception const &e) { if (!env->ExceptionCheck()) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } return nullptr; } } jobjectArray Java_com_spotify_voyager_jni_Index_query___3_3FIIJ( JNIEnv *env, jobject self, jobjectArray queryVectors, jint numNeighbors, jint numThreads, jlong queryEf) { try { std::shared_ptr<Index> index = getHandle<Index>(env, self); int numQueries = env->GetArrayLength(queryVectors); std::tuple<NDArray<hnswlib::labeltype, 2>, NDArray<float, 2>> queryResults = index->query(toNDArray(env, queryVectors), numNeighbors, numThreads, queryEf); jclass queryResultsClass = env->FindClass("com/spotify/voyager/jni/Index$QueryResults"); if (!queryResultsClass) { throw std::runtime_error( "C++ bindings failed to find QueryResults class."); } jmethodID constructor = env->GetMethodID(queryResultsClass, "<init>", "([J[F)V"); if (!constructor) { throw std::runtime_error( "C++ bindings failed to find QueryResults constructor."); } jobjectArray javaQueryResults = env->NewObjectArray(numQueries, queryResultsClass, NULL); for (int i = 0; i < numQueries; i++) { // Allocate a Java long array for the indices, and a float array for the // distances: jlongArray labels = env->NewLongArray(numNeighbors); // queryResults is a (size_t *), but labels is a signed (long *). // This may overflow if we have more than... 2^63 = 9.223372037e18 // elements. We're probably safe doing this. env->SetLongArrayRegion(labels, 0, numNeighbors, (jlong *)std::get<0>(queryResults)[i]); jfloatArray distances = env->NewFloatArray(numNeighbors); env->SetFloatArrayRegion(distances, 0, numNeighbors, std::get<1>(queryResults)[i]); jobject queryResults = env->NewObject(queryResultsClass, constructor, labels, distances); env->SetObjectArrayElement(javaQueryResults, i, queryResults); env->DeleteLocalRef(labels); env->DeleteLocalRef(distances); env->DeleteLocalRef(queryResults); } return javaQueryResults; } catch (RecallError const &e) { if (!env->ExceptionCheck()) { env->ThrowNew( env->FindClass("com/spotify/voyager/jni/exception/RecallException"), e.what()); } return nullptr; } catch (std::exception const &e) { if (!env->ExceptionCheck()) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } return nullptr; } } //////////////////////////////////////////////////////////////////////////////////////////////////// // Property Methods //////////////////////////////////////////////////////////////////////////////////////////////////// jobject Java_com_spotify_voyager_jni_Index_getSpace(JNIEnv *env, jobject self) { try { return toSpaceType(env, getHandle<Index>(env, self)->getSpace()); } catch (std::exception const &e) { if (!env->ExceptionCheck()) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } } return nullptr; } jint Java_com_spotify_voyager_jni_Index_getNumDimensions(JNIEnv *env, jobject self) { try { return getHandle<Index>(env, self)->getNumDimensions(); } catch (std::exception const &e) { if (!env->ExceptionCheck()) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } } return 0; } jlong Java_com_spotify_voyager_jni_Index_getM(JNIEnv *env, jobject self) { try { return getHandle<Index>(env, self)->getM(); } catch (std::exception const &e) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } return 0; } jlong Java_com_spotify_voyager_jni_Index_getEfConstruction(JNIEnv *env, jobject self) { try { return getHandle<Index>(env, self)->getEfConstruction(); } catch (std::exception const &e) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } return 0; } jlong Java_com_spotify_voyager_jni_Index_getMaxElements(JNIEnv *env, jobject self) { try { return getHandle<Index>(env, self)->getMaxElements(); } catch (std::exception const &e) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } return 0; } jobject Java_com_spotify_voyager_jni_Index_getStorageDataType(JNIEnv *env, jobject self) { try { return toStorageDataType(env, getHandle<Index>(env, self)->getStorageDataType()); } catch (std::exception const &e) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } return nullptr; } //////////////////////////////////////////////////////////////////////////////////////////////////// // Index Accessor Methods //////////////////////////////////////////////////////////////////////////////////////////////////// jlong Java_com_spotify_voyager_jni_Index_getNumElements(JNIEnv *env, jobject self) { try { return getHandle<Index>(env, self)->getNumElements(); } catch (std::exception const &e) { if (!env->ExceptionCheck()) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } } return 0; } jfloatArray Java_com_spotify_voyager_jni_Index_getVector(JNIEnv *env, jobject self, jlong id) { try { std::shared_ptr<Index> index = getHandle<Index>(env, self); return toFloatArray(env, index->getVector(id)); } catch (std::exception const &e) { if (!env->ExceptionCheck()) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } return nullptr; } } jobjectArray Java_com_spotify_voyager_jni_Index_getVectors(JNIEnv *env, jobject self, jlongArray ids) { try { std::shared_ptr<Index> index = getHandle<Index>(env, self); NDArray<float, 2> vectors = index->getVectors(toUnsignedStdVector(env, ids)); jclass floatArrayClass = env->FindClass("[F"); if (!floatArrayClass) { throw std::runtime_error("C++ bindings failed to find float[] class."); } jobjectArray javaVectors = env->NewObjectArray(vectors.shape[0], floatArrayClass, NULL); for (int i = 0; i < vectors.shape[0]; i++) { jfloatArray vector = env->NewFloatArray(vectors.shape[1]); env->SetFloatArrayRegion(vector, 0, vectors.shape[1], vectors[i]); env->SetObjectArrayElement(javaVectors, i, vector); env->DeleteLocalRef(vector); } return javaVectors; } catch (std::exception const &e) { if (!env->ExceptionCheck()) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } return nullptr; } } jlongArray Java_com_spotify_voyager_jni_Index_getIDs(JNIEnv *env, jobject self) { try { std::shared_ptr<Index> index = getHandle<Index>(env, self); std::vector<hnswlib::labeltype> ids = index->getIDs(); static_assert(sizeof(hnswlib::labeltype) == sizeof(jlong), "getIDs expects hnswlib::labeltype to be a 64-bit integer."); jclass longArrayClass = env->FindClass("[J"); if (!longArrayClass) { throw std::runtime_error("C++ bindings failed to find long[] class."); } // Allocate a Java long array for the IDs: jlongArray javaIds = env->NewLongArray(ids.size()); env->SetLongArrayRegion(javaIds, 0, ids.size(), (jlong *)ids.data()); return javaIds; } catch (std::exception const &e) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); return nullptr; } } //////////////////////////////////////////////////////////////////////////////////////////////////// // Index Modifier Methods //////////////////////////////////////////////////////////////////////////////////////////////////// void Java_com_spotify_voyager_jni_Index_setEf(JNIEnv *env, jobject self, jlong newEf) { try { getHandle<Index>(env, self)->setEF(newEf); } catch (std::exception const &e) { if (!env->ExceptionCheck()) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } } } jint Java_com_spotify_voyager_jni_Index_getEf(JNIEnv *env, jobject self) { try { return getHandle<Index>(env, self)->getEF(); } catch (std::exception const &e) { if (!env->ExceptionCheck()) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } } return 0; } void Java_com_spotify_voyager_jni_Index_markDeleted(JNIEnv *env, jobject self, jlong label) { try { getHandle<Index>(env, self)->markDeleted(label); } catch (std::exception const &e) { if (!env->ExceptionCheck()) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } } } void Java_com_spotify_voyager_jni_Index_unmarkDeleted(JNIEnv *env, jobject self, jlong label) { try { getHandle<Index>(env, self)->unmarkDeleted(label); } catch (std::exception const &e) { if (!env->ExceptionCheck()) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } } } void Java_com_spotify_voyager_jni_Index_resizeIndex(JNIEnv *env, jobject self, jlong newSize) { try { std::shared_ptr<Index> index = getHandle<Index>(env, self); index->resizeIndex(newSize); } catch (std::exception const &e) { if (!env->ExceptionCheck()) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// // Save Index //////////////////////////////////////////////////////////////////////////////////////////////////// void Java_com_spotify_voyager_jni_Index_saveIndex__Ljava_lang_String_2( JNIEnv *env, jobject self, jstring filename) { try { std::shared_ptr<Index> index = getHandle<Index>(env, self); index->saveIndex(toString(env, filename)); } catch (std::exception const &e) { if (!env->ExceptionCheck()) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } } } void Java_com_spotify_voyager_jni_Index_saveIndex__Ljava_io_OutputStream_2( JNIEnv *env, jobject self, jobject outputStream) { try { std::shared_ptr<Index> index = getHandle<Index>(env, self); index->saveIndex(std::make_shared<JavaOutputStream>(env, outputStream)); } catch (std::exception const &e) { if (!env->ExceptionCheck()) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// // Load Index //////////////////////////////////////////////////////////////////////////////////////////////////// // TODO: Convert these to static methods void Java_com_spotify_voyager_jni_Index_nativeLoadFromFileWithParameters( JNIEnv *env, jobject self, jstring filename, jobject spaceType, jint numDimensions, jobject storageDataType) { try { auto inputStream = std::make_shared<FileInputStream>(toString(env, filename)); std::unique_ptr<voyager::Metadata::V1> metadata = voyager::Metadata::loadFromStream(inputStream); if (metadata) { if (metadata->getStorageDataType() != toStorageDataType(env, storageDataType)) { throw std::domain_error( "Provided storage data type (" + toString(toStorageDataType(env, storageDataType)) + ") does not match the data type used in this file (" + toString(metadata->getStorageDataType()) + ")."); } if (metadata->getSpaceType() != toSpaceType(env, spaceType)) { throw std::domain_error( "Provided space type (" + toString(toSpaceType(env, spaceType)) + ") does not match the space type used in this file (" + toString(metadata->getSpaceType()) + ")."); } if (metadata->getNumDimensions() != numDimensions) { throw std::domain_error( "Provided number of dimensions (" + std::to_string(numDimensions) + ") does not match the number of dimensions used in this file (" + std::to_string(metadata->getNumDimensions()) + ")."); } setHandle<Index>( env, self, loadTypedIndexFromMetadata(std::move(metadata), inputStream) .release()); return; } switch (toStorageDataType(env, storageDataType)) { case StorageDataType::Float32: setHandle<Index>(env, self, new TypedIndex<float>(inputStream, toSpaceType(env, spaceType), numDimensions)); break; case StorageDataType::Float8: setHandle<Index>( env, self, new TypedIndex<float, int8_t, std::ratio<1, 127>>( inputStream, toSpaceType(env, spaceType), numDimensions)); break; case StorageDataType::E4M3: setHandle<Index>(env, self, new TypedIndex<float, E4M3>(inputStream, toSpaceType(env, spaceType), numDimensions)); break; } } catch (std::exception const &e) { if (!env->ExceptionCheck()) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } } } void Java_com_spotify_voyager_jni_Index_nativeLoadFromInputStreamWithParameters( JNIEnv *env, jobject self, jobject jInputStream, jobject spaceType, jint numDimensions, jobject storageDataType) { try { auto inputStream = std::make_shared<JavaInputStream>(env, jInputStream); std::unique_ptr<voyager::Metadata::V1> metadata = voyager::Metadata::loadFromStream(inputStream); if (metadata) { if (metadata->getStorageDataType() != toStorageDataType(env, storageDataType)) { throw std::domain_error( "Provided storage data type (" + toString(toStorageDataType(env, storageDataType)) + ") does not match the data type used in this file (" + toString(metadata->getStorageDataType()) + ")."); } if (metadata->getSpaceType() != toSpaceType(env, spaceType)) { throw std::domain_error( "Provided space type (" + toString(toSpaceType(env, spaceType)) + ") does not match the space type used in this file (" + toString(metadata->getSpaceType()) + ")."); } if (metadata->getNumDimensions() != numDimensions) { throw std::domain_error( "Provided number of dimensions (" + std::to_string(numDimensions) + ") does not match the number of dimensions used in this file (" + std::to_string(metadata->getNumDimensions()) + ")."); } setHandle<Index>( env, self, loadTypedIndexFromMetadata(std::move(metadata), inputStream) .release()); return; } switch (toStorageDataType(env, storageDataType)) { case StorageDataType::Float32: setHandle<Index>(env, self, new TypedIndex<float>(inputStream, toSpaceType(env, spaceType), numDimensions)); break; case StorageDataType::Float8: setHandle<Index>( env, self, new TypedIndex<float, int8_t, std::ratio<1, 127>>( inputStream, toSpaceType(env, spaceType), numDimensions)); break; case StorageDataType::E4M3: setHandle<Index>(env, self, new TypedIndex<float, E4M3>(inputStream, toSpaceType(env, spaceType), numDimensions)); break; } } catch (std::exception const &e) { if (!env->ExceptionCheck()) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } } } void Java_com_spotify_voyager_jni_Index_nativeLoadFromFile(JNIEnv *env, jobject self, jstring filename) { try { auto inputStream = std::make_shared<FileInputStream>(toString(env, filename)); std::unique_ptr<voyager::Metadata::V1> metadata = voyager::Metadata::loadFromStream(inputStream); if (metadata) { setHandle<Index>( env, self, loadTypedIndexFromMetadata(std::move(metadata), inputStream) .release()); } else { throw std::domain_error( "Provided index file has no metadata and no index parameters were " "specified. Must either provide an index with metadata or specify " "storageDataType, spaceType, and numDimensions."); } } catch (std::exception const &e) { if (!env->ExceptionCheck()) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } } } void Java_com_spotify_voyager_jni_Index_nativeLoadFromInputStream( JNIEnv *env, jobject self, jobject jInputStream) { try { auto inputStream = std::make_shared<JavaInputStream>(env, jInputStream); std::unique_ptr<voyager::Metadata::V1> metadata = voyager::Metadata::loadFromStream(inputStream); if (metadata) { setHandle<Index>( env, self, loadTypedIndexFromMetadata(std::move(metadata), inputStream) .release()); } else { throw std::domain_error( "Provided index file has no metadata and no index parameters were " "specified. Must either provide an index with metadata or specify " "storageDataType, spaceType, and numDimensions."); } } catch (std::exception const &e) { if (!env->ExceptionCheck()) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } } } void Java_com_spotify_voyager_jni_Index_nativeDestructor(JNIEnv *env, jobject self) { try { deleteHandle<Index>(env, self); } catch (std::exception const &e) { if (!env->ExceptionCheck()) { env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what()); } } }