Skip to content

Commit

Permalink
Add binary format support with HNSW method in Faiss Engine
Browse files Browse the repository at this point in the history
Signed-off-by: Heemin Kim <[email protected]>
  • Loading branch information
heemin32 committed Jul 9, 2024
1 parent 5139b16 commit b2a1332
Show file tree
Hide file tree
Showing 70 changed files with 2,144 additions and 318 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
* Adds dynamic query parameter ef_search [#1783](https://github.com/opensearch-project/k-NN/pull/1783)
* Adds dynamic query parameter ef_search in radial search faiss engine [#1790](https://github.com/opensearch-project/k-NN/pull/1790)
* Add binary format support with HNSW method in Faiss Engine [#1781](https://github.com/opensearch-project/k-NN/pull/1781)
### Enhancements
### Bug Fixes
### Infrastructure
Expand Down
2 changes: 1 addition & 1 deletion jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ namespace knn_jni {
jbyteArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ);

// Free the index located in memory at indexPointerJ
void Free(jlong indexPointer);
void Free(jlong indexPointer, jboolean isBinaryIndexJ);

// Free shared index state in memory at shareIndexStatePointerJ
void FreeSharedIndexState(jlong shareIndexStatePointerJ);
Expand Down
4 changes: 2 additions & 2 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBin
/*
* Class: org_opensearch_knn_jni_FaissService
* Method: free
* Signature: (J)V
* Signature: (JZ)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free
(JNIEnv *, jclass, jlong);
(JNIEnv *, jclass, jlong, jboolean);

/*
* Class: org_opensearch_knn_jni_FaissService
Expand Down
13 changes: 10 additions & 3 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,9 +531,16 @@ jobjectArray knn_jni::faiss_wrapper::QueryBinaryIndex_WithFilter(knn_jni::JNIUti
return results;
}

void knn_jni::faiss_wrapper::Free(jlong indexPointer) {
auto *indexWrapper = reinterpret_cast<faiss::Index*>(indexPointer);
delete indexWrapper;
void knn_jni::faiss_wrapper::Free(jlong indexPointer, jboolean isBinaryIndexJ) {
bool isBinaryIndex = static_cast<bool>(isBinaryIndexJ);
if (isBinaryIndex) {
auto *indexWrapper = reinterpret_cast<faiss::IndexBinary*>(indexPointer);
delete indexWrapper;
}
else {
auto *indexWrapper = reinterpret_cast<faiss::Index*>(indexPointer);
delete indexWrapper;
}
}

void knn_jni::faiss_wrapper::FreeSharedIndexState(jlong shareIndexStatePointerJ) {
Expand Down
4 changes: 2 additions & 2 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,10 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBin

}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free(JNIEnv * env, jclass cls, jlong indexPointerJ)
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free(JNIEnv * env, jclass cls, jlong indexPointerJ, jboolean isBinaryIndexJ)
{
try {
return knn_jni::faiss_wrapper::Free(indexPointerJ);
return knn_jni::faiss_wrapper::Free(indexPointerJ, isBinaryIndexJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
Expand Down
16 changes: 15 additions & 1 deletion jni/tests/faiss_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,21 @@ TEST(FaissFreeTest, BasicAssertions) {
test_util::FaissCreateIndex(dim, method, metricType));

// Free created index --> memory check should catch failure
knn_jni::faiss_wrapper::Free(reinterpret_cast<jlong>(createdIndex));
knn_jni::faiss_wrapper::Free(reinterpret_cast<jlong>(createdIndex), JNI_FALSE);
}


TEST(FaissBinaryFreeTest, BasicAssertions) {
// Define the data
int dim = 8;
std::string method = "BHNSW32";

// Create the index
faiss::IndexBinary *createdIndex(
test_util::FaissCreateBinaryIndex(dim, method));

// Free created index --> memory check should catch failure
knn_jni::faiss_wrapper::Free(reinterpret_cast<jlong>(createdIndex), JNI_TRUE);
}

TEST(FaissInitLibraryTest, BasicAssertions) {
Expand Down
33 changes: 26 additions & 7 deletions src/main/java/org/opensearch/knn/common/KNNValidationUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ public static void validateFloatVectorValue(float value) {
*
* @param value float value in byte range
*/
public static void validateByteVectorValue(float value) {
public static void validateByteVectorValue(float value, final VectorDataType dataType) {
validateFloatVectorValue(value);
if (value % 1 != 0) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"[%s] field was set as [%s] in index mapping. But, KNN vector values are floats instead of byte integers",
VECTOR_DATA_TYPE_FIELD,
VectorDataType.BYTE.getValue()
dataType.getValue()
)

);
Expand All @@ -60,7 +60,7 @@ public static void validateByteVectorValue(float value) {
Locale.ROOT,
"[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [%d, %d]",
VECTOR_DATA_TYPE_FIELD,
VectorDataType.BYTE.getValue(),
dataType.getValue(),
Byte.MIN_VALUE,
Byte.MAX_VALUE
)
Expand All @@ -71,13 +71,32 @@ public static void validateByteVectorValue(float value) {
/**
* Validate if the given vector size matches with the dimension provided in mapping.
*
* For binary index, the dimension is 8 times larger than vector size because 8 bits is packed into single byte
*
* @param dimension dimension of vector
* @param vectorSize size of the vector
* @param dataType vector data type
*/
public static void validateVectorDimension(int dimension, int vectorSize) {
if (dimension != vectorSize) {
String errorMessage = String.format(Locale.ROOT, "Vector dimension mismatch. Expected: %d, Given: %d", dimension, vectorSize);
throw new IllegalArgumentException(errorMessage);
public static void validateVectorDimension(final int dimension, final int vectorSize, final VectorDataType dataType) {
int actualDimension = VectorDataType.BINARY == dataType ? vectorSize * Byte.SIZE : vectorSize;
if (dimension != actualDimension) {
if (VectorDataType.BINARY == dataType) {
String errorMessage = String.format(
Locale.ROOT,
"The dimension of the binary vector must be 8 times the length of the provided vector. Expected: %d, Given: %d",
dimension,
actualDimension
);
throw new IllegalArgumentException(errorMessage);
} else {
String errorMessage = String.format(
Locale.ROOT,
"Vector dimension mismatch. Expected: %d, Given: %d",
dimension,
actualDimension
);
throw new IllegalArgumentException(errorMessage);
}
}
}
}
14 changes: 14 additions & 0 deletions src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_SEARCH;
import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;
import static org.opensearch.knn.index.util.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX;

public class IndexUtil {

Expand Down Expand Up @@ -310,4 +311,17 @@ public static boolean isSharedIndexStateRequired(KNNEngine knnEngine, String mod
}
return JNIService.isSharedIndexStateRequired(indexAddr, knnEngine);
}

/**
* Tell if it is binary index or not
*
* @param knnEngine knn engine associated with an index
* @param parameters parameters associated with an index
* @return true if it is binary index
*/
public static boolean isBinaryIndex(KNNEngine knnEngine, Map<String, Object> parameters) {
return KNNEngine.FAISS == knnEngine
&& parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null
&& parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString().startsWith(FAISS_BINARY_INDEX_DESCRIPTION_PREFIX);
}
}
19 changes: 12 additions & 7 deletions src/main/java/org/opensearch/knn/index/KNNMethodContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import org.opensearch.Version;
import org.opensearch.common.ValidationException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -48,21 +50,24 @@ public class KNNMethodContext implements ToXContentFragment, Writeable {

private static KNNMethodContext defaultInstance = null;

/**
* This is used only for testing
* @return default KNNMethodContext for testing
*/
public static synchronized KNNMethodContext getDefault() {
if (defaultInstance == null) {
defaultInstance = new KNNMethodContext(
KNNEngine.DEFAULT,
SpaceType.DEFAULT,
new MethodComponentContext(METHOD_HNSW, Collections.emptyMap())
);
MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap());
methodComponentContext.setIndexVersion(Version.CURRENT);
defaultInstance = new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT, methodComponentContext);
}
return defaultInstance;
}

@NonNull
private final KNNEngine knnEngine;
@NonNull
private final SpaceType spaceType;
@Setter
private SpaceType spaceType;
@NonNull
private final MethodComponentContext methodComponentContext;

Expand Down Expand Up @@ -131,7 +136,7 @@ public static KNNMethodContext parse(Object in) {
Map<String, Object> methodMap = (Map<String, Object>) in;

KNNEngine engine = KNNEngine.DEFAULT; // Get or default
SpaceType spaceType = SpaceType.DEFAULT; // Get or default
SpaceType spaceType = SpaceType.UNDEFINED; // Get or default
String name = "";
Map<String, Object> parameters = new HashMap<>();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index;

import org.apache.lucene.index.VectorSimilarityFunction;
import org.opensearch.knn.plugin.script.KNNScoringUtil;

/**
* Wrapper class of VectorSimilarityFunction to support more function than what Lucene provides
*/
public enum KNNVectorSimilarityFunction {
EUCLIDEAN(VectorSimilarityFunction.EUCLIDEAN),
DOT_PRODUCT(VectorSimilarityFunction.DOT_PRODUCT),
COSINE(VectorSimilarityFunction.COSINE),
MAXIMUM_INNER_PRODUCT(VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT),
HAMMING(null) {
@Override
public float compare(float[] v1, float[] v2) {
throw new IllegalStateException("Hamming space is not supported with float vectors");
}

@Override
public float compare(byte[] v1, byte[] v2) {
return 1.0f / (1 + KNNScoringUtil.calculateHammingBit(v1, v2));
}

@Override
public VectorSimilarityFunction getVectorSimilarityFunction() {
throw new IllegalStateException("VectorSimilarityFunction is not available for Hamming space");
}
};

private final VectorSimilarityFunction vectorSimilarityFunction;

KNNVectorSimilarityFunction(final VectorSimilarityFunction vectorSimilarityFunction) {
this.vectorSimilarityFunction = vectorSimilarityFunction;
}

public VectorSimilarityFunction getVectorSimilarityFunction() {
return vectorSimilarityFunction;
}

public float compare(float[] var1, float[] var2) {
return vectorSimilarityFunction.compare(var1, var2);
}

public float compare(byte[] var1, byte[] var2) {
return vectorSimilarityFunction.compare(var1, var2);
}
}
Loading

0 comments on commit b2a1332

Please sign in to comment.