diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index d15069b63..1749b34af 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -78,7 +78,7 @@ namespace knn_jni { // // Return an array of KNNQueryResults jobjectArray QueryBinaryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, - jbyteArray queryVectorJ, jint kJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ); + jbyteArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ); // Free the index located in memory at indexPointerJ void Free(jlong indexPointer); diff --git a/jni/src/commons.cpp b/jni/src/commons.cpp index c759f3c6a..13f59194e 100644 --- a/jni/src/commons.cpp +++ b/jni/src/commons.cpp @@ -71,4 +71,5 @@ int knn_jni::commons::getIntegerMethodParameter(JNIEnv * env, knn_jni::JNIUtilIn } return defaultValue; +} #endif //OPENSEARCH_KNN_COMMONS_H diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index f62c967d3..3eda03b41 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -490,10 +490,12 @@ jobjectArray knn_jni::faiss_wrapper::QueryBinaryIndex_WithFilter(knn_jni::JNIUti std::unique_ptr idGrouper; std::vector idGrouperBitmap; auto hnswReader = dynamic_cast(indexReader->index); - if(hnswReader!= nullptr) { + // TODO currently, search parameter is not supported in binary index + // To avoid test failure, we skip setting ef search when methodPramsJ is null temporary + if(hnswReader!= nullptr && (methodParamsJ != nullptr || parentIdsJ != nullptr)) { // Query param efsearch supersedes ef_search provided during index setting. hnswParams.efSearch = knn_jni::commons::getIntegerMethodParameter(env, jniUtil, methodParams, EF_SEARCH, hnswReader->hnsw.efSearch); - if(parentIdsJ != nullptr) { + if (parentIdsJ != nullptr) { idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap); hnswParams.grp = idGrouper.get(); } diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index c5ba503fc..030e10f75 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -425,7 +425,7 @@ TEST(FaissQueryBinaryIndexTest, BasicAssertions) { knn_jni::faiss_wrapper::QueryBinaryIndex_WithFilter( &mockJNIUtil, jniEnv, reinterpret_cast(&createdIndexWithData), - reinterpret_cast(&query), k, nullptr, 0, nullptr))); + reinterpret_cast(&query), k, nullptr, nullptr, 0, nullptr))); ASSERT_EQ(k, results->size()); @@ -556,6 +556,10 @@ TEST(FaissQueryIndexWithParentFilterTest, BasicAssertions) { // Setup jni JNIEnv *jniEnv = nullptr; NiceMock mockJNIUtil; + EXPECT_CALL(mockJNIUtil, + GetJavaIntArrayLength( + jniEnv, reinterpret_cast(&parentIds))) + .WillRepeatedly(Return(parentIds.size())); for (auto query : queries) { std::unique_ptr *>> results( reinterpret_cast *> *>( @@ -635,13 +639,13 @@ TEST(FaissCreateHnswSQfp16IndexTest, BasicAssertions) { // Define the data faiss::idx_t numIds = 200; std::vector ids; - auto *vectors = new std::vector(); + std::vector vectors; int dim = 2; - vectors->reserve(dim * numIds); + vectors.reserve(dim * numIds); for (int64_t i = 0; i < numIds; ++i) { ids.push_back(i); for (int j = 0; j < dim; ++j) { - vectors->push_back(test_util::RandomFloat(-500.0, 500.0)); + vectors.push_back(test_util::RandomFloat(-500.0, 500.0)); } } @@ -660,14 +664,14 @@ TEST(FaissCreateHnswSQfp16IndexTest, BasicAssertions) { EXPECT_CALL(mockJNIUtil, GetJavaObjectArrayLength( jniEnv, reinterpret_cast(&vectors))) - .WillRepeatedly(Return(vectors->size())); + .WillRepeatedly(Return(vectors.size())); // Create the index std::unique_ptr faissMethods(new FaissMethods()); knn_jni::faiss_wrapper::IndexService IndexService(std::move(faissMethods)); knn_jni::faiss_wrapper::CreateIndex( &mockJNIUtil, jniEnv, reinterpret_cast(&ids), - (jlong)vectors, dim, (jstring)&indexPath, + (jlong)&vectors, dim, (jstring)&indexPath, (jobject)¶metersMap, &IndexService); // Make sure index can be loaded diff --git a/jni/tests/faiss_wrapper_unit_test.cpp b/jni/tests/faiss_wrapper_unit_test.cpp index ea9131dd7..77b38e383 100644 --- a/jni/tests/faiss_wrapper_unit_test.cpp +++ b/jni/tests/faiss_wrapper_unit_test.cpp @@ -22,16 +22,15 @@ #include "faiss/IndexIDMap.h" using ::testing::NiceMock; - using idx_t = faiss::idx_t; -struct MockIndex : faiss::IndexHNSW { - explicit MockIndex(idx_t d) : faiss::IndexHNSW(d, 32) { +struct FaissMockIndex : faiss::IndexHNSW { + explicit FaissMockIndex(idx_t d) : faiss::IndexHNSW(d, 32) { } }; -struct MockIdMap : faiss::IndexIDMap { +struct FaissMockIdMap : faiss::IndexIDMap { mutable idx_t nCalled; mutable const float *xCalled; mutable idx_t kCalled; @@ -39,7 +38,7 @@ struct MockIdMap : faiss::IndexIDMap { mutable idx_t *labelsCalled; mutable const faiss::SearchParametersHNSW *paramsCalled; - explicit MockIdMap(MockIndex *index) : faiss::IndexIDMapTemplate(index) { + explicit FaissMockIdMap(FaissMockIndex *index) : faiss::IndexIDMapTemplate(index) { } void search( @@ -85,8 +84,8 @@ class FaissWrappeterParametrizedTestFixture : public testing::TestWithParam