java/JavaInputStream.h (102 lines of code) (raw):

/*- * -\-\- * voyager * -- * Copyright (C) 2016 - 2023 Spotify AB * -- * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * -/-/- */ #include <StreamUtils.h> #include <cstring> #include <jni.h> #include <vector> class JavaInputStream : public InputStream { public: // This input stream stores a temporary buffer to copy between Java and C++; // if we don't set a maximum buffer size here, the C++ side could read // hundreds of GB at once, which would allocate 2x that amount. static constexpr long long MAX_BUFFER_SIZE = 1024 * 1024 * 100; JavaInputStream(JNIEnv *env, jobject inputStream) : env(env), inputStream(inputStream) { jclass inputStreamClass = env->FindClass("java/io/InputStream"); if (!inputStreamClass) { throw std::runtime_error("Native code failed to find InputStream class!"); } if (!env->IsInstanceOf(inputStream, inputStreamClass)) { throw std::runtime_error( "Provided Java object is not a java.io.InputStream!"); } }; virtual bool isSeekable() { return false; } virtual long long getTotalLength() { return -1; } virtual long long read(char *buffer, long long bytesToRead) { jmethodID readMethod = env->GetMethodID( env->FindClass("java/io/InputStream"), "read", "([BII)I"); if (!readMethod) { throw std::runtime_error("Native code failed to find " "java.io.InputStream#read(byte[]) method!"); } long long bytesRead = 0; long long bufferSize = std::min(MAX_BUFFER_SIZE, bytesToRead); jbyteArray byteArray = env->NewByteArray(bufferSize); if (!byteArray) { throw std::domain_error( "Failed to instantiate Java byte array of size: " + std::to_string(bufferSize)); } if (peekValue.size()) { long long bytesToCopy = std::min(bytesToRead, (long long)peekValue.size()); std::memcpy(buffer, peekValue.data(), bytesToCopy); for (int i = 0; i < bytesToCopy; i++) peekValue.erase(peekValue.begin()); bytesRead += bytesToCopy; buffer += bytesToCopy; } while (bytesRead < bytesToRead) { int readResult = env->CallIntMethod( inputStream, readMethod, byteArray, 0, (int)(std::min(bufferSize, bytesToRead - bytesRead))); if (env->ExceptionCheck()) { return 0; } if (readResult > 0) { if (bytesRead + readResult > bytesToRead) { throw std::domain_error("java.io.InputStream#read(byte[]) returned " + std::to_string(readResult) + ", but only " + std::to_string(bytesToRead - bytesRead) + " bytes were required."); } if (readResult > bufferSize) { throw std::domain_error("java.io.InputStream#read(byte[]) returned " + std::to_string(readResult) + ", but buffer is only " + std::to_string(bufferSize) + " bytes."); } env->GetByteArrayRegion(byteArray, 0, readResult, (jbyte *)buffer); bytesRead += readResult; buffer += readResult; } else { bytesRead = readResult; break; } } env->DeleteLocalRef(byteArray); return bytesRead; } virtual bool isExhausted() { return false; } virtual long long getPosition() { return bytesRead; } virtual bool setPosition(long long position) { return false; } virtual ~JavaInputStream() {} virtual uint32_t peek() { uint32_t result = 0; long long lastPosition = getPosition(); if (read((char *)&result, sizeof(result)) == sizeof(result)) { char *resultAsCharacters = (char *)&result; peekValue.push_back(resultAsCharacters[0]); peekValue.push_back(resultAsCharacters[1]); peekValue.push_back(resultAsCharacters[2]); peekValue.push_back(resultAsCharacters[3]); return result; } else { throw std::runtime_error("Failed to peek " + std::to_string(sizeof(result)) + " bytes from JavaInputStream at index " + std::to_string(lastPosition) + "."); } } private: JNIEnv *env; jobject inputStream; std::vector<char> peekValue; long long bytesRead = 0; };