Skip to content

Commit

Permalink
Support ef_search parameter in radial search faiss engine (opensearch…
Browse files Browse the repository at this point in the history
…-project#1790)

Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei authored Jul 3, 2024
1 parent a913082 commit e7d7ec8
Show file tree
Hide file tree
Showing 25 changed files with 508 additions and 119 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.15...2.x)
### Features
* Adds dynamic query parameter ef_search [#1783](https://github.com/opensearch-project/k-NN/pull/1783)
* Adds dynamic query parameter ef_search in radial search faiss engine [#1790](https://github.com/opensearch-project/k-NN/pull/1790)
### Enhancements
### Bug Fixes
### Infrastructure
Expand Down
6 changes: 4 additions & 2 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ namespace knn_jni {
* @param indexPointerJ - pointer to the index
* @param queryVectorJ - the query vector
* @param radiusJ - the radius for the range search
* @param methodParamsJ - the method parameters
* @param maxResultsWindowJ - the maximum number of results to return
* @param filterIdsJ - the filter ids
* @param filterIdsTypeJ - the filter ids type
Expand All @@ -110,21 +111,22 @@ namespace knn_jni {
* @return an array of RangeQueryResults
*/
jobjectArray RangeSearchWithFilter(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, jfloatArray queryVectorJ,
jfloat radiusJ, jint maxResultWindowJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ);
jfloat radiusJ, jobject methodParamsJ, jint maxResultWindowJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ);

/*
* Perform a range search against the index located in memory at indexPointerJ.
*
* @param indexPointerJ - pointer to the index
* @param queryVectorJ - the query vector
* @param radiusJ - the radius for the range search
* @param methodParamsJ - the method parameters
* @param maxResultsWindowJ - the maximum number of results to return
* @param parentIdsJ - the parent ids
*
* @return an array of RangeQueryResults
*/
jobjectArray RangeSearch(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, jfloatArray queryVectorJ,
jfloat radiusJ, jint maxResultWindowJ, jintArray parentIdsJ);
jfloat radiusJ, jobject methodParamsJ, jint maxResultWindowJ, jintArray parentIdsJ);
}
}

Expand Down
8 changes: 4 additions & 4 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,18 +150,18 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors
/*
* Class: org_opensearch_knn_jni_FaissService
* Method: rangeSearchIndexWithFilter
* Signature: (J[FJ[I)[Lorg/opensearch/knn/index/query/RangeQueryResult;
* Signature: (J[FJLjava/util/MapI[JII)[Lorg/opensearch/knn/index/query/RangeQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndexWithFilter
(JNIEnv *, jclass, jlong, jfloatArray, jfloat, jint, jlongArray, jint, jintArray);
(JNIEnv *, jclass, jlong, jfloatArray, jfloat, jobject, jint, jlongArray, jint, jintArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: rangeSearchIndex
* Signature: (J[FJ[I)[Lorg/opensearch/knn/index/query/RangeQueryResult;
* Signature: (J[FJLjava/util/MapII)[Lorg/opensearch/knn/index/query/RangeQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex
(JNIEnv *, jclass, jlong, jfloatArray, jfloat, jint, jintArray);
(JNIEnv *, jclass, jlong, jfloatArray, jfloat, jobject, jint, jintArray);

#ifdef __cplusplus
}
Expand Down
29 changes: 17 additions & 12 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -716,12 +716,12 @@ faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index) {
}

jobjectArray knn_jni::faiss_wrapper::RangeSearch(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ,
jfloatArray queryVectorJ, jfloat radiusJ, jint maxResultWindowJ, jintArray parentIdsJ) {
return knn_jni::faiss_wrapper::RangeSearchWithFilter(jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, maxResultWindowJ, nullptr, 0, parentIdsJ);
jfloatArray queryVectorJ, jfloat radiusJ, jobject methodParamsJ, jint maxResultWindowJ, jintArray parentIdsJ) {
return knn_jni::faiss_wrapper::RangeSearchWithFilter(jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, methodParamsJ, maxResultWindowJ, nullptr, 0, parentIdsJ);
}

jobjectArray knn_jni::faiss_wrapper::RangeSearchWithFilter(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ,
jfloatArray queryVectorJ, jfloat radiusJ, jint maxResultWindowJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) {
jfloatArray queryVectorJ, jfloat radiusJ, jobject methodParamsJ, jint maxResultWindowJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) {
if (queryVectorJ == nullptr) {
throw std::runtime_error("Query Vector cannot be null");
}
Expand All @@ -734,6 +734,11 @@ jobjectArray knn_jni::faiss_wrapper::RangeSearchWithFilter(knn_jni::JNIUtilInter

float *rawQueryVector = jniUtil->GetFloatArrayElements(env, queryVectorJ, nullptr);

std::unordered_map<std::string, jobject> methodParams;
if (methodParamsJ != nullptr) {
methodParams = jniUtil->ConvertJavaMapToCppMap(env, methodParamsJ);
}

// The res will be freed by ~RangeSearchResult() in FAISS
// The second parameter is always true, as lims is allocated by FAISS
faiss::RangeSearchResult res(1, true);
Expand All @@ -755,9 +760,8 @@ jobjectArray knn_jni::faiss_wrapper::RangeSearchWithFilter(knn_jni::JNIUtilInter
std::vector<uint64_t> idGrouperBitmap;
auto hnswReader = dynamic_cast<const faiss::IndexHNSW*>(indexReader->index);
if(hnswReader) {
// Setting the ef_search value equal to what was provided during index creation. SearchParametersHNSW has a default
// value of ef_search = 16 which will then be used.
hnswParams.efSearch = hnswReader->hnsw.efSearch;
// Query param ef_search supersedes ef_search provided during index setting.
hnswParams.efSearch = knn_jni::commons::getIntegerMethodParameter(env, jniUtil, methodParams, EF_SEARCH, hnswReader->hnsw.efSearch);
hnswParams.sel = idSelector.get();
if (parentIdsJ != nullptr) {
idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap);
Expand Down Expand Up @@ -785,12 +789,13 @@ jobjectArray knn_jni::faiss_wrapper::RangeSearchWithFilter(knn_jni::JNIUtilInter
std::unique_ptr<faiss::IDGrouperBitmap> idGrouper;
std::vector<uint64_t> idGrouperBitmap;
auto hnswReader = dynamic_cast<const faiss::IndexHNSW*>(indexReader->index);
if(hnswReader!= nullptr && parentIdsJ != nullptr) {
// Setting the ef_search value equal to what was provided during index creation. SearchParametersHNSW has a default
// value of ef_search = 16 which will then be used.
hnswParams.efSearch = hnswReader->hnsw.efSearch;
idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap);
hnswParams.grp = idGrouper.get();
if(hnswReader!= nullptr) {
// Query param ef_search supersedes ef_search provided during index setting.
hnswParams.efSearch = knn_jni::commons::getIntegerMethodParameter(env, jniUtil, methodParams, EF_SEARCH, hnswReader->hnsw.efSearch);
if (parentIdsJ != nullptr) {
idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap);
hnswParams.grp = idGrouper.get();
}
searchParameters = &hnswParams;
}
try {
Expand Down
11 changes: 5 additions & 6 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,11 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex(JNIEnv * env, jclass cls,
jlong indexPointerJ,
jfloatArray queryVectorJ,
jfloat radiusJ, jint maxResultWindowJ,
jintArray parentIdsJ)
jfloat radiusJ, jobject methodParamsJ,
jint maxResultWindowJ, jintArray parentIdsJ)
{
try {
return knn_jni::faiss_wrapper::RangeSearch(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, maxResultWindowJ, parentIdsJ);
return knn_jni::faiss_wrapper::RangeSearch(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, methodParamsJ, maxResultWindowJ, parentIdsJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
Expand All @@ -255,12 +255,11 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSea
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndexWithFilter(JNIEnv * env, jclass cls,
jlong indexPointerJ,
jfloatArray queryVectorJ,
jfloat radiusJ, jint maxResultWindowJ,
jfloat radiusJ, jobject methodParamsJ, jint maxResultWindowJ,
jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ)
{
try {
return knn_jni::faiss_wrapper::RangeSearchWithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ,
maxResultWindowJ, filterIdsJ, filterIdsTypeJ, parentIdsJ);
return knn_jni::faiss_wrapper::RangeSearchWithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, methodParamsJ, maxResultWindowJ, filterIdsJ, filterIdsTypeJ, parentIdsJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
Expand Down
15 changes: 10 additions & 5 deletions jni/tests/faiss_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,11 @@ TEST(FaissRangeSearchQueryIndexTest, BasicAssertions) {
faiss::MetricType metricType = faiss::METRIC_L2;
std::string method = "HNSW32,Flat";

int efSearch = 20;
std::unordered_map<std::string, jobject> methodParams;
methodParams[knn_jni::EF_SEARCH] = reinterpret_cast<jobject>(&efSearch);
auto methodParamsJ = reinterpret_cast<jobject>(&methodParams);

// Define query data
int numQueries = 100;
std::vector<std::vector<float>> queries;
Expand Down Expand Up @@ -819,7 +824,7 @@ TEST(FaissRangeSearchQueryIndexTest, BasicAssertions) {
knn_jni::faiss_wrapper::RangeSearch(
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&createdIndexWithData),
reinterpret_cast<jfloatArray>(&query), rangeSearchRadius, maxResultWindow, nullptr)));
reinterpret_cast<jfloatArray>(&query), rangeSearchRadius, methodParamsJ, maxResultWindow, nullptr)));

// assert result size is not 0
ASSERT_NE(0, results->size());
Expand Down Expand Up @@ -874,7 +879,7 @@ TEST(FaissRangeSearchQueryIndexTest_WhenHitMaxWindowResult, BasicAssertions){
knn_jni::faiss_wrapper::RangeSearch(
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&createdIndexWithData),
reinterpret_cast<jfloatArray>(&query), rangeSearchRadius, maxResultWindow, nullptr)));
reinterpret_cast<jfloatArray>(&query), rangeSearchRadius, nullptr, maxResultWindow, nullptr)));

// assert result size is not 0
ASSERT_NE(0, results->size());
Expand Down Expand Up @@ -940,7 +945,7 @@ TEST(FaissRangeSearchQueryIndexTestWithFilterTest, BasicAssertions) {
knn_jni::faiss_wrapper::RangeSearchWithFilter(
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&createdIndexWithData),
reinterpret_cast<jfloatArray>(&query), rangeSearchRadius, maxResultWindow,
reinterpret_cast<jfloatArray>(&query), rangeSearchRadius, nullptr, maxResultWindow,
reinterpret_cast<jlongArray>(&bitmap), 0, nullptr)));

// assert result size is not 0
Expand Down Expand Up @@ -1015,7 +1020,7 @@ TEST(FaissRangeSearchQueryIndexTestWithParentFilterTest, BasicAssertions) {
knn_jni::faiss_wrapper::RangeSearchWithFilter(
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&createdIndexWithData),
reinterpret_cast<jfloatArray>(&query), rangeSearchRadius, maxResultWindow, nullptr, 0,
reinterpret_cast<jfloatArray>(&query), rangeSearchRadius, nullptr, maxResultWindow, nullptr, 0,
reinterpret_cast<jintArray>(&parentIds))));

// assert result size is not 0
Expand All @@ -1032,4 +1037,4 @@ TEST(FaissRangeSearchQueryIndexTestWithParentFilterTest, BasicAssertions) {
delete it;
}
}
}
}
Loading

0 comments on commit e7d7ec8

Please sign in to comment.