Skip to content

Commit

Permalink
Add binary format support with HNSW method in Faiss Engine
Browse files Browse the repository at this point in the history
Signed-off-by: Heemin Kim <[email protected]>
  • Loading branch information
heemin32 committed Jul 10, 2024
1 parent 5139b16 commit 612a612
Show file tree
Hide file tree
Showing 73 changed files with 2,187 additions and 419 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
* Adds dynamic query parameter ef_search [#1783](https://github.com/opensearch-project/k-NN/pull/1783)
* Adds dynamic query parameter ef_search in radial search faiss engine [#1790](https://github.com/opensearch-project/k-NN/pull/1790)
* Add binary format support with HNSW method in Faiss Engine [#1781](https://github.com/opensearch-project/k-NN/pull/1781)
### Enhancements
### Bug Fixes
### Infrastructure
Expand Down
2 changes: 1 addition & 1 deletion jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ namespace knn_jni {
jbyteArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ);

// Free the index located in memory at indexPointerJ
void Free(jlong indexPointer);
void Free(jlong indexPointer, jboolean isBinaryIndexJ);

// Free shared index state in memory at shareIndexStatePointerJ
void FreeSharedIndexState(jlong shareIndexStatePointerJ);
Expand Down
4 changes: 2 additions & 2 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBin
/*
* Class: org_opensearch_knn_jni_FaissService
* Method: free
* Signature: (J)V
* Signature: (JZ)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free
(JNIEnv *, jclass, jlong);
(JNIEnv *, jclass, jlong, jboolean);

/*
* Class: org_opensearch_knn_jni_FaissService
Expand Down
13 changes: 10 additions & 3 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,9 +531,16 @@ jobjectArray knn_jni::faiss_wrapper::QueryBinaryIndex_WithFilter(knn_jni::JNIUti
return results;
}

void knn_jni::faiss_wrapper::Free(jlong indexPointer) {
auto *indexWrapper = reinterpret_cast<faiss::Index*>(indexPointer);
delete indexWrapper;
void knn_jni::faiss_wrapper::Free(jlong indexPointer, jboolean isBinaryIndexJ) {
bool isBinaryIndex = static_cast<bool>(isBinaryIndexJ);
if (isBinaryIndex) {
auto *indexWrapper = reinterpret_cast<faiss::IndexBinary*>(indexPointer);
delete indexWrapper;
}
else {
auto *indexWrapper = reinterpret_cast<faiss::Index*>(indexPointer);
delete indexWrapper;
}
}

void knn_jni::faiss_wrapper::FreeSharedIndexState(jlong shareIndexStatePointerJ) {
Expand Down
4 changes: 2 additions & 2 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,10 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBin

}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free(JNIEnv * env, jclass cls, jlong indexPointerJ)
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free(JNIEnv * env, jclass cls, jlong indexPointerJ, jboolean isBinaryIndexJ)
{
try {
return knn_jni::faiss_wrapper::Free(indexPointerJ);
return knn_jni::faiss_wrapper::Free(indexPointerJ, isBinaryIndexJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
Expand Down
16 changes: 15 additions & 1 deletion jni/tests/faiss_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,21 @@ TEST(FaissFreeTest, BasicAssertions) {
test_util::FaissCreateIndex(dim, method, metricType));

// Free created index --> memory check should catch failure
knn_jni::faiss_wrapper::Free(reinterpret_cast<jlong>(createdIndex));
knn_jni::faiss_wrapper::Free(reinterpret_cast<jlong>(createdIndex), JNI_FALSE);
}


TEST(FaissBinaryFreeTest, BasicAssertions) {
// Define the data
int dim = 8;
std::string method = "BHNSW32";

// Create the index
faiss::IndexBinary *createdIndex(
test_util::FaissCreateBinaryIndex(dim, method));

// Free created index --> memory check should catch failure
knn_jni::faiss_wrapper::Free(reinterpret_cast<jlong>(createdIndex), JNI_TRUE);
}

TEST(FaissInitLibraryTest, BasicAssertions) {
Expand Down
33 changes: 26 additions & 7 deletions src/main/java/org/opensearch/knn/common/KNNValidationUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ public static void validateFloatVectorValue(float value) {
*
* @param value float value in byte range
*/
public static void validateByteVectorValue(float value) {
public static void validateByteVectorValue(float value, final VectorDataType dataType) {
validateFloatVectorValue(value);
if (value % 1 != 0) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"[%s] field was set as [%s] in index mapping. But, KNN vector values are floats instead of byte integers",
VECTOR_DATA_TYPE_FIELD,
VectorDataType.BYTE.getValue()
dataType.getValue()
)

);
Expand All @@ -60,7 +60,7 @@ public static void validateByteVectorValue(float value) {
Locale.ROOT,
"[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [%d, %d]",
VECTOR_DATA_TYPE_FIELD,
VectorDataType.BYTE.getValue(),
dataType.getValue(),
Byte.MIN_VALUE,
Byte.MAX_VALUE
)
Expand All @@ -71,13 +71,32 @@ public static void validateByteVectorValue(float value) {
/**
* Validate if the given vector size matches with the dimension provided in mapping.
*
* For binary index, the dimension is 8 times larger than vector size because 8 bits is packed into single byte
*
* @param dimension dimension of vector
* @param vectorSize size of the vector
* @param dataType vector data type
*/
public static void validateVectorDimension(int dimension, int vectorSize) {
if (dimension != vectorSize) {
String errorMessage = String.format(Locale.ROOT, "Vector dimension mismatch. Expected: %d, Given: %d", dimension, vectorSize);
throw new IllegalArgumentException(errorMessage);
public static void validateVectorDimension(final int dimension, final int vectorSize, final VectorDataType dataType) {
int actualDimension = VectorDataType.BINARY == dataType ? vectorSize * Byte.SIZE : vectorSize;
if (dimension != actualDimension) {
if (VectorDataType.BINARY == dataType) {
String errorMessage = String.format(
Locale.ROOT,
"The dimension of the binary vector must be 8 times the length of the provided vector. Expected: %d, Given: %d",
dimension,
actualDimension
);
throw new IllegalArgumentException(errorMessage);
} else {
String errorMessage = String.format(
Locale.ROOT,
"Vector dimension mismatch. Expected: %d, Given: %d",
dimension,
actualDimension
);
throw new IllegalArgumentException(errorMessage);
}
}
}
}
23 changes: 17 additions & 6 deletions src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@

import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES;
import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_SEARCH;
import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;

public class IndexUtil {

Expand Down Expand Up @@ -257,14 +257,14 @@ public static ValidationException validateKnnField(
* @param spaceType Space for this particular segment
* @param knnEngine Engine used for the native library indices being loaded in
* @param indexName Name of OpenSearch index that the segment files belong to
* @param indexDescription Index description of OpenSearch index with faiss that the segment files belong to
* @param vectorDataType Vector data type for this particular segment
* @return load parameters that will be passed to the JNI.
*/
public static Map<String, Object> getParametersAtLoading(
SpaceType spaceType,
KNNEngine knnEngine,
String indexName,
String indexDescription
VectorDataType vectorDataType
) {
Map<String, Object> loadParameters = Maps.newHashMap(ImmutableMap.of(SPACE_TYPE, spaceType.getValue()));

Expand All @@ -273,9 +273,7 @@ public static Map<String, Object> getParametersAtLoading(
if (KNNEngine.NMSLIB.equals(knnEngine)) {
loadParameters.put(HNSW_ALGO_EF_SEARCH, KNNSettings.getEfSearchParam(indexName));
}
if (KNNEngine.FAISS.equals(knnEngine)) {
loadParameters.put(INDEX_DESCRIPTION_PARAMETER, indexDescription);
}
loadParameters.put(VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue());

return Collections.unmodifiableMap(loadParameters);
}
Expand Down Expand Up @@ -310,4 +308,17 @@ public static boolean isSharedIndexStateRequired(KNNEngine knnEngine, String mod
}
return JNIService.isSharedIndexStateRequired(indexAddr, knnEngine);
}

/**
* Tell if it is binary index or not
*
* @param knnEngine knn engine associated with an index
* @param parameters parameters associated with an index
* @return true if it is binary index
*/
public static boolean isBinaryIndex(KNNEngine knnEngine, Map<String, Object> parameters) {
return KNNEngine.FAISS == knnEngine
&& parameters.get(VECTOR_DATA_TYPE_FIELD) != null
&& parameters.get(VECTOR_DATA_TYPE_FIELD).toString().equals(VectorDataType.BINARY.getValue());
}
}
12 changes: 6 additions & 6 deletions src/main/java/org/opensearch/knn/index/KNNIndexShard.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.index.memory.NativeMemoryEntryContext;
import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy;
import org.opensearch.knn.index.util.FieldInfoExtractor;
import org.opensearch.knn.index.util.KNNEngine;

import java.io.IOException;
Expand All @@ -37,6 +36,7 @@

import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.IndexUtil.getParametersAtLoading;
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFilePrefix;
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileSuffix;
Expand Down Expand Up @@ -98,7 +98,7 @@ public void warmup() throws IOException {
engineFileContext.getSpaceType(),
KNNEngine.getEngineNameFromPath(engineFileContext.getIndexPath()),
getIndexName(),
engineFileContext.indexDescription
engineFileContext.getVectorDataType()
),
getIndexName(),
engineFileContext.getModelId()
Expand Down Expand Up @@ -182,7 +182,7 @@ List<EngineFileContext> getEngineFileContexts(IndexReader indexReader, KNNEngine
shardPath,
spaceType,
modelId,
FieldInfoExtractor.getIndexDescription(fieldInfo)
VectorDataType.get(fieldInfo.attributes().getOrDefault(VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue()))
)
);
}
Expand All @@ -200,15 +200,15 @@ List<EngineFileContext> getEngineFileContexts(
Path shardPath,
SpaceType spaceType,
String modelId,
String indexDescription
VectorDataType vectorDataType
) {
String prefix = buildEngineFilePrefix(segmentName);
String suffix = buildEngineFileSuffix(fieldName, fileExtension);
return files.stream()
.filter(fileName -> fileName.startsWith(prefix))
.filter(fileName -> fileName.endsWith(suffix))
.map(fileName -> shardPath.resolve(fileName).toString())
.map(fileName -> new EngineFileContext(spaceType, modelId, fileName, indexDescription))
.map(fileName -> new EngineFileContext(spaceType, modelId, fileName, vectorDataType))
.collect(Collectors.toList());
}

Expand All @@ -219,6 +219,6 @@ static class EngineFileContext {
private final SpaceType spaceType;
private final String modelId;
private final String indexPath;
private final String indexDescription;
private final VectorDataType vectorDataType;
}
}
19 changes: 12 additions & 7 deletions src/main/java/org/opensearch/knn/index/KNNMethodContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import org.opensearch.Version;
import org.opensearch.common.ValidationException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -48,21 +50,24 @@ public class KNNMethodContext implements ToXContentFragment, Writeable {

private static KNNMethodContext defaultInstance = null;

/**
* This is used only for testing
* @return default KNNMethodContext for testing
*/
public static synchronized KNNMethodContext getDefault() {
if (defaultInstance == null) {
defaultInstance = new KNNMethodContext(
KNNEngine.DEFAULT,
SpaceType.DEFAULT,
new MethodComponentContext(METHOD_HNSW, Collections.emptyMap())
);
MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap());
methodComponentContext.setIndexVersion(Version.CURRENT);
defaultInstance = new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT, methodComponentContext);
}
return defaultInstance;
}

@NonNull
private final KNNEngine knnEngine;
@NonNull
private final SpaceType spaceType;
@Setter
private SpaceType spaceType;
@NonNull
private final MethodComponentContext methodComponentContext;

Expand Down Expand Up @@ -131,7 +136,7 @@ public static KNNMethodContext parse(Object in) {
Map<String, Object> methodMap = (Map<String, Object>) in;

KNNEngine engine = KNNEngine.DEFAULT; // Get or default
SpaceType spaceType = SpaceType.DEFAULT; // Get or default
SpaceType spaceType = SpaceType.UNDEFINED; // Get or default
String name = "";
Map<String, Object> parameters = new HashMap<>();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index;

import org.apache.lucene.index.VectorSimilarityFunction;
import org.opensearch.knn.plugin.script.KNNScoringUtil;

/**
* Wrapper class of VectorSimilarityFunction to support more function than what Lucene provides
*/
public enum KNNVectorSimilarityFunction {
EUCLIDEAN(VectorSimilarityFunction.EUCLIDEAN),
DOT_PRODUCT(VectorSimilarityFunction.DOT_PRODUCT),
COSINE(VectorSimilarityFunction.COSINE),
MAXIMUM_INNER_PRODUCT(VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT),
HAMMING(null) {
@Override
public float compare(float[] v1, float[] v2) {
throw new IllegalStateException("Hamming space is not supported with float vectors");
}

@Override
public float compare(byte[] v1, byte[] v2) {
return 1.0f / (1 + KNNScoringUtil.calculateHammingBit(v1, v2));
}

@Override
public VectorSimilarityFunction getVectorSimilarityFunction() {
throw new IllegalStateException("VectorSimilarityFunction is not available for Hamming space");
}
};

private final VectorSimilarityFunction vectorSimilarityFunction;

KNNVectorSimilarityFunction(final VectorSimilarityFunction vectorSimilarityFunction) {
this.vectorSimilarityFunction = vectorSimilarityFunction;
}

public VectorSimilarityFunction getVectorSimilarityFunction() {
return vectorSimilarityFunction;
}

public float compare(float[] var1, float[] var2) {
return vectorSimilarityFunction.compare(var1, var2);
}

public float compare(byte[] var1, byte[] var2) {
return vectorSimilarityFunction.compare(var1, var2);
}
}
Loading

0 comments on commit 612a612

Please sign in to comment.