diff --git a/CHANGELOG.md b/CHANGELOG.md index 70f787cff..5ec910ccd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features * Adds dynamic query parameter ef_search [#1783](https://github.com/opensearch-project/k-NN/pull/1783) * Adds dynamic query parameter ef_search in radial search faiss engine [#1790](https://github.com/opensearch-project/k-NN/pull/1790) +* Add binary format support with HNSW method in Faiss Engine [#1781](https://github.com/opensearch-project/k-NN/pull/1781) ### Enhancements ### Bug Fixes ### Infrastructure diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 5ac17cfd1..2b9bc2c76 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -81,7 +81,7 @@ namespace knn_jni { jbyteArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ); // Free the index located in memory at indexPointerJ - void Free(jlong indexPointer); + void Free(jlong indexPointer, jboolean isBinaryIndexJ); // Free shared index state in memory at shareIndexStatePointerJ void FreeSharedIndexState(jlong shareIndexStatePointerJ); diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index 3d6aef45c..7cc071ff3 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -110,10 +110,10 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBin /* * Class: org_opensearch_knn_jni_FaissService * Method: free - * Signature: (J)V + * Signature: (JZ)V */ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free - (JNIEnv *, jclass, jlong); + (JNIEnv *, jclass, jlong, jboolean); /* * Class: org_opensearch_knn_jni_FaissService diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index c4c6e18eb..9abb2357f 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -531,9 +531,16 @@ jobjectArray knn_jni::faiss_wrapper::QueryBinaryIndex_WithFilter(knn_jni::JNIUti return results; } -void knn_jni::faiss_wrapper::Free(jlong indexPointer) { - auto *indexWrapper = reinterpret_cast(indexPointer); - delete indexWrapper; +void knn_jni::faiss_wrapper::Free(jlong indexPointer, jboolean isBinaryIndexJ) { + bool isBinaryIndex = static_cast(isBinaryIndexJ); + if (isBinaryIndex) { + auto *indexWrapper = reinterpret_cast(indexPointer); + delete indexWrapper; + } + else { + auto *indexWrapper = reinterpret_cast(indexPointer); + delete indexWrapper; + } } void knn_jni::faiss_wrapper::FreeSharedIndexState(jlong shareIndexStatePointerJ) { diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 5f9c83ea8..6e447b034 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -179,10 +179,10 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBin } -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free(JNIEnv * env, jclass cls, jlong indexPointerJ) +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free(JNIEnv * env, jclass cls, jlong indexPointerJ, jboolean isBinaryIndexJ) { try { - return knn_jni::faiss_wrapper::Free(indexPointerJ); + return knn_jni::faiss_wrapper::Free(indexPointerJ, isBinaryIndexJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index c6663a19a..36bcea491 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -596,7 +596,21 @@ TEST(FaissFreeTest, BasicAssertions) { test_util::FaissCreateIndex(dim, method, metricType)); // Free created index --> memory check should catch failure - knn_jni::faiss_wrapper::Free(reinterpret_cast(createdIndex)); + knn_jni::faiss_wrapper::Free(reinterpret_cast(createdIndex), JNI_FALSE); +} + + +TEST(FaissBinaryFreeTest, BasicAssertions) { + // Define the data + int dim = 8; + std::string method = "BHNSW32"; + + // Create the index + faiss::IndexBinary *createdIndex( + test_util::FaissCreateBinaryIndex(dim, method)); + + // Free created index --> memory check should catch failure + knn_jni::faiss_wrapper::Free(reinterpret_cast(createdIndex), JNI_TRUE); } TEST(FaissInitLibraryTest, BasicAssertions) { diff --git a/src/main/java/org/opensearch/knn/common/KNNValidationUtil.java b/src/main/java/org/opensearch/knn/common/KNNValidationUtil.java index ca8e1459a..59ca8993a 100644 --- a/src/main/java/org/opensearch/knn/common/KNNValidationUtil.java +++ b/src/main/java/org/opensearch/knn/common/KNNValidationUtil.java @@ -41,7 +41,7 @@ public static void validateFloatVectorValue(float value) { * * @param value float value in byte range */ - public static void validateByteVectorValue(float value) { + public static void validateByteVectorValue(float value, final VectorDataType dataType) { validateFloatVectorValue(value); if (value % 1 != 0) { throw new IllegalArgumentException( @@ -49,7 +49,7 @@ public static void validateByteVectorValue(float value) { Locale.ROOT, "[%s] field was set as [%s] in index mapping. But, KNN vector values are floats instead of byte integers", VECTOR_DATA_TYPE_FIELD, - VectorDataType.BYTE.getValue() + dataType.getValue() ) ); @@ -60,7 +60,7 @@ public static void validateByteVectorValue(float value) { Locale.ROOT, "[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [%d, %d]", VECTOR_DATA_TYPE_FIELD, - VectorDataType.BYTE.getValue(), + dataType.getValue(), Byte.MIN_VALUE, Byte.MAX_VALUE ) @@ -71,13 +71,32 @@ public static void validateByteVectorValue(float value) { /** * Validate if the given vector size matches with the dimension provided in mapping. * + * For binary index, the dimension is 8 times larger than vector size because 8 bits is packed into single byte + * * @param dimension dimension of vector * @param vectorSize size of the vector + * @param dataType vector data type */ - public static void validateVectorDimension(int dimension, int vectorSize) { - if (dimension != vectorSize) { - String errorMessage = String.format(Locale.ROOT, "Vector dimension mismatch. Expected: %d, Given: %d", dimension, vectorSize); - throw new IllegalArgumentException(errorMessage); + public static void validateVectorDimension(final int dimension, final int vectorSize, final VectorDataType dataType) { + int actualDimension = VectorDataType.BINARY == dataType ? vectorSize * Byte.SIZE : vectorSize; + if (dimension != actualDimension) { + if (VectorDataType.BINARY == dataType) { + String errorMessage = String.format( + Locale.ROOT, + "The dimension of the binary vector must be 8 times the length of the provided vector. Expected: %d, Given: %d", + dimension, + actualDimension + ); + throw new IllegalArgumentException(errorMessage); + } else { + String errorMessage = String.format( + Locale.ROOT, + "Vector dimension mismatch. Expected: %d, Given: %d", + dimension, + actualDimension + ); + throw new IllegalArgumentException(errorMessage); + } } } } diff --git a/src/main/java/org/opensearch/knn/index/IndexUtil.java b/src/main/java/org/opensearch/knn/index/IndexUtil.java index ef26e0b17..9a2468332 100644 --- a/src/main/java/org/opensearch/knn/index/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/IndexUtil.java @@ -37,6 +37,7 @@ import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; +import static org.opensearch.knn.index.util.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; public class IndexUtil { @@ -310,4 +311,17 @@ public static boolean isSharedIndexStateRequired(KNNEngine knnEngine, String mod } return JNIService.isSharedIndexStateRequired(indexAddr, knnEngine); } + + /** + * Tell if it is binary index or not + * + * @param knnEngine knn engine associated with an index + * @param parameters parameters associated with an index + * @return true if it is binary index + */ + public static boolean isBinaryIndex(KNNEngine knnEngine, Map parameters) { + return KNNEngine.FAISS == knnEngine + && parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null + && parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString().startsWith(FAISS_BINARY_INDEX_DESCRIPTION_PREFIX); + } } diff --git a/src/main/java/org/opensearch/knn/index/KNNMethodContext.java b/src/main/java/org/opensearch/knn/index/KNNMethodContext.java index ce48b06be..ba7a7509c 100644 --- a/src/main/java/org/opensearch/knn/index/KNNMethodContext.java +++ b/src/main/java/org/opensearch/knn/index/KNNMethodContext.java @@ -14,6 +14,8 @@ import lombok.AllArgsConstructor; import lombok.Getter; import lombok.NonNull; +import lombok.Setter; +import org.opensearch.Version; import org.opensearch.common.ValidationException; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -48,13 +50,15 @@ public class KNNMethodContext implements ToXContentFragment, Writeable { private static KNNMethodContext defaultInstance = null; + /** + * This is used only for testing + * @return default KNNMethodContext for testing + */ public static synchronized KNNMethodContext getDefault() { if (defaultInstance == null) { - defaultInstance = new KNNMethodContext( - KNNEngine.DEFAULT, - SpaceType.DEFAULT, - new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()) - ); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); + methodComponentContext.setIndexVersion(Version.CURRENT); + defaultInstance = new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT, methodComponentContext); } return defaultInstance; } @@ -62,7 +66,8 @@ public static synchronized KNNMethodContext getDefault() { @NonNull private final KNNEngine knnEngine; @NonNull - private final SpaceType spaceType; + @Setter + private SpaceType spaceType; @NonNull private final MethodComponentContext methodComponentContext; @@ -131,7 +136,7 @@ public static KNNMethodContext parse(Object in) { Map methodMap = (Map) in; KNNEngine engine = KNNEngine.DEFAULT; // Get or default - SpaceType spaceType = SpaceType.DEFAULT; // Get or default + SpaceType spaceType = SpaceType.UNDEFINED; // Get or default String name = ""; Map parameters = new HashMap<>(); diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorSimilarityFunction.java b/src/main/java/org/opensearch/knn/index/KNNVectorSimilarityFunction.java new file mode 100644 index 000000000..7eca6287c --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/KNNVectorSimilarityFunction.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.opensearch.knn.plugin.script.KNNScoringUtil; + +/** + * Wrapper class of VectorSimilarityFunction to support more function than what Lucene provides + */ +public enum KNNVectorSimilarityFunction { + EUCLIDEAN(VectorSimilarityFunction.EUCLIDEAN), + DOT_PRODUCT(VectorSimilarityFunction.DOT_PRODUCT), + COSINE(VectorSimilarityFunction.COSINE), + MAXIMUM_INNER_PRODUCT(VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT), + HAMMING(null) { + @Override + public float compare(float[] v1, float[] v2) { + throw new IllegalStateException("Hamming space is not supported with float vectors"); + } + + @Override + public float compare(byte[] v1, byte[] v2) { + return 1.0f / (1 + KNNScoringUtil.calculateHammingBit(v1, v2)); + } + + @Override + public VectorSimilarityFunction getVectorSimilarityFunction() { + throw new IllegalStateException("VectorSimilarityFunction is not available for Hamming space"); + } + }; + + private final VectorSimilarityFunction vectorSimilarityFunction; + + KNNVectorSimilarityFunction(final VectorSimilarityFunction vectorSimilarityFunction) { + this.vectorSimilarityFunction = vectorSimilarityFunction; + } + + public VectorSimilarityFunction getVectorSimilarityFunction() { + return vectorSimilarityFunction; + } + + public float compare(float[] var1, float[] var2) { + return vectorSimilarityFunction.compare(var1, var2); + } + + public float compare(byte[] var1, byte[] var2) { + return vectorSimilarityFunction.compare(var1, var2); + } +} diff --git a/src/main/java/org/opensearch/knn/index/SpaceType.java b/src/main/java/org/opensearch/knn/index/SpaceType.java index 240bfbe91..a65c4bb4c 100644 --- a/src/main/java/org/opensearch/knn/index/SpaceType.java +++ b/src/main/java/org/opensearch/knn/index/SpaceType.java @@ -12,7 +12,6 @@ package org.opensearch.knn.index; import java.util.Locale; -import org.apache.lucene.index.VectorSimilarityFunction; import java.util.HashSet; import java.util.Set; @@ -26,6 +25,19 @@ * nmslib calls the inner_product space "negdotprod". This translation should take place in the nmslib's jni layer. */ public enum SpaceType { + // This undefined space type is used to indicate that space type is not provided by user + // Later, we need to assign a default value based on data type + UNDEFINED("undefined") { + @Override + public float scoreTranslation(final float rawScore) { + throw new IllegalStateException("Unsupported method"); + } + + @Override + public void validateVectorDataType(VectorDataType vectorDataType) { + throw new IllegalStateException("Unsupported method"); + } + }, L2("l2") { @Override public float scoreTranslation(float rawScore) { @@ -33,8 +45,8 @@ public float scoreTranslation(float rawScore) { } @Override - public VectorSimilarityFunction getVectorSimilarityFunction() { - return VectorSimilarityFunction.EUCLIDEAN; + public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() { + return KNNVectorSimilarityFunction.EUCLIDEAN; } @Override @@ -52,8 +64,8 @@ public float scoreTranslation(float rawScore) { } @Override - public VectorSimilarityFunction getVectorSimilarityFunction() { - return VectorSimilarityFunction.COSINE; + public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() { + return KNNVectorSimilarityFunction.COSINE; } @Override @@ -104,8 +116,8 @@ public float scoreTranslation(float rawScore) { } @Override - public VectorSimilarityFunction getVectorSimilarityFunction() { - return VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; + public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() { + return KNNVectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; } }, HAMMING_BIT("hammingbit") { @@ -113,9 +125,29 @@ public VectorSimilarityFunction getVectorSimilarityFunction() { public float scoreTranslation(float rawScore) { return 1 / (1 + rawScore); } + + @Override + public void validateVectorDataType(VectorDataType vectorDataType) { + if (VectorDataType.BINARY != vectorDataType) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Space type [%s] is not supported with [%s] data type", + getValue(), + vectorDataType.getValue() + ) + ); + } + } + + @Override + public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() { + return KNNVectorSimilarityFunction.HAMMING; + } }; public static SpaceType DEFAULT = L2; + public static SpaceType DEFAULT_BINARY = HAMMING_BIT; private final String value; @@ -126,12 +158,12 @@ public float scoreTranslation(float rawScore) { public abstract float scoreTranslation(float rawScore); /** - * Get VectorSimilarityFunction that maps to this SpaceType + * Get KNNVectorSimilarityFunction that maps to this SpaceType * - * @return VectorSimilarityFunction + * @return KNNVectorSimilarityFunction */ - public VectorSimilarityFunction getVectorSimilarityFunction() { - throw new UnsupportedOperationException(String.format("Space [%s] does not have a vector similarity function", getValue())); + public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() { + throw new UnsupportedOperationException(String.format("Space [%s] does not have a knn vector similarity function", getValue())); } /** @@ -152,6 +184,19 @@ public void validateVector(float[] vector) { // do nothing } + /** + * Validate if given vector data type is supported by this space type + * + * @param vectorDataType the given vector data type + */ + public void validateVectorDataType(VectorDataType vectorDataType) { + if (VectorDataType.FLOAT != vectorDataType && VectorDataType.BYTE != vectorDataType) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Space type [%s] is not supported with [%s] data type", getValue(), vectorDataType.getValue()) + ); + } + } + /** * Get space type name in engine * diff --git a/src/main/java/org/opensearch/knn/index/VectorDataType.java b/src/main/java/org/opensearch/knn/index/VectorDataType.java index 98b767f8d..8add84609 100644 --- a/src/main/java/org/opensearch/knn/index/VectorDataType.java +++ b/src/main/java/org/opensearch/knn/index/VectorDataType.java @@ -24,11 +24,25 @@ import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; /** - * Enum contains data_type of vectors and right now only supported for lucene engine in k-NN plugin. - * We have two vector data_types, one is float (default) and the other one is byte. + * Enum contains data_type of vectors + * Lucene supports byte and float data type + * NMSLib supports only float data type + * Faiss supports binary and float data type */ @AllArgsConstructor public enum VectorDataType { + BINARY("binary") { + + @Override + public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) { + throw new IllegalStateException("Unsupported method"); + } + + @Override + public float[] getVectorFromBytesRef(BytesRef binaryValue) { + throw new IllegalStateException("Unsupported method"); + } + }, BYTE("byte") { @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index 2e5c78076..349ad2f4d 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -16,6 +16,10 @@ import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.transfer.VectorTransfer; +import org.opensearch.knn.index.codec.transfer.VectorTransferByte; +import org.opensearch.knn.index.codec.transfer.VectorTransferFloat; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.util.KNNCodecUtil; @@ -56,6 +60,7 @@ import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName; import static org.opensearch.knn.index.codec.util.KNNCodecUtil.calculateArraySize; +import static org.opensearch.knn.index.util.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; /** * This class writes the KNN docvalues to the segments @@ -106,11 +111,18 @@ private KNNEngine getKNNEngine(@NonNull FieldInfo field) { return KNNEngine.getEngine(engineName); } + private VectorTransfer getVectorTransfer(FieldInfo field) { + if (VectorDataType.BINARY.getValue().equalsIgnoreCase(field.attributes().get(KNNConstants.VECTOR_DATA_TYPE_FIELD))) { + return new VectorTransferByte(KNNSettings.getVectorStreamingMemoryLimit().getBytes()); + } + return new VectorTransferFloat(KNNSettings.getVectorStreamingMemoryLimit().getBytes()); + } + public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh) throws IOException { // Get values to be indexed BinaryDocValues values = valuesProducer.getBinary(field); - KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values); + KNNCodecUtil.Pair pair = KNNCodecUtil.getPair(values, getVectorTransfer(field)); if (pair.getVectorAddress() == 0 || pair.docs.length == 0) { logger.info("Skipping engine index creation as there are no vectors or docs in the segment"); return; @@ -227,6 +239,17 @@ private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pa ); } + // Update index description of Faiss for binary data type + if (KNNEngine.FAISS == knnEngine + && VectorDataType.BINARY.getValue() + .equals(fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue())) + && parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null) { + parameters.put( + KNNConstants.INDEX_DESCRIPTION_PARAMETER, + FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString() + ); + } + // Used to determine how many threads to use when indexing parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransfer.java new file mode 100644 index 000000000..2ab80f776 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransfer.java @@ -0,0 +1,55 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.transfer; + +import lombok.Data; +import org.opensearch.knn.index.codec.util.SerializationMode; + +import java.io.ByteArrayInputStream; + +/** + * Abstract class to transfer vector value from Java to native memory + */ +@Data +public abstract class VectorTransfer { + protected final long vectorsStreamingMemoryLimit; + protected long totalLiveDocs; + protected long vectorsPerTransfer; + protected long vectorAddress; + protected int dimension; + + public VectorTransfer(final long vectorsStreamingMemoryLimit) { + this.vectorsStreamingMemoryLimit = vectorsStreamingMemoryLimit; + this.vectorsPerTransfer = Integer.MIN_VALUE; + } + + /** + * Initialize the transfer + * + * @param totalLiveDocs total number of vectors to be transferred + */ + abstract public void init(final long totalLiveDocs); + + /** + * Transfer a single vector + * + * @param byteStream a vector in byte stream format + */ + abstract public void transfer(final ByteArrayInputStream byteStream); + + /** + * Close the transfer + */ + abstract public void close(); + + /** + * Get serialization mode of given byte stream + * + * @param byteStream byte stream of a vector + * @return serialization mode + */ + abstract public SerializationMode getSerializationMode(final ByteArrayInputStream byteStream); +} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferByte.java b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferByte.java new file mode 100644 index 000000000..fb0f9d470 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferByte.java @@ -0,0 +1,66 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.transfer; + +import org.opensearch.knn.index.codec.util.SerializationMode; +import org.opensearch.knn.jni.JNICommons; + +import java.io.ByteArrayInputStream; +import java.util.ArrayList; +import java.util.List; + +/** + * Vector transfer for byte + */ +public class VectorTransferByte extends VectorTransfer { + private List vectorList; + + public VectorTransferByte(final long vectorsStreamingMemoryLimit) { + super(vectorsStreamingMemoryLimit); + vectorList = new ArrayList<>(); + } + + @Override + public void init(final long totalLiveDocs) { + this.totalLiveDocs = totalLiveDocs; + vectorList.clear(); + } + + @Override + public void transfer(final ByteArrayInputStream byteStream) { + final byte[] vector = byteStream.readAllBytes(); + dimension = vector.length * 8; + if (vectorsPerTransfer == Integer.MIN_VALUE) { + vectorsPerTransfer = (vector.length * totalLiveDocs) / vectorsStreamingMemoryLimit; + // This condition comes if vectorsStreamingMemoryLimit is higher than total number floats to transfer + // Doing this will reduce 1 extra trip to JNI layer. + if (vectorsPerTransfer == 0) { + vectorsPerTransfer = totalLiveDocs; + } + } + + vectorList.add(vector); + if (vectorList.size() == vectorsPerTransfer) { + transfer(); + } + } + + @Override + public void close() { + transfer(); + } + + @Override + public SerializationMode getSerializationMode(final ByteArrayInputStream byteStream) { + return SerializationMode.COLLECTIONS_OF_BYTES; + } + + private void transfer() { + int lengthOfVector = dimension / 8; + vectorAddress = JNICommons.storeByteVectorData(vectorAddress, vectorList.toArray(new byte[][] {}), totalLiveDocs * lengthOfVector); + vectorList.clear(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloat.java b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloat.java new file mode 100644 index 000000000..d5958b375 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloat.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.transfer; + +import org.opensearch.knn.index.codec.util.KNNVectorSerializer; +import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; +import org.opensearch.knn.index.codec.util.SerializationMode; +import org.opensearch.knn.jni.JNICommons; + +import java.io.ByteArrayInputStream; +import java.util.ArrayList; +import java.util.List; + +/** + * Vector transfer for float + */ +public class VectorTransferFloat extends VectorTransfer { + private List vectorList; + + public VectorTransferFloat(final long vectorsStreamingMemoryLimit) { + super(vectorsStreamingMemoryLimit); + vectorList = new ArrayList<>(); + } + + @Override + public void init(final long totalLiveDocs) { + this.totalLiveDocs = totalLiveDocs; + vectorList.clear(); + } + + @Override + public void transfer(final ByteArrayInputStream byteStream) { + final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); + final float[] vector = vectorSerializer.byteToFloatArray(byteStream); + dimension = vector.length; + + if (vectorsPerTransfer == Integer.MIN_VALUE) { + vectorsPerTransfer = (dimension * Float.BYTES * totalLiveDocs) / vectorsStreamingMemoryLimit; + // This condition comes if vectorsStreamingMemoryLimit is higher than total number floats to transfer + // Doing this will reduce 1 extra trip to JNI layer. + if (vectorsPerTransfer == 0) { + vectorsPerTransfer = totalLiveDocs; + } + } + + vectorList.add(vector); + if (vectorList.size() == vectorsPerTransfer) { + transfer(); + } + } + + @Override + public void close() { + transfer(); + } + + @Override + public SerializationMode getSerializationMode(final ByteArrayInputStream byteStream) { + return KNNVectorSerializerFactory.getSerializerModeFromStream(byteStream); + } + + private void transfer() { + vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][] {}), totalLiveDocs * dimension); + vectorList.clear(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index e05962608..68b61a070 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -11,9 +11,8 @@ import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.BytesRef; -import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.codec.KNN80Codec.KNN80BinaryDocValues; -import org.opensearch.knn.jni.JNICommons; +import org.opensearch.knn.index.codec.transfer.VectorTransfer; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -40,55 +39,27 @@ public static final class Pair { @Setter private int dimension; public SerializationMode serializationMode; - } - public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOException { - List vectorList = new ArrayList<>(); + public static KNNCodecUtil.Pair getPair(final BinaryDocValues values, final VectorTransfer vectorTransfer) throws IOException { List docIdList = new ArrayList<>(); - long vectorAddress = 0; - int dimension = 0; SerializationMode serializationMode = SerializationMode.COLLECTION_OF_FLOATS; - - long totalLiveDocs = getTotalLiveDocsCount(values); - long vectorsStreamingMemoryLimit = KNNSettings.getVectorStreamingMemoryLimit().getBytes(); - long vectorsPerTransfer = Integer.MIN_VALUE; - + vectorTransfer.init(getTotalLiveDocsCount(values)); for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) { BytesRef bytesref = values.binaryValue(); try (ByteArrayInputStream byteStream = new ByteArrayInputStream(bytesref.bytes, bytesref.offset, bytesref.length)) { - serializationMode = KNNVectorSerializerFactory.serializerModeFromStream(byteStream); - final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); - final float[] vector = vectorSerializer.byteToFloatArray(byteStream); - dimension = vector.length; - - if (vectorsPerTransfer == Integer.MIN_VALUE) { - vectorsPerTransfer = (dimension * Float.BYTES * totalLiveDocs) / vectorsStreamingMemoryLimit; - // This condition comes if vectorsStreamingMemoryLimit is higher than total number floats to transfer - // Doing this will reduce 1 extra trip to JNI layer. - if (vectorsPerTransfer == 0) { - vectorsPerTransfer = totalLiveDocs; - } - } - if (vectorList.size() == vectorsPerTransfer) { - vectorAddress = JNICommons.storeVectorData( - vectorAddress, - vectorList.toArray(new float[][] {}), - totalLiveDocs * dimension - ); - // We should probably come up with a better way to reuse the vectorList memory which we have - // created. Problem here is doing like this can lead to a lot of list memory which is of no use and - // will be garbage collected later on, but it creates pressure on JVM. We should revisit this. - vectorList = new ArrayList<>(); - } - vectorList.add(vector); + serializationMode = vectorTransfer.getSerializationMode(byteStream); + vectorTransfer.transfer(byteStream); } docIdList.add(doc); } - if (vectorList.isEmpty() == false) { - vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][] {}), totalLiveDocs * dimension); - } - return new KNNCodecUtil.Pair(docIdList.stream().mapToInt(Integer::intValue).toArray(), vectorAddress, dimension, serializationMode); + vectorTransfer.close(); + return new KNNCodecUtil.Pair( + docIdList.stream().mapToInt(Integer::intValue).toArray(), + vectorTransfer.getVectorAddress(), + vectorTransfer.getDimension(), + serializationMode + ); } public static long calculateArraySize(int numVectors, int vectorLength, SerializationMode serializationMode) { @@ -102,7 +73,7 @@ public static long calculateArraySize(int numVectors, int vectorLength, Serializ vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER; } return vectorsSize; - } else { + } else if (serializationMode == SerializationMode.COLLECTION_OF_FLOATS) { int vectorSize = vectorLength * FLOAT_BYTE_SIZE; if (vectorSize % JAVA_ROUNDING_NUMBER != 0) { vectorSize += vectorSize % JAVA_ROUNDING_NUMBER; @@ -112,6 +83,18 @@ public static long calculateArraySize(int numVectors, int vectorLength, Serializ vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER; } return vectorsSize; + } else if (serializationMode == SerializationMode.COLLECTIONS_OF_BYTES) { + int vectorSize = vectorLength; + if (vectorSize % JAVA_ROUNDING_NUMBER != 0) { + vectorSize += vectorSize % JAVA_ROUNDING_NUMBER; + } + int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE); + if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) { + vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER; + } + return vectorsSize; + } else { + throw new IllegalStateException("Unreachable code"); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerFactory.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerFactory.java index 5c1e4ca9b..23a829dfd 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerFactory.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerFactory.java @@ -52,11 +52,11 @@ public static KNNVectorSerializer getDefaultSerializer() { } public static KNNVectorSerializer getSerializerByStreamContent(final ByteArrayInputStream byteStream) { - final SerializationMode serializationMode = serializerModeFromStream(byteStream); + final SerializationMode serializationMode = getSerializerModeFromStream(byteStream); return getSerializerBySerializationMode(serializationMode); } - static SerializationMode serializerModeFromStream(ByteArrayInputStream byteStream) { + public static SerializationMode getSerializerModeFromStream(ByteArrayInputStream byteStream) { int numberOfAvailableBytesInStream = byteStream.available(); if (numberOfAvailableBytesInStream < ARRAY_HEADER_OFFSET) { return getSerializerOrThrowError(numberOfAvailableBytesInStream, COLLECTION_OF_FLOATS); diff --git a/src/main/java/org/opensearch/knn/index/codec/util/SerializationMode.java b/src/main/java/org/opensearch/knn/index/codec/util/SerializationMode.java index 1fb82cbfe..f3a32f53e 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/SerializationMode.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/SerializationMode.java @@ -7,5 +7,6 @@ public enum SerializationMode { ARRAY, - COLLECTION_OF_FLOATS + COLLECTION_OF_FLOATS, + COLLECTIONS_OF_BYTES } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 7e697fed7..793cb0bfc 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -73,7 +73,7 @@ import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.clipVectorValueToFP16Range; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.deserializeStoredVector; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFP16VectorValue; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithEngine; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataType; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithKnnIndexSetting; /** @@ -125,7 +125,7 @@ public static class Builder extends ParametrizedFieldMapper.Builder { * data_type which defines the datatype of the vector values. This is an optional parameter and * this is right now only relevant for lucene engine. The default value is float. */ - private final Parameter vectorDataType = new Parameter<>( + protected final Parameter vectorDataType = new Parameter<>( VECTOR_DATA_TYPE_FIELD, false, () -> DEFAULT_VECTOR_DATA_TYPE_FIELD, @@ -233,13 +233,16 @@ public KNNVectorFieldMapper build(BuilderContext context) { // the mappings, setting the index settings will have no impact. final KNNMethodContext knnMethodContext = this.knnMethodContext.getValue(); - validateMaxDimensions(knnMethodContext); + setDefaultSpaceType(knnMethodContext, vectorDataType.getValue()); + validateSpaceType(knnMethodContext, vectorDataType.getValue()); + validateDimensions(knnMethodContext, vectorDataType.getValue()); final MultiFields multiFieldsBuilder = this.multiFieldsBuilder.build(this, context); final CopyTo copyToBuilder = copyTo.build(); final Explicit ignoreMalformed = ignoreMalformed(context); final Map metaValue = meta.getValue(); if (knnMethodContext != null) { + validateVectorDataType(knnMethodContext, vectorDataType.getValue()); knnMethodContext.getMethodComponentContext().setIndexVersion(indexCreatedVersion); final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType( buildFullName(context), @@ -265,11 +268,6 @@ public KNNVectorFieldMapper build(BuilderContext context) { return new LuceneFieldMapper(createLuceneFieldMapperInput); } - // Validates and throws exception if data_type field is set in the index mapping - // using any VectorDataType (other than float, which is default) because other - // VectorDataTypes are only supported for lucene engine. - validateVectorDataTypeWithEngine(vectorDataType); - return new MethodFieldMapper( name, mappedFieldType, @@ -342,7 +340,29 @@ public KNNVectorFieldMapper build(BuilderContext context) { ); } - private KNNEngine validateMaxDimensions(final KNNMethodContext knnMethodContext) { + private void setDefaultSpaceType(final KNNMethodContext knnMethodContext, final VectorDataType vectorDataType) { + if (knnMethodContext == null) { + return; + } + + if (SpaceType.UNDEFINED == knnMethodContext.getSpaceType()) { + if (VectorDataType.BINARY == vectorDataType) { + knnMethodContext.setSpaceType(SpaceType.DEFAULT_BINARY); + } else { + knnMethodContext.setSpaceType(SpaceType.DEFAULT); + } + } + } + + private void validateSpaceType(final KNNMethodContext knnMethodContext, final VectorDataType vectorDataType) { + if (knnMethodContext == null) { + return; + } + + knnMethodContext.getSpaceType().validateVectorDataType(vectorDataType); + } + + private KNNEngine validateDimensions(final KNNMethodContext knnMethodContext, final VectorDataType dataType) { final KNNEngine knnEngine; if (knnMethodContext != null) { knnEngine = knnMethodContext.getKnnEngine(); @@ -358,6 +378,9 @@ private KNNEngine validateMaxDimensions(final KNNMethodContext knnMethodContext) ) ); } + if (VectorDataType.BINARY == dataType && dimension.getValue() % 8 != 0) { + throw new IllegalArgumentException("Dimension should be multiply of 8 for binary vector data type"); + } return knnEngine; } } @@ -578,9 +601,19 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s validateIfKNNPluginEnabled(); validateIfCircuitBreakerIsNotTriggered(); + spaceType.validateVectorDataType(vectorDataType); + + if (VectorDataType.BINARY == vectorDataType) { + Optional bytesArrayOptional = getBytesFromContext(context, dimension, vectorDataType); - if (VectorDataType.BYTE == vectorDataType) { - Optional bytesArrayOptional = getBytesFromContext(context, dimension); + if (bytesArrayOptional.isEmpty()) { + return; + } + final byte[] array = bytesArrayOptional.get(); + spaceType.validateVector(array); + context.doc().addAll(getFieldsForByteVector(array, fieldType)); + } else if (VectorDataType.BYTE == vectorDataType) { + Optional bytesArrayOptional = getBytesFromContext(context, dimension, vectorDataType); if (bytesArrayOptional.isEmpty()) { return; @@ -663,31 +696,30 @@ void validateIfKNNPluginEnabled() { // Returns an optional array of byte values where each value in the vector is parsed as a float and validated // if it is a finite number without any decimals and within the byte range of [-128 to 127]. - Optional getBytesFromContext(ParseContext context, int dimension) throws IOException { + Optional getBytesFromContext(ParseContext context, int dimension, VectorDataType dataType) throws IOException { context.path().add(simpleName()); ArrayList vector = new ArrayList<>(); XContentParser.Token token = context.parser().currentToken(); - float value; if (token == XContentParser.Token.START_ARRAY) { token = context.parser().nextToken(); while (token != XContentParser.Token.END_ARRAY) { - value = context.parser().floatValue(); - validateByteVectorValue(value); + float value = context.parser().floatValue(); + validateByteVectorValue(value, dataType); vector.add((byte) value); token = context.parser().nextToken(); } } else if (token == XContentParser.Token.VALUE_NUMBER) { - value = context.parser().floatValue(); - validateByteVectorValue(value); + float value = context.parser().floatValue(); + validateByteVectorValue(value, dataType); vector.add((byte) value); context.parser().nextToken(); } else if (token == XContentParser.Token.VALUE_NULL) { context.path().remove(); return Optional.empty(); } - validateVectorDimension(dimension, vector.size()); + validateVectorDimension(dimension, vector.size(), dataType); byte[] array = new byte[vector.size()]; int i = 0; for (Byte f : vector) { @@ -749,7 +781,7 @@ Optional getFloatsFromContext(ParseContext context, int dimension, Meth context.path().remove(); return Optional.empty(); } - validateVectorDimension(dimension, vector.size()); + validateVectorDimension(dimension, vector.size(), vectorDataType); float[] array = new float[vector.size()]; int i = 0; diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java index 8bd7eb6f2..03f369d0d 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java @@ -18,6 +18,7 @@ import org.apache.lucene.index.DocValuesType; import org.apache.lucene.util.BytesRef; import org.opensearch.index.mapper.ParametrizedFieldMapper; +import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; import org.opensearch.knn.index.util.KNNEngine; @@ -29,11 +30,14 @@ import java.util.Locale; import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; +import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16; import static org.opensearch.knn.common.KNNConstants.FP16_MAX_VALUE; import static org.opensearch.knn.common.KNNConstants.FP16_MIN_VALUE; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue; @@ -90,25 +94,60 @@ public static float clipVectorValueToFP16Range(float value) { } /** - * Validates and throws exception if data_type field is set in the index mapping - * using any VectorDataType (other than float, which is default) because other - * VectorDataTypes are only supported for lucene engine. + * Validates if the vector data type is supported with given method context * - * @param vectorDataType VectorDataType Parameter + * @param methodContext methodContext + * @param vectorDataType vector data type */ - public static void validateVectorDataTypeWithEngine(ParametrizedFieldMapper.Parameter vectorDataType) { - if (VectorDataType.FLOAT == vectorDataType.getValue()) { + public static void validateVectorDataType(KNNMethodContext methodContext, VectorDataType vectorDataType) { + if (VectorDataType.FLOAT == vectorDataType) { return; } - throw new IllegalArgumentException( - String.format( - Locale.ROOT, - "[%s] field with value [%s] is only supported for [%s] engine", - VECTOR_DATA_TYPE_FIELD, - vectorDataType.getValue().getValue(), - LUCENE_NAME - ) - ); + + if (VectorDataType.BYTE == vectorDataType) { + if (KNNEngine.LUCENE == methodContext.getKnnEngine()) { + return; + } else { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "[%s] field with value [%s] is only supported for [%s] engine", + VECTOR_DATA_TYPE_FIELD, + vectorDataType.getValue(), + LUCENE_NAME + ) + ); + } + } + + if (VectorDataType.BINARY == vectorDataType) { + if (KNNEngine.FAISS == methodContext.getKnnEngine()) { + if (METHOD_HNSW.equals(methodContext.getMethodComponentContext().getName())) { + return; + } else { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "[%s] field with value [%s] is only supported for [%s] method", + VECTOR_DATA_TYPE_FIELD, + vectorDataType.getValue(), + METHOD_HNSW + ) + ); + } + } else { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "[%s] field with value [%s] is only supported for [%s] engine", + VECTOR_DATA_TYPE_FIELD, + vectorDataType.getValue(), + FAISS_NAME + ) + ); + } + } + throw new IllegalArgumentException("This line should not be reached"); } /** @@ -131,10 +170,10 @@ public static void validateVectorDataTypeWithKnnIndexSetting( throw new IllegalArgumentException( String.format( Locale.ROOT, - "[%s] field with value [%s] is only supported for [%s] engine", + "[%s] field with value [%s] is not supported for [%s] engine", VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue().getValue(), - LUCENE_NAME + NMSLIB_NAME ) ); } diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index 59f4867dd..b8ba688ff 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -49,7 +49,9 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { vectorDataType = input.getVectorDataType(); this.knnMethod = input.getKnnMethodContext(); - final VectorSimilarityFunction vectorSimilarityFunction = this.knnMethod.getSpaceType().getVectorSimilarityFunction(); + final VectorSimilarityFunction vectorSimilarityFunction = this.knnMethod.getSpaceType() + .getKnnVectorSimilarityFunction() + .getVectorSimilarityFunction(); final int dimension = input.getMappedFieldType().getDimension(); if (dimension > KNNEngine.getMaxDimensionByEngine(KNNEngine.LUCENE)) { diff --git a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java index d2db7fb5a..f09ac1b4c 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java @@ -17,6 +17,7 @@ import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; /** * Field mapper for method definition in mapping @@ -51,6 +52,7 @@ public class MethodFieldMapper extends KNNVectorFieldMapper { this.fieldType.putAttribute(DIMENSION, String.valueOf(dimension)); this.fieldType.putAttribute(SPACE_TYPE, knnMethodContext.getSpaceType().getValue()); + this.fieldType.putAttribute(VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); KNNEngine knnEngine = knnMethodContext.getKnnEngine(); this.fieldType.putAttribute(KNN_ENGINE, knnEngine.getName()); diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java index c8b56436b..b108fb6f0 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java @@ -11,6 +11,7 @@ package org.opensearch.knn.index.memory; +import lombok.Getter; import org.apache.lucene.index.LeafReaderContext; import org.opensearch.knn.index.query.KNNWeight; import org.opensearch.knn.jni.JNICommons; @@ -87,12 +88,17 @@ class IndexAllocation implements NativeMemoryAllocation { private final long memoryAddress; private final int size; private volatile boolean closed; + @Getter private final KNNEngine knnEngine; + @Getter private final String indexPath; + @Getter private final String openSearchIndexName; private final ReadWriteLock readWriteLock; private final WatcherHandle watcherHandle; private final SharedIndexState sharedIndexState; + @Getter + private final boolean isBinaryIndex; /** * Constructor @@ -114,7 +120,7 @@ class IndexAllocation implements NativeMemoryAllocation { String openSearchIndexName, WatcherHandle watcherHandle ) { - this(executorService, memoryAddress, size, knnEngine, indexPath, openSearchIndexName, watcherHandle, null); + this(executorService, memoryAddress, size, knnEngine, indexPath, openSearchIndexName, watcherHandle, null, false); } /** @@ -137,7 +143,8 @@ class IndexAllocation implements NativeMemoryAllocation { String indexPath, String openSearchIndexName, WatcherHandle watcherHandle, - SharedIndexState sharedIndexState + SharedIndexState sharedIndexState, + boolean isBinaryIndex ) { this.executor = executorService; this.closed = false; @@ -149,6 +156,7 @@ class IndexAllocation implements NativeMemoryAllocation { this.size = size; this.watcherHandle = watcherHandle; this.sharedIndexState = sharedIndexState; + this.isBinaryIndex = isBinaryIndex; } @Override @@ -171,7 +179,7 @@ private void cleanup() { // memoryAddress is sometimes initialized to 0. If this is ever the case, freeing will surely fail. if (memoryAddress != 0) { - JNIService.free(memoryAddress, knnEngine); + JNIService.free(memoryAddress, knnEngine, isBinaryIndex); } if (sharedIndexState != null) { @@ -223,33 +231,6 @@ public void writeUnlock() { public int getSizeInKB() { return size; } - - /** - * Getter for k-NN Engine associated with this index allocation. - * - * @return KNNEngine associated with index allocation - */ - public KNNEngine getKnnEngine() { - return knnEngine; - } - - /** - * Getter for the path to the file from which the index was loaded. - * - * @return indexPath to index - */ - public String getIndexPath() { - return indexPath; - } - - /** - * Getter for the OpenSearch index associated with the native index. - * - * @return OpenSearch index name - */ - public String getOpenSearchIndexName() { - return openSearchIndexName; - } } /** diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java index cb7dafdfc..3602dd3c0 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java @@ -113,7 +113,8 @@ public NativeMemoryAllocation.IndexAllocation load(NativeMemoryEntryContext.Inde indexPath.toString(), indexEntryContext.getOpenSearchIndexName(), watcherHandle, - sharedIndexState + sharedIndexState, + IndexUtil.isBinaryIndex(knnEngine, indexEntryContext.getParameters()) ); } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java index 4b875d9a8..1c4ef25e5 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java @@ -20,6 +20,7 @@ import org.apache.lucene.search.Weight; import org.apache.lucene.search.join.BitSetProducer; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.VectorDataType; import java.io.IOException; import java.util.Arrays; @@ -37,9 +38,13 @@ public class KNNQuery extends Query { private final String field; private final float[] queryVector; + @Getter + private final byte[] byteQueryVector; private int k; private Map methodParameters; private final String indexName; + @Getter + private final VectorDataType vectorDataType; @Setter private Query filterQuery; @@ -54,11 +59,7 @@ public KNNQuery( final String indexName, final BitSetProducer parentsFilter ) { - this.field = field; - this.queryVector = queryVector; - this.k = k; - this.indexName = indexName; - this.parentsFilter = parentsFilter; + this(field, queryVector, null, k, indexName, null, parentsFilter, VectorDataType.FLOAT); } public KNNQuery( @@ -68,13 +69,40 @@ public KNNQuery( final String indexName, final Query filterQuery, final BitSetProducer parentsFilter + ) { + this(field, queryVector, null, k, indexName, filterQuery, parentsFilter, VectorDataType.FLOAT); + } + + public KNNQuery( + final String field, + final byte[] byteQueryVector, + final int k, + final String indexName, + final Query filterQuery, + final BitSetProducer parentsFilter, + final VectorDataType vectorDataType + ) { + this(field, null, byteQueryVector, k, indexName, filterQuery, parentsFilter, vectorDataType); + } + + private KNNQuery( + final String field, + final float[] queryVector, + final byte[] byteQueryVector, + final int k, + final String indexName, + final Query filterQuery, + final BitSetProducer parentsFilter, + final VectorDataType vectorDataType ) { this.field = field; this.queryVector = queryVector; + this.byteQueryVector = byteQueryVector; this.k = k; this.indexName = indexName; this.filterQuery = filterQuery; this.parentsFilter = parentsFilter; + this.vectorDataType = vectorDataType; } /** @@ -86,10 +114,7 @@ public KNNQuery( * @param parentsFilter parent filter */ public KNNQuery(String field, float[] queryVector, String indexName, BitSetProducer parentsFilter) { - this.field = field; - this.queryVector = queryVector; - this.indexName = indexName; - this.parentsFilter = parentsFilter; + this(field, queryVector, null, 0, indexName, null, parentsFilter, VectorDataType.FLOAT); } /** @@ -191,6 +216,7 @@ private boolean equalsTo(KNNQuery other) { if (other == this) return true; return Objects.equals(field, other.field) && Arrays.equals(queryVector, other.queryVector) + && Arrays.equals(byteQueryVector, other.byteQueryVector) && Objects.equals(k, other.k) && Objects.equals(methodParameters, other.methodParameters) && Objects.equals(radius, other.radius) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 86d8031bd..80ee5e32c 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -27,7 +27,6 @@ import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.IndexUtil; -import org.opensearch.knn.index.util.EngineSpecificMethodContext; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; @@ -35,6 +34,7 @@ import org.opensearch.knn.index.VectorQueryType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.query.parser.MethodParametersParser; +import org.opensearch.knn.index.util.EngineSpecificMethodContext; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.index.util.QueryContext; import org.opensearch.knn.indices.ModelDao; @@ -554,17 +554,25 @@ protected Query doToQuery(QueryShardContext context) { radius = knnEngine.scoreToRadialThreshold(this.minScore, spaceType); } - if (fieldDimension != vector.length) { + int vectorLength = VectorDataType.BINARY == vectorDataType ? vector.length * Byte.SIZE : vector.length; + if (fieldDimension != vectorLength) { throw new IllegalArgumentException( - String.format("Query vector has invalid dimension: %d. Dimension should be: %d", vector.length, fieldDimension) + String.format("Query vector has invalid dimension: %d. Dimension should be: %d", vectorLength, fieldDimension) ); } byte[] byteVector = new byte[0]; - if (VectorDataType.BYTE == vectorDataType) { + if (VectorDataType.BINARY == vectorDataType) { + byteVector = new byte[vector.length]; + for (int i = 0; i < vector.length; i++) { + validateByteVectorValue(vector[i], knnVectorFieldType.getVectorDataType()); + byteVector[i] = (byte) vector[i]; + } + spaceType.validateVector(byteVector); + } else if (VectorDataType.BYTE == vectorDataType) { byteVector = new byte[vector.length]; for (int i = 0; i < vector.length; i++) { - validateByteVectorValue(vector[i]); + validateByteVectorValue(vector[i], knnVectorFieldType.getVectorDataType()); byteVector[i] = (byte) vector[i]; } spaceType.validateVector(byteVector); @@ -586,7 +594,7 @@ protected Query doToQuery(QueryShardContext context) { .indexName(indexName) .fieldName(this.fieldName) .vector(VectorDataType.FLOAT == vectorDataType ? this.vector : null) - .byteVector(VectorDataType.BYTE == vectorDataType ? byteVector : null) + .byteVector(VectorDataType.BYTE == vectorDataType || VectorDataType.BINARY == vectorDataType ? byteVector : null) .vectorDataType(vectorDataType) .k(this.k) .methodParameters(this.methodParameters) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index 36987c750..af7dad026 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.query; +import com.google.common.annotations.VisibleForTesting; import lombok.extern.log4j.Log4j2; import org.apache.lucene.search.KnnByteVectorQuery; import org.apache.lucene.search.KnnFloatVectorQuery; @@ -30,6 +31,9 @@ public class KNNQueryFactory extends BaseQueryFactory { /** + * Note. This method should be used only for test. + * Should use {@link #create(CreateQueryRequest)} instead. + * * Creates a Lucene query for a particular engine. * * @param knnEngine Engine to create the query for @@ -39,6 +43,7 @@ public class KNNQueryFactory extends BaseQueryFactory { * @param k the number of nearest neighbors to return * @return Lucene Query */ + @VisibleForTesting public static Query create( KNNEngine knnEngine, String indexName, @@ -83,6 +88,7 @@ public static Query create(CreateQueryRequest createQueryRequest) { if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) { final Query validatedFilterQuery = validateFilterQuerySupport(filterQuery, createQueryRequest.getKnnEngine()); + log.debug( "Creating custom k-NN query for index:{}, field:{}, k:{}, filterQuery:{}, efSearch:{}", indexName, @@ -91,15 +97,31 @@ public static Query create(CreateQueryRequest createQueryRequest) { validatedFilterQuery, methodParameters ); - return KNNQuery.builder() - .field(fieldName) - .queryVector(vector) - .indexName(indexName) - .parentsFilter(parentFilter) - .k(k) - .methodParameters(methodParameters) - .filterQuery(validatedFilterQuery) - .build(); + + switch (vectorDataType) { + case BINARY: + return KNNQuery.builder() + .field(fieldName) + .byteQueryVector(byteVector) + .indexName(indexName) + .parentsFilter(parentFilter) + .k(k) + .methodParameters(methodParameters) + .filterQuery(validatedFilterQuery) + .vectorDataType(vectorDataType) + .build(); + default: + return KNNQuery.builder() + .field(fieldName) + .queryVector(vector) + .indexName(indexName) + .parentsFilter(parentFilter) + .k(k) + .methodParameters(methodParameters) + .filterQuery(validatedFilterQuery) + .vectorDataType(vectorDataType) + .build(); + } } Integer requestEfSearch = null; diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index fce8e8e04..2c450ad8a 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -31,11 +31,15 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.memory.NativeMemoryAllocation; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryEntryContext; import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; +import org.opensearch.knn.index.query.filtered.FilteredIdsKNNByteIterator; import org.opensearch.knn.index.query.filtered.FilteredIdsKNNIterator; +import org.opensearch.knn.index.query.filtered.KNNIterator; +import org.opensearch.knn.index.query.filtered.NestedFilteredIdsKNNByteIterator; import org.opensearch.knn.index.query.filtered.NestedFilteredIdsKNNIterator; import org.opensearch.knn.index.util.FieldInfoExtractor; import org.opensearch.knn.index.util.KNNEngine; @@ -285,16 +289,29 @@ private Map doANNSearch(final LeafReaderContext context, final B } int[] parentIds = getParentIdsArray(context); if (knnQuery.getK() > 0) { - results = JNIService.queryIndex( - indexAllocation.getMemoryAddress(), - knnQuery.getQueryVector(), - knnQuery.getK(), - knnQuery.getMethodParameters(), - knnEngine, - filterIds, - filterType.getValue(), - parentIds - ); + if (knnQuery.getVectorDataType() == VectorDataType.BINARY) { + results = JNIService.queryBinaryIndex( + indexAllocation.getMemoryAddress(), + knnQuery.getByteQueryVector(), + knnQuery.getK(), + knnQuery.getMethodParameters(), + knnEngine, + filterIds, + filterType.getValue(), + parentIds + ); + } else { + results = JNIService.queryIndex( + indexAllocation.getMemoryAddress(), + knnQuery.getQueryVector(), + knnQuery.getK(), + knnQuery.getMethodParameters(), + knnEngine, + filterIds, + filterType.getValue(), + parentIds + ); + } } else { results = JNIService.radiusQueryIndex( indexAllocation.getMemoryAddress(), @@ -336,7 +353,7 @@ private Map doExactSearch(final LeafReaderContext leafReaderCont final HitQueue queue = new HitQueue(Math.min(this.knnQuery.getK(), cardinality), true); ScoreDoc topDoc = queue.top(); final Map docToScore = new HashMap<>(); - FilteredIdsKNNIterator iterator = getFilteredKNNIterator(leafReaderContext, filterIdsBitSet); + KNNIterator iterator = getFilteredKNNIterator(leafReaderContext, filterIdsBitSet); int docId; while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { if (iterator.score() > topDoc.score) { @@ -367,21 +384,32 @@ private Map doExactSearch(final LeafReaderContext leafReaderCont return Collections.emptyMap(); } - private FilteredIdsKNNIterator getFilteredKNNIterator(final LeafReaderContext leafReaderContext, final BitSet filterIdsBitSet) - throws IOException { + private KNNIterator getFilteredKNNIterator(final LeafReaderContext leafReaderContext, final BitSet filterIdsBitSet) throws IOException { final SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(leafReaderContext.reader()); final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.getName()); final SpaceType spaceType = getSpaceType(fieldInfo); - return knnQuery.getParentsFilter() == null - ? new FilteredIdsKNNIterator(filterIdsBitSet, knnQuery.getQueryVector(), values, spaceType) - : new NestedFilteredIdsKNNIterator( - filterIdsBitSet, - knnQuery.getQueryVector(), - values, - spaceType, - knnQuery.getParentsFilter().getBitSet(leafReaderContext) - ); + if (VectorDataType.BINARY == knnQuery.getVectorDataType()) { + return knnQuery.getParentsFilter() == null + ? new FilteredIdsKNNByteIterator(filterIdsBitSet, knnQuery.getByteQueryVector(), values, spaceType) + : new NestedFilteredIdsKNNByteIterator( + filterIdsBitSet, + knnQuery.getByteQueryVector(), + values, + spaceType, + knnQuery.getParentsFilter().getBitSet(leafReaderContext) + ); + } else { + return knnQuery.getParentsFilter() == null + ? new FilteredIdsKNNIterator(filterIdsBitSet, knnQuery.getQueryVector(), values, spaceType) + : new NestedFilteredIdsKNNIterator( + filterIdsBitSet, + knnQuery.getQueryVector(), + values, + spaceType, + knnQuery.getParentsFilter().getBitSet(leafReaderContext) + ); + } } private Scorer convertSearchResponseToScorer(final Map docsToScore) throws IOException { diff --git a/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNByteIterator.java b/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNByteIterator.java new file mode 100644 index 000000000..815e621f6 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNByteIterator.java @@ -0,0 +1,79 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.filtered; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.BitSetIterator; +import org.apache.lucene.util.BytesRef; +import org.opensearch.knn.index.SpaceType; + +import java.io.ByteArrayInputStream; +import java.io.IOException; + +/** + * Inspired by DiversifyingChildrenFloatKnnVectorQuery in lucene + * https://github.com/apache/lucene/blob/7b8aece125aabff2823626d5b939abf4747f63a7/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java#L162 + * + * The class is used in KNNWeight to score filtered KNN field by iterating filterIdsArray. + */ +public class FilteredIdsKNNByteIterator implements KNNIterator { + // Array of doc ids to iterate + protected final BitSet filterIdsBitSet; + protected final BitSetIterator bitSetIterator; + protected final byte[] queryVector; + protected final BinaryDocValues binaryDocValues; + protected final SpaceType spaceType; + protected float currentScore = Float.NEGATIVE_INFINITY; + protected int docId; + + public FilteredIdsKNNByteIterator( + final BitSet filterIdsBitSet, + final byte[] queryVector, + final BinaryDocValues binaryDocValues, + final SpaceType spaceType + ) { + this.filterIdsBitSet = filterIdsBitSet; + this.bitSetIterator = new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length()); + this.queryVector = queryVector; + this.binaryDocValues = binaryDocValues; + this.spaceType = spaceType; + this.docId = bitSetIterator.nextDoc(); + } + + /** + * Advance to the next doc and update score value with score of the next doc. + * DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs + * + * @return next doc id + */ + @Override + public int nextDoc() throws IOException { + + if (docId == DocIdSetIterator.NO_MORE_DOCS) { + return DocIdSetIterator.NO_MORE_DOCS; + } + int doc = binaryDocValues.advance(docId); + currentScore = computeScore(); + docId = bitSetIterator.nextDoc(); + return doc; + } + + @Override + public float score() { + return currentScore; + } + + protected float computeScore() throws IOException { + final BytesRef value = binaryDocValues.binaryValue(); + final ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length); + final byte[] vector = byteStream.readAllBytes(); + // Calculates a similarity score between the two vectors with a specified function. Higher similarity + // scores correspond to closer vectors. + return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector); + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIterator.java index a53cb8d60..fb153989a 100644 --- a/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIterator.java @@ -23,7 +23,7 @@ * * The class is used in KNNWeight to score filtered KNN field by iterating filterIdsArray. */ -public class FilteredIdsKNNIterator { +public class FilteredIdsKNNIterator implements KNNIterator { // Array of doc ids to iterate protected final BitSet filterIdsBitSet; protected final BitSetIterator bitSetIterator; @@ -53,6 +53,7 @@ public FilteredIdsKNNIterator( * * @return next doc id */ + @Override public int nextDoc() throws IOException { if (docId == DocIdSetIterator.NO_MORE_DOCS) { @@ -64,6 +65,7 @@ public int nextDoc() throws IOException { return doc; } + @Override public float score() { return currentScore; } @@ -75,6 +77,6 @@ protected float computeScore() throws IOException { final float[] vector = vectorSerializer.byteToFloatArray(byteStream); // Calculates a similarity score between the two vectors with a specified function. Higher similarity // scores correspond to closer vectors. - return spaceType.getVectorSimilarityFunction().compare(queryVector, vector); + return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector); } } diff --git a/src/main/java/org/opensearch/knn/index/query/filtered/KNNIterator.java b/src/main/java/org/opensearch/knn/index/query/filtered/KNNIterator.java new file mode 100644 index 000000000..4a105975a --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/filtered/KNNIterator.java @@ -0,0 +1,14 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.filtered; + +import java.io.IOException; + +public interface KNNIterator { + int nextDoc() throws IOException; + + float score(); +} diff --git a/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNByteIterator.java b/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNByteIterator.java new file mode 100644 index 000000000..80fba1e41 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNByteIterator.java @@ -0,0 +1,61 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.filtered; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BitSet; +import org.opensearch.knn.index.SpaceType; + +import java.io.IOException; + +/** + * This iterator iterates filterIdsArray to score. However, it dedupe docs per each parent doc + * of which ID is set in parentBitSet and only return best child doc with the highest score. + */ +public class NestedFilteredIdsKNNByteIterator extends FilteredIdsKNNByteIterator { + private final BitSet parentBitSet; + + public NestedFilteredIdsKNNByteIterator( + final BitSet filterIdsArray, + final byte[] queryVector, + final BinaryDocValues values, + final SpaceType spaceType, + final BitSet parentBitSet + ) { + super(filterIdsArray, queryVector, values, spaceType); + this.parentBitSet = parentBitSet; + } + + /** + * Advance to the next best child doc per parent and update score with the best score among child docs from the parent. + * DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs + * + * @return next best child doc id + */ + @Override + public int nextDoc() throws IOException { + if (docId == DocIdSetIterator.NO_MORE_DOCS) { + return DocIdSetIterator.NO_MORE_DOCS; + } + + currentScore = Float.NEGATIVE_INFINITY; + int currentParent = parentBitSet.nextSetBit(docId); + int bestChild = -1; + + while (docId != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) { + binaryDocValues.advance(docId); + float score = computeScore(); + if (score > currentScore) { + bestChild = docId; + currentScore = score; + } + docId = bitSetIterator.nextDoc(); + } + + return bestChild; + } +} diff --git a/src/main/java/org/opensearch/knn/index/util/Faiss.java b/src/main/java/org/opensearch/knn/index/util/Faiss.java index 7cf31ba3c..711c206f5 100644 --- a/src/main/java/org/opensearch/knn/index/util/Faiss.java +++ b/src/main/java/org/opensearch/knn/index/util/Faiss.java @@ -55,7 +55,8 @@ /** * Implements NativeLibrary for the faiss native library */ -class Faiss extends NativeLibrary { +public class Faiss extends NativeLibrary { + public static final String FAISS_BINARY_INDEX_DESCRIPTION_PREFIX = "B"; Map> scoreTransform; // TODO: Current version is not really current version. Instead, it encodes information in the file name @@ -246,7 +247,7 @@ class Faiss extends NativeLibrary { ).addParameter(METHOD_PARAMETER_M, "", "").addParameter(METHOD_ENCODER_PARAMETER, ",", "").build()) ) .build() - ).addSpaces(SpaceType.L2, SpaceType.INNER_PRODUCT).build(), + ).addSpaces(SpaceType.UNDEFINED, SpaceType.HAMMING_BIT, SpaceType.L2, SpaceType.INNER_PRODUCT).build(), METHOD_IVF, KNNMethod.Builder.builder( MethodComponent.Builder.builder(METHOD_IVF) @@ -304,7 +305,7 @@ class Faiss extends NativeLibrary { return ((4L * centroids * dimension) / BYTES_PER_KILOBYTES) + 1; }) .build() - ).addSpaces(SpaceType.L2, SpaceType.INNER_PRODUCT).build() + ).addSpaces(SpaceType.UNDEFINED, SpaceType.L2, SpaceType.INNER_PRODUCT).build() ); final static Faiss INSTANCE = new Faiss( diff --git a/src/main/java/org/opensearch/knn/index/util/FieldInfoExtractor.java b/src/main/java/org/opensearch/knn/index/util/FieldInfoExtractor.java index 5ad271969..c46cf1296 100644 --- a/src/main/java/org/opensearch/knn/index/util/FieldInfoExtractor.java +++ b/src/main/java/org/opensearch/knn/index/util/FieldInfoExtractor.java @@ -12,10 +12,13 @@ import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.VectorDataType; import java.io.IOException; import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.index.util.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; /** * Class having methods to extract a value from field info @@ -27,11 +30,18 @@ public static String getIndexDescription(FieldInfo fieldInfo) throws IOException return null; } - return (String) XContentHelper.createParser( + String indexDescription = (String) XContentHelper.createParser( NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, new BytesArray(parameters), MediaTypeRegistry.getDefaultMediaType() ).map().getOrDefault(INDEX_DESCRIPTION_PARAMETER, null); + + if (VectorDataType.BINARY.getValue() + .equals(fieldInfo.attributes().getOrDefault(VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue()))) { + indexDescription = FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + indexDescription; + } + + return indexDescription; } } diff --git a/src/main/java/org/opensearch/knn/index/util/Lucene.java b/src/main/java/org/opensearch/knn/index/util/Lucene.java index d98775f94..b5bbfca75 100644 --- a/src/main/java/org/opensearch/knn/index/util/Lucene.java +++ b/src/main/java/org/opensearch/knn/index/util/Lucene.java @@ -45,7 +45,7 @@ public class Lucene extends JVMLibrary { ) ) .build() - ).addSpaces(SpaceType.L2, SpaceType.COSINESIMIL, SpaceType.INNER_PRODUCT).build() + ).addSpaces(SpaceType.UNDEFINED, SpaceType.L2, SpaceType.COSINESIMIL, SpaceType.INNER_PRODUCT).build() ); // Map that overrides the default distance translations for Lucene, check more details in knn documentation: diff --git a/src/main/java/org/opensearch/knn/index/util/Nmslib.java b/src/main/java/org/opensearch/knn/index/util/Nmslib.java index 7b18ed11d..a068901d3 100644 --- a/src/main/java/org/opensearch/knn/index/util/Nmslib.java +++ b/src/main/java/org/opensearch/knn/index/util/Nmslib.java @@ -47,7 +47,7 @@ class Nmslib extends NativeLibrary { ) ) .build() - ).addSpaces(SpaceType.L2, SpaceType.L1, SpaceType.LINF, SpaceType.COSINESIMIL, SpaceType.INNER_PRODUCT).build() + ).addSpaces(SpaceType.UNDEFINED, SpaceType.L2, SpaceType.L1, SpaceType.LINF, SpaceType.COSINESIMIL, SpaceType.INNER_PRODUCT).build() ); final static Nmslib INSTANCE = new Nmslib(METHODS, Collections.emptyMap(), CURRENT_VERSION, EXTENSION); diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index f718ce6d5..21de90765 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -202,10 +202,29 @@ public static native KNNQueryResult[] queryBinaryIndexWithFilter( int[] parentIds ); + /** + * Query a binary index with filter + * + * @param indexPointer pointer to index in memory + * @param queryVector vector to be used for query + * @param k neighbors to be returned + * @param filterIds list of doc ids to include in the query result + * @param parentIds list of parent doc ids when the knn field is a nested field + * @return KNNQueryResult array of k neighbors + */ + public static native KNNQueryResult[] queryBinaryIndexWithFilter( + long indexPointer, + byte[] queryVector, + int k, + long[] filterIds, + int filterIdsType, + int[] parentIds + ); + /** * Free native memory pointer */ - public static native void free(long indexPointer); + public static native void free(long indexPointer, boolean isBinary); /** * Deallocate memory of the shared index state diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index ed6a169c1..cefd0af53 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -13,7 +13,7 @@ import org.apache.commons.lang.ArrayUtils; import org.opensearch.common.Nullable; -import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.IndexUtil; import org.opensearch.knn.index.query.KNNQueryResult; import org.opensearch.knn.index.util.KNNEngine; @@ -23,8 +23,6 @@ * Service to distribute requests to the proper engine jni service */ public class JNIService { - private static final String FAISS_BINARY_INDEX_PREFIX = "B"; - /** * Create an index for the native library. The memory occupied by the vectorsAddress will be freed up during the * function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer @@ -53,8 +51,7 @@ public static void createIndex( } if (KNNEngine.FAISS == knnEngine) { - if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null - && parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString().startsWith(FAISS_BINARY_INDEX_PREFIX)) { + if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { FaissService.createBinaryIndex(ids, vectorsAddress, dim, indexPath, parameters); } else { FaissService.createIndex(ids, vectorsAddress, dim, indexPath, parameters); @@ -109,8 +106,7 @@ public static long loadIndex(String indexPath, Map parameters, K } if (KNNEngine.FAISS == knnEngine) { - if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null - && parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString().startsWith(FAISS_BINARY_INDEX_PREFIX)) { + if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { return FaissService.loadBinaryIndex(indexPath); } else { return FaissService.loadIndex(indexPath); @@ -260,14 +256,25 @@ public static KNNQueryResult[] queryBinaryIndex( * @param indexPointer location to be freed * @param knnEngine engine to perform free */ - public static void free(long indexPointer, KNNEngine knnEngine) { + public static void free(final long indexPointer, final KNNEngine knnEngine) { + free(indexPointer, knnEngine, false); + } + + /** + * Free native memory pointer + * + * @param indexPointer location to be freed + * @param knnEngine engine to perform free + * @param isBinaryIndex indicate if it is binary index or not + */ + public static void free(final long indexPointer, final KNNEngine knnEngine, final boolean isBinaryIndex) { if (KNNEngine.NMSLIB == knnEngine) { NmslibService.free(indexPointer); return; } if (KNNEngine.FAISS == knnEngine) { - FaissService.free(indexPointer); + FaissService.free(indexPointer, isBinaryIndex); return; } diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java index ebbd7fa9b..fb8ccc4ce 100644 --- a/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java @@ -16,6 +16,7 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.indices.ModelUtil; import org.opensearch.knn.plugin.KNNPlugin; import org.opensearch.knn.plugin.transport.TrainingJobRouterAction; @@ -97,6 +98,9 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr trainingField = parser.textOrNull(); } else if (KNN_METHOD.equals(fieldName) && ensureNotSet(fieldName, knnMethodContext)) { knnMethodContext = KNNMethodContext.parse(parser.map()); + if (SpaceType.UNDEFINED == knnMethodContext.getSpaceType()) { + knnMethodContext.setSpaceType(SpaceType.L2); + } } else if (DIMENSION.equals(fieldName) && ensureNotSet(fieldName, dimension)) { dimension = (Integer) NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false); } else if (MAX_VECTOR_COUNT_PARAMETER.equals(fieldName) && ensureNotSet(fieldName, maximumVectorCount)) { diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java index 889780d7a..5f9efd3cb 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java @@ -116,7 +116,7 @@ public static float[] convertVectorToPrimitive(Object vector, VectorDataType vec for (int i = 0; i < primitiveVector.length; i++) { float value = tmp.get(i).floatValue(); if (VectorDataType.BYTE == vectorDataType) { - validateByteVectorValue(value); + validateByteVectorValue(value, vectorDataType); } primitiveVector[i] = value; } diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java index 84e986faa..01f80b371 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java @@ -59,7 +59,7 @@ private static float[] toFloat(List inputVector, VectorDataType vectorDa for (final Number val : inputVector) { float floatValue = val.floatValue(); if (VectorDataType.BYTE == vectorDataType) { - validateByteVectorValue(floatValue); + validateByteVectorValue(floatValue, vectorDataType); } value[index++] = floatValue; } @@ -195,6 +195,21 @@ public static float calculateHammingBit(Long queryLong, Long inputLong) { return Long.bitCount(queryLong ^ inputLong); } + /** + * This method calculates hamming distance between query vector + * + * @param queryVector query vector + * @param inputVector input vector + * @return hamming distance + */ + public static float calculateHammingBit(byte[] queryVector, byte[] inputVector) { + float distance = 0; + for (int i = 0; i < inputVector.length; i++) { + distance += Integer.bitCount(queryVector[i] ^ inputVector[i]); + } + return distance; + } + /** * This method calculates L1 distance between query vector * and input vector diff --git a/src/test/java/org/opensearch/knn/common/KNNValidationUtilTests.java b/src/test/java/org/opensearch/knn/common/KNNValidationUtilTests.java new file mode 100644 index 000000000..56e462fc1 --- /dev/null +++ b/src/test/java/org/opensearch/knn/common/KNNValidationUtilTests.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.common; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.VectorDataType; + +import static org.hamcrest.Matchers.containsString; + +public class KNNValidationUtilTests extends KNNTestCase { + public void testValidateVectorDimension_whenBinary_thenVectorSizeShouldBeEightTimesLarger() { + int vectorLength = randomInt(100); + Exception ex = expectThrows( + IllegalArgumentException.class, + () -> KNNValidationUtil.validateVectorDimension(vectorLength, vectorLength, VectorDataType.BINARY) + ); + assertThat( + ex.getMessage(), + containsString("The dimension of the binary vector must be 8 times the length of the provided vector.") + ); + + // Expect no exception + KNNValidationUtil.validateVectorDimension(vectorLength * Byte.SIZE, vectorLength, VectorDataType.BINARY); + } + + public void testValidateVectorDimension_whenNonBinary_thenVectorSizeShouldBeSameAsDimension() { + int dimension = randomInt(100); + VectorDataType vectorDataType = randomInt(1) == 0 ? VectorDataType.FLOAT : VectorDataType.BYTE; + Exception ex = expectThrows( + IllegalArgumentException.class, + () -> KNNValidationUtil.validateVectorDimension(dimension, dimension + 1, vectorDataType) + ); + assertThat(ex.getMessage(), containsString("Vector dimension mismatch")); + + // Expect no exception + KNNValidationUtil.validateVectorDimension(dimension, dimension, vectorDataType); + } +} diff --git a/src/test/java/org/opensearch/knn/index/IndexUtilTests.java b/src/test/java/org/opensearch/knn/index/IndexUtilTests.java index e6c3e96ee..7287a6182 100644 --- a/src/test/java/org/opensearch/knn/index/IndexUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/IndexUtilTests.java @@ -23,11 +23,13 @@ import org.opensearch.common.ValidationException; import org.opensearch.common.settings.Settings; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.jni.JNIService; +import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -242,4 +244,16 @@ public void testIsShareableStateContainedInIndex_whenJNIIsSharedIndexStateRequir KNNEngine knnEngine = KNNEngine.FAISS; assertTrue(IndexUtil.isSharedIndexStateRequired(knnEngine, modelId, TEST_INDEX_ADDRESS)); } + + public void testIsBinaryIndex_whenBinary_thenTrue() { + Map binaryIndexParams = new HashMap<>(); + binaryIndexParams.put(KNNConstants.INDEX_DESCRIPTION_PARAMETER, "BHNSW"); + assertTrue(IndexUtil.isBinaryIndex(KNNEngine.FAISS, binaryIndexParams)); + } + + public void testIsBinaryIndex_whenNonBinary_thenFalse() { + Map nonBinaryIndexParams = new HashMap<>(); + nonBinaryIndexParams.put(KNNConstants.INDEX_DESCRIPTION_PARAMETER, "HNSW"); + assertFalse(IndexUtil.isBinaryIndex(KNNEngine.FAISS, nonBinaryIndexParams)); + } } diff --git a/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java b/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java index 1330e8da0..cb294bc3d 100644 --- a/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java @@ -290,7 +290,7 @@ public void testParse_valid() throws IOException { KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); assertEquals(KNNEngine.DEFAULT, knnMethodContext.getKnnEngine()); - assertEquals(SpaceType.DEFAULT, knnMethodContext.getSpaceType()); + assertEquals(SpaceType.UNDEFINED, knnMethodContext.getSpaceType()); assertEquals(methodName, knnMethodContext.getMethodComponentContext().getName()); assertTrue(knnMethodContext.getMethodComponentContext().getParameters().isEmpty()); diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorSimilarityFunctionTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorSimilarityFunctionTests.java new file mode 100644 index 000000000..691941dc3 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/KNNVectorSimilarityFunctionTests.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index; + +import junit.framework.TestCase; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.opensearch.knn.plugin.script.KNNScoringUtil; + +import java.util.Set; + +import static org.apache.lucene.tests.util.LuceneTestCase.expectThrows; +import static org.opensearch.knn.index.KNNVectorSimilarityFunction.COSINE; +import static org.opensearch.knn.index.KNNVectorSimilarityFunction.DOT_PRODUCT; +import static org.opensearch.knn.index.KNNVectorSimilarityFunction.EUCLIDEAN; +import static org.opensearch.knn.index.KNNVectorSimilarityFunction.HAMMING; +import static org.opensearch.knn.index.KNNVectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; + +public class KNNVectorSimilarityFunctionTests extends TestCase { + private static final Set FUNCTION_SET_BACKED_BY_LUCENE = Set.of( + EUCLIDEAN, + DOT_PRODUCT, + COSINE, + MAXIMUM_INNER_PRODUCT + ); + + public void testFunctions_whenBackedByLucene_thenSameAsLucene() { + float[] f1 = new float[] { 1.5f, 2.5f, 3.5f, 4.5f, 5.5f }; + float[] f2 = new float[] { 6.5f, 7.5f, 8.5f, 09.5f, 10.5f }; + byte[] b1 = new byte[] { 1, 2, 3 }; + byte[] b2 = new byte[] { 4, 5, 6 }; + for (KNNVectorSimilarityFunction function : KNNVectorSimilarityFunction.values()) { + if (FUNCTION_SET_BACKED_BY_LUCENE.contains(function) == false) { + continue; + } + assertEquals(VectorSimilarityFunction.valueOf(function.name()), function.getVectorSimilarityFunction()); + assertEquals(function.getVectorSimilarityFunction().compare(f1, f2), function.compare(f1, f2)); + assertEquals(function.getVectorSimilarityFunction().compare(b1, b2), function.compare(b1, b2)); + } + } + + public void testFunctions_whenHamming_thenFloatVectorIsNotSupported() { + float[] f1 = new float[] { 1.5f, 2.5f, 3.5f, 4.5f, 5.5f }; + float[] f2 = new float[] { 6.5f, 7.5f, 8.5f, 09.5f, 10.5f }; + + Exception ex = expectThrows(IllegalStateException.class, () -> HAMMING.compare(f1, f2)); + assertTrue(ex.getMessage().contains("not supported")); + } + + public void testFunctions_whenHamming_thenReturnCorrectScore() { + byte[] b1 = new byte[] { 1, 2, 3 }; + byte[] b2 = new byte[] { 4, 5, 6 }; + assertEquals(1.0f / (1 + KNNScoringUtil.calculateHammingBit(b1, b2)), HAMMING.compare(b1, b2)); + } +} diff --git a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java index e05b90360..261034e71 100644 --- a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java +++ b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java @@ -10,7 +10,6 @@ import lombok.SneakyThrows; import org.apache.commons.lang.math.RandomUtils; import org.apache.hc.core5.http.io.entity.EntityUtils; -import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.VectorUtil; import org.junit.After; import org.opensearch.client.Response; @@ -64,14 +63,14 @@ public class LuceneEngineIT extends KNNRestTestCase { private static final float[][] TEST_QUERY_VECTORS = { { 1.0f, 1.0f, 1.0f }, { 2.0f, 2.0f, 2.0f }, { 3.0f, 3.0f, 3.0f } }; - private static final Map> VECTOR_SIMILARITY_TO_SCORE = ImmutableMap.of( - VectorSimilarityFunction.EUCLIDEAN, + private static final Map> VECTOR_SIMILARITY_TO_SCORE = ImmutableMap.of( + KNNVectorSimilarityFunction.EUCLIDEAN, (similarity) -> 1 / (1 + similarity), - VectorSimilarityFunction.DOT_PRODUCT, + KNNVectorSimilarityFunction.DOT_PRODUCT, (similarity) -> (1 + similarity) / 2, - VectorSimilarityFunction.COSINE, + KNNVectorSimilarityFunction.COSINE, (similarity) -> (1 + similarity) / 2, - VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT, + KNNVectorSimilarityFunction.MAXIMUM_INNER_PRODUCT, (similarity) -> similarity <= 0 ? 1 / (1 - similarity) : similarity + 1 ); private static final String DIMENSION_FIELD_NAME = "dimension"; @@ -520,7 +519,7 @@ private void validateQueries(SpaceType spaceType, String fieldName, Map> expected = Map.of( + SpaceType.UNDEFINED, + Collections.emptySet(), + SpaceType.L2, + Set.of(VectorDataType.FLOAT, VectorDataType.BYTE), + SpaceType.COSINESIMIL, + Set.of(VectorDataType.FLOAT, VectorDataType.BYTE), + SpaceType.L1, + Set.of(VectorDataType.FLOAT, VectorDataType.BYTE), + SpaceType.LINF, + Set.of(VectorDataType.FLOAT, VectorDataType.BYTE), + SpaceType.INNER_PRODUCT, + Set.of(VectorDataType.FLOAT, VectorDataType.BYTE), + SpaceType.HAMMING_BIT, + Set.of(VectorDataType.BINARY) + ); + + for (SpaceType spaceType : SpaceType.values()) { + for (VectorDataType vectorDataType : VectorDataType.values()) { + if (expected.get(spaceType).isEmpty()) { + Exception ex = expectThrows(IllegalStateException.class, () -> spaceType.validateVectorDataType(vectorDataType)); + assertTrue(ex.getMessage().contains("Unsupported method")); + continue; + } + + if (expected.get(spaceType).contains(vectorDataType)) { + spaceType.validateVectorDataType(vectorDataType); + } else { + Exception ex = expectThrows(IllegalArgumentException.class, () -> spaceType.validateVectorDataType(vectorDataType)); + assertTrue(ex.getMessage().contains("is not supported")); + } + } + } + } } diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java index e55a3be42..cc8b86572 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java @@ -34,6 +34,7 @@ import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; @@ -282,10 +283,10 @@ public void testByteVectorDataTypeWithLegacyFieldMapperKnnIndexSetting() { .contains( String.format( Locale.ROOT, - "[%s] field with value [%s] is only supported for [%s] engine", + "[%s] field with value [%s] is not supported for [%s] engine", VECTOR_DATA_TYPE_FIELD, VectorDataType.BYTE.getValue(), - LUCENE_NAME + NMSLIB_NAME ) ) ); diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java index 19270717d..8e6b6d7f7 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java @@ -13,8 +13,10 @@ import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.apache.lucene.util.BytesRef; import org.junit.Assert; import org.opensearch.knn.KNNTestCase; @@ -106,4 +108,17 @@ private void createKNNByteVectorDocument(Directory directory) throws IOException writer.commit(); writer.close(); } + + public void testCreateKnnVectorFieldType_whenBinary_thenException() { + Exception ex = expectThrows( + IllegalStateException.class, + () -> VectorDataType.BINARY.createKnnVectorFieldType(1, VectorSimilarityFunction.EUCLIDEAN) + ); + assertTrue(ex.getMessage().contains("Unsupported method")); + } + + public void testGetVectorFromBytesRef_whenBinary_thenException() { + Exception ex = expectThrows(IllegalStateException.class, () -> VectorDataType.BINARY.getVectorFromBytesRef(new BytesRef())); + assertTrue(ex.getMessage().contains("Unsupported method")); + } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java index f7c9f3eb8..4e3231894 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java @@ -26,6 +26,7 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; @@ -62,7 +63,9 @@ import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.index.KNNSettings.MODEL_CACHE_SIZE_LIMIT_SETTING; +import static org.opensearch.knn.index.codec.KNNCodecTestUtil.assertBinaryIndexLoadableByEngine; import static org.opensearch.knn.index.codec.KNNCodecTestUtil.assertFileInCorrectLocation; import static org.opensearch.knn.index.codec.KNNCodecTestUtil.assertLoadableByEngine; import static org.opensearch.knn.index.codec.KNNCodecTestUtil.assertValidFooter; @@ -146,7 +149,8 @@ public void testAddKNNBinaryField_noVectors() throws IOException { Long initialMergeSize = KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue(); Long initialMergeDocs = KNNGraphValue.MERGE_TOTAL_DOCS.getValue(); KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, null); - knn80DocValuesConsumer.addKNNBinaryField(null, randomVectorDocValuesProducer, true, true); + FieldInfo fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("test-field").build(); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfo, randomVectorDocValuesProducer, true, true); assertEquals(initialGraphIndexRequests, KNNCounter.GRAPH_INDEX_REQUESTS.getCount()); assertEquals(initialRefreshOperations, KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); assertEquals(initialMergeOperations, KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); @@ -329,6 +333,69 @@ public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); } + public void testAddKNNBinaryField_whenFaissBinary_thenAdded() throws IOException { + String segmentName = String.format("test_segment%s", randomAlphaOfLength(4)); + int docsInSegment = 100; + String fieldName = String.format("test_field%s", randomAlphaOfLength(4)); + + KNNEngine knnEngine = KNNEngine.FAISS; + SpaceType spaceType = SpaceType.HAMMING_BIT; + VectorDataType dataType = VectorDataType.BINARY; + int dimension = 16; + + SegmentInfo segmentInfo = KNNCodecTestUtil.segmentInfoBuilder() + .directory(directory) + .segmentName(segmentName) + .docsInSegment(docsInSegment) + .codec(codec) + .build(); + + KNNMethodContext knnMethodContext = new KNNMethodContext( + knnEngine, + spaceType, + new MethodComponentContext(METHOD_HNSW, ImmutableMap.of(METHOD_PARAMETER_M, 16, METHOD_PARAMETER_EF_CONSTRUCTION, 512)) + ); + knnMethodContext.getMethodComponentContext().setIndexVersion(Version.CURRENT); + + String parameterString = XContentFactory.jsonBuilder().map(knnEngine.getMethodAsMap(knnMethodContext)).toString(); + + FieldInfo[] fieldInfoArray = new FieldInfo[] { + KNNCodecTestUtil.FieldInfoBuilder.builder(fieldName) + .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") + .addAttribute(KNNConstants.KNN_ENGINE, knnEngine.getName()) + .addAttribute(KNNConstants.SPACE_TYPE, spaceType.getValue()) + .addAttribute(VECTOR_DATA_TYPE_FIELD, dataType.getValue()) + .addAttribute(KNNConstants.PARAMETERS, parameterString) + .build() }; + + FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); + SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); + + long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); + long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); + + // Add documents to the field + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); + RandomVectorDocValuesProducer randomVectorDocValuesProducer = new RandomVectorDocValuesProducer(docsInSegment, dimension); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true, true); + + // The document should be created in the correct location + String expectedFile = KNNCodecUtil.buildEngineFileName(segmentName, knnEngine.getVersion(), fieldName, knnEngine.getExtension()); + assertFileInCorrectLocation(state, expectedFile); + + // The footer should be valid + assertValidFooter(state.directory, expectedFile); + + // The document should be readable by faiss + assertBinaryIndexLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension); + + // The graph creation statistics should be updated + assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); + assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); + assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); + assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); + } + public void testAddKNNBinaryField_fromModel_faiss() throws IOException, ExecutionException, InterruptedException { // Generate a trained faiss model KNNEngine knnEngine = KNNEngine.FAISS; diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java index fe8200375..d978e7210 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java @@ -49,6 +49,7 @@ import static com.carrotsearch.randomizedtesting.RandomizedTest.randomFloat; import static org.junit.Assert.assertTrue; +import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; import static org.opensearch.test.OpenSearchTestCase.randomByteArrayOfLength; @@ -343,6 +344,27 @@ public static void assertLoadableByEngine( JNIService.free(indexPtr, knnEngine); } + public static void assertBinaryIndexLoadableByEngine( + SegmentWriteState state, + String fileName, + KNNEngine knnEngine, + SpaceType spaceType, + int dimension + ) { + String filePath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), fileName) + .toString(); + long indexPtr = JNIService.loadIndex( + filePath, + Maps.newHashMap(ImmutableMap.of(SPACE_TYPE, spaceType.getValue(), INDEX_DESCRIPTION_PARAMETER, "BHNSW32")), + knnEngine + ); + int k = 2; + byte[] queryVector = new byte[dimension]; + KNNQueryResult[] results = JNIService.queryBinaryIndex(indexPtr, queryVector, k, null, knnEngine, null, 0, null); + assertTrue(results.length > 0); + JNIService.free(indexPtr, knnEngine); + } + public static float[][] getRandomVectors(int count, int dimension) { float[][] data = new float[count][dimension]; for (int i = 0; i < count; i++) { diff --git a/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferByteTests.java b/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferByteTests.java new file mode 100644 index 000000000..abcd89a0e --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferByteTests.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.transfer; + +import junit.framework.TestCase; +import lombok.SneakyThrows; +import org.opensearch.knn.index.codec.util.SerializationMode; +import org.opensearch.knn.jni.JNICommons; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.util.Random; + +import static org.junit.Assert.assertNotEquals; + +public class VectorTransferByteTests extends TestCase { + @SneakyThrows + public void testTransfer_whenCalled_thenAdded() { + final ByteArrayInputStream bais1 = getByteArrayOfVectors(20); + final ByteArrayInputStream bais2 = getByteArrayOfVectors(20); + VectorTransferByte vectorTransfer = new VectorTransferByte(1000); + try { + vectorTransfer.init(2); + + vectorTransfer.transfer(bais1); + // flush is not called + assertEquals(0, vectorTransfer.getVectorAddress()); + + vectorTransfer.transfer(bais2); + // flush should be called + assertNotEquals(0, vectorTransfer.getVectorAddress()); + } finally { + if (vectorTransfer.getVectorAddress() != 0) { + JNICommons.freeVectorData(vectorTransfer.getVectorAddress()); + } + } + } + + @SneakyThrows + public void testSerializationMode_whenCalled_thenReturn() { + final ByteArrayInputStream bais = getByteArrayOfVectors(20); + VectorTransferByte vectorTransfer = new VectorTransferByte(1000); + + // Verify + assertEquals(SerializationMode.COLLECTIONS_OF_BYTES, vectorTransfer.getSerializationMode(bais)); + } + + private ByteArrayInputStream getByteArrayOfVectors(int vectorLength) throws IOException { + byte[] vector = new byte[vectorLength]; + new Random().nextBytes(vector); + return new ByteArrayInputStream(vector); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloatTests.java b/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloatTests.java new file mode 100644 index 000000000..1de513a0b --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloatTests.java @@ -0,0 +1,66 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.transfer; + +import junit.framework.TestCase; +import lombok.SneakyThrows; +import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; +import org.opensearch.knn.jni.JNICommons; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.Random; +import java.util.stream.IntStream; + +import static org.junit.Assert.assertNotEquals; + +public class VectorTransferFloatTests extends TestCase { + @SneakyThrows + public void testTransfer_whenCalled_thenAdded() { + final ByteArrayInputStream bais1 = getByteArrayOfVectors(20); + final ByteArrayInputStream bais2 = getByteArrayOfVectors(20); + VectorTransferFloat vectorTransfer = new VectorTransferFloat(1000); + try { + vectorTransfer.init(2); + + vectorTransfer.transfer(bais1); + // flush is not called + assertEquals(0, vectorTransfer.getVectorAddress()); + + vectorTransfer.transfer(bais2); + // flush should be called + assertNotEquals(0, vectorTransfer.getVectorAddress()); + } finally { + if (vectorTransfer.getVectorAddress() != 0) { + JNICommons.freeVectorData(vectorTransfer.getVectorAddress()); + } + } + } + + @SneakyThrows + public void testSerializationMode_whenCalled_thenReturn() { + final ByteArrayInputStream bais = getByteArrayOfVectors(20); + VectorTransferFloat vectorTransfer = new VectorTransferFloat(1000); + + // Verify + assertEquals(KNNVectorSerializerFactory.getSerializerModeFromStream(bais), vectorTransfer.getSerializationMode(bais)); + } + + private ByteArrayInputStream getByteArrayOfVectors(int vectorLength) throws IOException { + float[] vector = new float[vectorLength]; + IntStream.range(0, vectorLength).forEach(index -> vector[index] = new Random().nextFloat()); + + final ByteArrayOutputStream bas = new ByteArrayOutputStream(); + final DataOutputStream ds = new DataOutputStream(bas); + for (float f : vector) { + ds.writeFloat(f); + } + final byte[] vectorAsCollectionOfFloats = bas.toByteArray(); + return new ByteArrayInputStream(vectorAsCollectionOfFloats); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java new file mode 100644 index 000000000..04c1c038f --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.util; + +import junit.framework.TestCase; +import lombok.SneakyThrows; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.util.BytesRef; +import org.opensearch.knn.index.codec.transfer.VectorTransfer; + +import java.io.ByteArrayInputStream; +import java.util.Arrays; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class KNNCodecUtilTests extends TestCase { + @SneakyThrows + public void testGetPair_whenCalled_thenReturn() { + long liveDocCount = 1l; + int[] docId = { 2 }; + long vectorAddress = 3l; + int dimension = 4; + BytesRef bytesRef = new BytesRef(); + + BinaryDocValues binaryDocValues = mock(BinaryDocValues.class); + when(binaryDocValues.cost()).thenReturn(liveDocCount); + when(binaryDocValues.nextDoc()).thenReturn(docId[0], NO_MORE_DOCS); + when(binaryDocValues.binaryValue()).thenReturn(bytesRef); + + VectorTransfer vectorTransfer = mock(VectorTransfer.class); + when(vectorTransfer.getSerializationMode(any(ByteArrayInputStream.class))).thenReturn(SerializationMode.COLLECTIONS_OF_BYTES); + when(vectorTransfer.getVectorAddress()).thenReturn(vectorAddress); + when(vectorTransfer.getDimension()).thenReturn(dimension); + + // Run + KNNCodecUtil.Pair pair = KNNCodecUtil.getPair(binaryDocValues, vectorTransfer); + + // Verify + verify(vectorTransfer).init(liveDocCount); + verify(vectorTransfer).getSerializationMode(any(ByteArrayInputStream.class)); + verify(vectorTransfer).transfer(any(ByteArrayInputStream.class)); + verify(vectorTransfer).close(); + + assertTrue(Arrays.equals(docId, pair.docs)); + assertEquals(vectorAddress, pair.getVectorAddress()); + assertEquals(dimension, pair.getDimension()); + assertEquals(SerializationMode.COLLECTIONS_OF_BYTES, pair.serializationMode); + } +} diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index a9b65878f..0ba2b97bc 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -13,20 +13,20 @@ import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.util.BytesRef; import org.mockito.Mockito; -import org.opensearch.common.Explicit; -import org.opensearch.index.mapper.FieldMapper; -import org.opensearch.index.mapper.ParseContext; -import org.opensearch.knn.KNNTestCase; import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.common.Explicit; import org.opensearch.common.ValidationException; import org.opensearch.common.settings.IndexScopedSettings; import org.opensearch.common.settings.Settings; -import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.IndexSettings; import org.opensearch.index.mapper.ContentPath; +import org.opensearch.index.mapper.FieldMapper; import org.opensearch.index.mapper.Mapper; import org.opensearch.index.mapper.MapperService; +import org.opensearch.index.mapper.ParseContext; +import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.MethodComponentContext; @@ -52,6 +52,9 @@ import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.Version.CURRENT; import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16; @@ -68,10 +71,9 @@ import static org.opensearch.knn.common.KNNConstants.NAME; import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; -import static org.opensearch.Version.CURRENT; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.index.KNNSettings.KNN_INDEX; +import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.clipVectorValueToFP16Range; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFP16VectorValue; @@ -364,10 +366,6 @@ public void testTypeParser_parse_invalidVectorDataType() { String fieldName = "test-field-name-vec"; String indexName = "test-index-name-vec"; String vectorDataType = "invalid"; - String supportedTypes = String.join( - ",", - Arrays.stream((VectorDataType.values())).map(VectorDataType::getValue).collect(Collectors.toCollection(HashSet::new)) - ); Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); @@ -402,7 +400,7 @@ public void testTypeParser_parse_invalidVectorDataType() { Locale.ROOT, "Invalid value provided for [%s] field. Supported values are [%s]", VECTOR_DATA_TYPE_FIELD, - supportedTypes + SUPPORTED_VECTOR_DATA_TYPES ), ex.getMessage() ); @@ -817,7 +815,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { when(parseContext.path()).thenReturn(contentPath); LuceneFieldMapper luceneFieldMapper = Mockito.spy(new LuceneFieldMapper(inputBuilder.build())); - doReturn(Optional.of(TEST_BYTE_VECTOR)).when(luceneFieldMapper).getBytesFromContext(parseContext, TEST_DIMENSION); + doReturn(Optional.of(TEST_BYTE_VECTOR)).when(luceneFieldMapper) + .getBytesFromContext(parseContext, TEST_DIMENSION, VectorDataType.BYTE); doNothing().when(luceneFieldMapper).validateIfCircuitBreakerIsNotTriggered(); doNothing().when(luceneFieldMapper).validateIfKNNPluginEnabled(); @@ -859,7 +858,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { inputBuilder.hasDocValues(false); luceneFieldMapper = Mockito.spy(new LuceneFieldMapper(inputBuilder.build())); - doReturn(Optional.of(TEST_BYTE_VECTOR)).when(luceneFieldMapper).getBytesFromContext(parseContext, TEST_DIMENSION); + doReturn(Optional.of(TEST_BYTE_VECTOR)).when(luceneFieldMapper) + .getBytesFromContext(parseContext, TEST_DIMENSION, VectorDataType.BYTE); doNothing().when(luceneFieldMapper).validateIfCircuitBreakerIsNotTriggered(); doNothing().when(luceneFieldMapper).validateIfKNNPluginEnabled(); @@ -918,6 +918,90 @@ public void testClipVectorValuetoFP16Range_succeed() { assertEquals(-65504.0f, clipVectorValueToFP16Range(-1000000.89f), 0.0f); } + public void testBuilder_whenBinaryFaissHNSW_thenValid() { + testBuilderWithBinaryDataType(KNNEngine.FAISS, SpaceType.UNDEFINED, METHOD_HNSW, 8, null); + } + + public void testBuilder_whenBinaryWithInvalidDimension_thenException() { + testBuilderWithBinaryDataType(KNNEngine.FAISS, SpaceType.UNDEFINED, METHOD_HNSW, 4, "should be multiply of 8"); + } + + public void testBuilder_whenBinaryFaissHNSWWithInvalidSpaceType_thenException() { + for (SpaceType spaceType : SpaceType.values()) { + if (SpaceType.UNDEFINED == spaceType || SpaceType.HAMMING_BIT == spaceType) { + continue; + } + testBuilderWithBinaryDataType(KNNEngine.FAISS, spaceType, METHOD_HNSW, 8, "is not supported"); + } + } + + public void testBuilder_whenBinaryNonFaiss_thenException() { + testBuilderWithBinaryDataType(KNNEngine.LUCENE, SpaceType.UNDEFINED, METHOD_HNSW, 8, "is only supported for"); + testBuilderWithBinaryDataType(KNNEngine.NMSLIB, SpaceType.UNDEFINED, METHOD_HNSW, 8, "is only supported for"); + } + + private void testBuilderWithBinaryDataType( + KNNEngine knnEngine, + SpaceType spaceType, + String method, + int dimension, + String expectedErrMsg + ) { + ModelDao modelDao = mock(ModelDao.class); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + + // Setup settings + Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); + + builder.knnMethodContext.setValue( + new KNNMethodContext(knnEngine, spaceType, new MethodComponentContext(method, Collections.emptyMap())) + ); + builder.vectorDataType.setValue(VectorDataType.BINARY); + builder.dimension.setValue(dimension); + + Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); + if (expectedErrMsg == null) { + KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); + assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper); + if (SpaceType.UNDEFINED == spaceType) { + assertEquals(SpaceType.HAMMING_BIT, knnVectorFieldMapper.fieldType().spaceType); + } + } else { + Exception ex = expectThrows(Exception.class, () -> builder.build(builderContext)); + assertTrue(ex.getMessage(), ex.getMessage().contains(expectedErrMsg)); + } + } + + public void testBuilder_whenBinaryWithLegacyKNNDisabled_thenValid() { + // Check legacy is picked up if model context and method context are not set + ModelDao modelDao = mock(ModelDao.class); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + builder.vectorDataType.setValue(VectorDataType.BINARY); + builder.dimension.setValue(8); + + // Setup settings + Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, false).build(); + + Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); + KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); + assertTrue(knnVectorFieldMapper instanceof LegacyFieldMapper); + } + + public void testBuilder_whenBinaryWithLegacyKNNEnabled_thenException() { + // Check legacy is picked up if model context and method context are not set + ModelDao modelDao = mock(ModelDao.class); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + builder.vectorDataType.setValue(VectorDataType.BINARY); + builder.dimension.setValue(8); + + // Setup settings + Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); + + Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); + Exception ex = expectThrows(Exception.class, () -> builder.build(builderContext)); + assertTrue(ex.getMessage(), ex.getMessage().contains("is not supported for")); + } + private LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperInputBuilder createLuceneFieldMapperInputBuilder( VectorDataType vectorDataType ) { diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java index ff47dcd69..c7c945bfa 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java @@ -13,16 +13,20 @@ import org.apache.lucene.document.StoredField; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.MethodComponentContext; +import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; +import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; import java.io.ByteArrayInputStream; import java.util.Arrays; +import java.util.Collections; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -114,4 +118,52 @@ public void testGetExpectedDimensionsFailure() { ); assertEquals(String.format("Field '%s' does not have model.", fieldName), e.getMessage()); } + + public void testValidateVectorDataType_whenBinaryFaissHNSW_thenValid() { + validateValidateVectorDataType(KNNEngine.FAISS, KNNConstants.METHOD_HNSW, VectorDataType.BINARY, null); + } + + public void testValidateVectorDataType_whenBinaryNonFaiss_thenException() { + validateValidateVectorDataType(KNNEngine.LUCENE, KNNConstants.METHOD_HNSW, VectorDataType.BINARY, "only supported"); + validateValidateVectorDataType(KNNEngine.NMSLIB, KNNConstants.METHOD_HNSW, VectorDataType.BINARY, "only supported"); + } + + public void testValidateVectorDataType_whenBinaryFaissIVF_thenException() { + validateValidateVectorDataType(KNNEngine.FAISS, KNNConstants.METHOD_IVF, VectorDataType.BINARY, "only supported"); + } + + public void testValidateVectorDataType_whenByteLucene_thenValid() { + validateValidateVectorDataType(KNNEngine.LUCENE, KNNConstants.METHOD_HNSW, VectorDataType.BYTE, null); + validateValidateVectorDataType(KNNEngine.LUCENE, KNNConstants.METHOD_IVF, VectorDataType.BYTE, null); + } + + public void testValidateVectorDataType_whenByteNonLucene_thenException() { + validateValidateVectorDataType(KNNEngine.FAISS, KNNConstants.METHOD_HNSW, VectorDataType.BYTE, "only supported"); + validateValidateVectorDataType(KNNEngine.NMSLIB, KNNConstants.METHOD_IVF, VectorDataType.BYTE, "only supported"); + } + + public void testValidateVectorDataType_whenFloat_thenValid() { + validateValidateVectorDataType(KNNEngine.FAISS, KNNConstants.METHOD_HNSW, VectorDataType.FLOAT, null); + validateValidateVectorDataType(KNNEngine.LUCENE, KNNConstants.METHOD_HNSW, VectorDataType.FLOAT, null); + validateValidateVectorDataType(KNNEngine.NMSLIB, KNNConstants.METHOD_HNSW, VectorDataType.FLOAT, null); + } + + private void validateValidateVectorDataType( + final KNNEngine knnEngine, + final String methodName, + final VectorDataType vectorDataType, + final String expectedErrMsg + ) { + MethodComponentContext methodComponentContext = new MethodComponentContext(methodName, Collections.emptyMap()); + KNNMethodContext methodContext = new KNNMethodContext(knnEngine, SpaceType.UNDEFINED, methodComponentContext); + if (expectedErrMsg == null) { + KNNVectorFieldMapperUtil.validateVectorDataType(methodContext, vectorDataType); + } else { + Exception ex = expectThrows( + IllegalArgumentException.class, + () -> KNNVectorFieldMapperUtil.validateVectorDataType(methodContext, vectorDataType) + ); + assertTrue(ex.getMessage().contains(expectedErrMsg)); + } + } } diff --git a/src/test/java/org/opensearch/knn/index/mapper/MethodFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/MethodFieldMapperTests.java new file mode 100644 index 000000000..ef2a2768e --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/mapper/MethodFieldMapperTests.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import junit.framework.TestCase; +import org.opensearch.index.mapper.FieldMapper; +import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; + +import java.util.Collections; + +public class MethodFieldMapperTests extends TestCase { + public void testMethodFieldMapper_whenVectorDataTypeIsGiven_thenSetItInFieldType() { + KNNVectorFieldMapper.KNNVectorFieldType mappedFieldType = new KNNVectorFieldMapper.KNNVectorFieldType( + "testField", + Collections.emptyMap(), + 1, + VectorDataType.BINARY, + SpaceType.HAMMING_BIT + ); + MethodFieldMapper mappers = new MethodFieldMapper( + "simpleName", + mappedFieldType, + null, + new FieldMapper.CopyTo.Builder().build(), + KNNVectorFieldMapper.Defaults.IGNORE_MALFORMED, + true, + true, + KNNMethodContext.getDefault() + ); + assertEquals(VectorDataType.BINARY, mappers.fieldType().vectorDataType); + } +} diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java index 7573a4394..a422cfc4a 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java @@ -12,6 +12,7 @@ package org.opensearch.knn.index.memory; import com.google.common.collect.ImmutableMap; +import lombok.SneakyThrows; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.IndexUtil; @@ -91,6 +92,67 @@ public void testIndexAllocation_close() throws InterruptedException { executorService.shutdown(); } + @SneakyThrows + public void testClose_whenBinaryFiass_thenSuccess() { + Path dir = createTempDir(); + KNNEngine knnEngine = KNNEngine.FAISS; + String indexName = "test1" + knnEngine.getExtension(); + String path = dir.resolve(indexName).toAbsolutePath().toString(); + int numVectors = 10; + int dimension = 8; + int dataLength = dimension / 8; + int[] ids = new int[numVectors]; + byte[][] vectors = new byte[numVectors][dataLength]; + for (int i = 0; i < numVectors; i++) { + ids[i] = i; + vectors[i][0] = 1; + } + Map parameters = ImmutableMap.of( + KNNConstants.SPACE_TYPE, + SpaceType.HAMMING_BIT.getValue(), + KNNConstants.INDEX_DESCRIPTION_PARAMETER, + "BHNSW32" + ); + long vectorMemoryAddress = JNICommons.storeByteVectorData(0, vectors, numVectors * dataLength); + JNIService.createIndex(ids, vectorMemoryAddress, dimension, path, parameters, knnEngine); + + // Load index into memory + long memoryAddress = JNIService.loadIndex(path, parameters, knnEngine); + + @SuppressWarnings("unchecked") + WatcherHandle watcherHandle = (WatcherHandle) mock(WatcherHandle.class); + doNothing().when(watcherHandle).stop(); + + ExecutorService executorService = Executors.newSingleThreadExecutor(); + NativeMemoryAllocation.IndexAllocation indexAllocation = new NativeMemoryAllocation.IndexAllocation( + executorService, + memoryAddress, + IndexUtil.getFileSizeInKB(path), + knnEngine, + path, + "test", + watcherHandle, + null, + true + ); + + indexAllocation.close(); + + Thread.sleep(1000 * 2); + indexAllocation.writeLock(); + assertTrue(indexAllocation.isClosed()); + indexAllocation.writeUnlock(); + + indexAllocation.close(); + + Thread.sleep(1000 * 2); + indexAllocation.writeLock(); + assertTrue(indexAllocation.isClosed()); + indexAllocation.writeUnlock(); + + executorService.shutdown(); + } + public void testIndexAllocation_getMemoryAddress() { long memoryAddress = 12; NativeMemoryAllocation.IndexAllocation indexAllocation = new NativeMemoryAllocation.IndexAllocation( diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java index 876303523..c9216ee0b 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java @@ -80,6 +80,63 @@ public void testIndexLoadStrategy_load() throws IOException { assertTrue(results.length > 0); } + public void testLoad_whenFaissBinary_thenSuccess() throws IOException { + Path dir = createTempDir(); + KNNEngine knnEngine = KNNEngine.FAISS; + String indexName = "test1" + knnEngine.getExtension(); + String path = dir.resolve(indexName).toAbsolutePath().toString(); + int numVectors = 10; + int dimension = 8; + int dataLength = dimension / 8; + int[] ids = new int[numVectors]; + byte[][] vectors = new byte[numVectors][dataLength]; + for (int i = 0; i < numVectors; i++) { + ids[i] = i; + vectors[i][0] = 1; + } + Map parameters = ImmutableMap.of( + KNNConstants.SPACE_TYPE, + SpaceType.HAMMING_BIT.getValue(), + KNNConstants.INDEX_DESCRIPTION_PARAMETER, + "BHNSW32" + ); + long memoryAddress = JNICommons.storeByteVectorData(0, vectors, numVectors); + JNIService.createIndex(ids, memoryAddress, dimension, path, parameters, knnEngine); + + // Setup mock resource manager + ResourceWatcherService resourceWatcherService = mock(ResourceWatcherService.class); + doReturn(null).when(resourceWatcherService).add(any()); + NativeMemoryLoadStrategy.IndexLoadStrategy.initialize(resourceWatcherService); + + NativeMemoryEntryContext.IndexEntryContext indexEntryContext = new NativeMemoryEntryContext.IndexEntryContext( + path, + NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), + parameters, + "test" + ); + + // Load + NativeMemoryAllocation.IndexAllocation indexAllocation = NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance() + .load(indexEntryContext); + + // Verify + assertTrue(indexAllocation.isBinaryIndex()); + + // Confirm that the file was loaded by querying + byte[] query = { 1 }; + KNNQueryResult[] results = JNIService.queryBinaryIndex( + indexAllocation.getMemoryAddress(), + query, + 2, + null, + knnEngine, + null, + 0, + null + ); + assertTrue(results.length > 0); + } + @SuppressWarnings("unchecked") public void testTrainingLoadStrategy_load() { // Mock the vector reader so that on read, it waits 2 seconds, transfers vectors to the consumer, and then calls diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 0492297cd..06e370026 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -817,6 +817,7 @@ public void testDoToQuery_WhenknnQueryWithFilterAndFaissEngine_thenSuccess() { org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); @@ -873,6 +874,7 @@ public void testDoToQuery_whenknnQueryWithFilterAndNmsLibEngine_thenException() org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.NMSLIB, SpaceType.L2, methodComponentContext); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); @@ -1173,6 +1175,7 @@ public void testRadialSearch_whenUnsupportedEngine_thenThrowException() { when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); expectThrows(UnsupportedOperationException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); @@ -1232,4 +1235,36 @@ public void testRadialSearch_whenEfSearchIsSet_whenFaissEngine_thenSuccess() { KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); assertEquals(1 / MIN_SCORE - 1, query.getRadius(), 0); } + + public void testDoToQuery_whenBinary_thenValid() throws Exception { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + byte[] expectedQueryVector = { 1, 2, 3, 4 }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(32); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BINARY); + when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.HAMMING_BIT); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); + assertArrayEquals(expectedQueryVector, query.getByteQueryVector()); + assertNull(query.getQueryVector()); + } + + public void testDoToQuery_whenBinaryWithInvalidDimension_thenException() throws Exception { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(8); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BINARY); + when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.HAMMING_BIT); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + Exception ex = expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); + assertTrue(ex.getMessage(), ex.getMessage().contains("invalid dimension")); + } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java index 56e81d237..02b64cba5 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java @@ -216,6 +216,7 @@ public void testCreateFaissQueryWithFilter_withValidValues_thenSuccess() { .queryVector(testQueryVector) .k(testK) .methodParameters(methodParameters) + .vectorDataType(VectorDataType.FLOAT) .build(); // When @@ -226,6 +227,7 @@ public void testCreateFaissQueryWithFilter_withValidValues_thenSuccess() { .vector(testQueryVector) .k(testK) .methodParameters(methodParameters) + .vectorDataType(VectorDataType.FLOAT) .context(mockQueryShardContext) .filter(FILTER_QUERY_BUILDER) .build(); @@ -259,6 +261,7 @@ public void testCreateFaissQueryWithFilter_withValidValues_nullEfSearch_thenSucc .fieldName(testFieldName) .vector(testQueryVector) .k(testK) + .vectorDataType(VectorDataType.FLOAT) .context(mockQueryShardContext) .filter(FILTER_QUERY_BUILDER) .build(); @@ -294,6 +297,7 @@ public void testCreate_whenNestedVectorFiledAndNonNestedFilterField_thenReturnTo .fieldName(testFieldName) .vector(testQueryVector) .k(testK) + .vectorDataType(VectorDataType.FLOAT) .context(mockQueryShardContext) .filter(FILTER_QUERY_BUILDER) .build(); @@ -322,6 +326,7 @@ public void testCreate_whenNestedVectorAndFilterField_thenReturnSameFilterQuery( .fieldName(testFieldName) .vector(testQueryVector) .k(testK) + .vectorDataType(VectorDataType.FLOAT) .context(mockQueryShardContext) .filter(FILTER_QUERY_BUILDER) .build(); @@ -343,6 +348,7 @@ public void testCreate_whenFaissWithParentFilter_thenSuccess() { .fieldName(testFieldName) .vector(testQueryVector) .k(testK) + .vectorDataType(VectorDataType.FLOAT) .context(mockQueryShardContext) .build(); final Query query = KNNQueryFactory.create(createQueryRequest); @@ -379,4 +385,27 @@ private void validateDiversifyingQueryWithParentFilter(final VectorDataType type assertEquals(expectedQueryClass, query.getClass()); } } + + public void testCreate_whenBinary_thenSuccess() { + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + MappedFieldType testMapper = mock(MappedFieldType.class); + when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper); + BitSetProducer parentFilter = mock(BitSetProducer.class); + when(mockQueryShardContext.getParentFilter()).thenReturn(parentFilter); + final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(KNNEngine.FAISS) + .indexName(testIndexName) + .fieldName(testFieldName) + .vector(testQueryVector) + .byteVector(testByteQueryVector) + .vectorDataType(VectorDataType.BINARY) + .k(testK) + .context(mockQueryShardContext) + .filter(FILTER_QUERY_BUILDER) + .build(); + Query query = KNNQueryFactory.create(createQueryRequest); + assertTrue(query instanceof KNNQuery); + assertNotNull(((KNNQuery) query).getByteQueryVector()); + assertNull(((KNNQuery) query).getQueryVector()); + } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index dc3454368..7da80f2fe 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -29,6 +29,7 @@ import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.StringHelper; import org.apache.lucene.util.Version; +import org.junit.After; import org.junit.Before; import org.junit.BeforeClass; import org.mockito.MockedStatic; @@ -38,6 +39,7 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.KNNCodecVersion; import org.opensearch.knn.index.codec.util.KNNVectorAsArraySerializer; import org.opensearch.knn.index.memory.NativeMemoryAllocation; @@ -70,6 +72,7 @@ import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.when; import static org.opensearch.knn.KNNRestTestCase.INDEX_NAME; import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; @@ -82,6 +85,7 @@ public class KNNWeightTests extends KNNTestCase { private static final String FIELD_NAME = "target_field"; private static final float[] QUERY_VECTOR = new float[] { 1.8f, 2.4f }; + private static final byte[] BYTE_QUERY_VECTOR = new byte[] { 1, 2 }; private static final String SEGMENT_NAME = "0"; private static final int K = 5; private static final Set SEGMENT_FILES_NMSLIB = Set.of("_0.cfe", "_0_2011_target_field.hnswc"); @@ -93,6 +97,7 @@ public class KNNWeightTests extends KNNTestCase { private static final Map DOC_ID_TO_SCORES = Map.of(10, 0.4f, 101, 0.05f, 100, 0.8f, 50, 0.52f); private static final Map FILTERED_DOC_ID_TO_SCORES = Map.of(101, 0.05f, 100, 0.8f, 50, 0.52f); private static final Map EXACT_SEARCH_DOC_ID_TO_SCORES = Map.of(0, 0.12048191f); + private static final Map BINARY_EXACT_SEARCH_DOC_ID_TO_SCORES = Map.of(0, 0.5f); private static final Query FILTER_QUERY = new TermQuery(new Term("foo", "fooValue")); @@ -118,7 +123,6 @@ public static void setUpClass() throws Exception { knnSettingsMockedStatic.when(KNNSettings::state).thenReturn(knnSettings); knnSettingsMockedStatic.when(KNNSettings::isKNNPluginEnabled).thenReturn(true); - jniServiceMockedStatic = mockStatic(JNIService.class); nativeMemoryCacheManagerMockedStatic = mockStatic(NativeMemoryCacheManager.class); final NativeMemoryCacheManager nativeMemoryCacheManager = mock(NativeMemoryCacheManager.class); @@ -136,6 +140,12 @@ public static void setUpClass() throws Exception { @Before public void setupBeforeTest() { knnSettingsMockedStatic.when(() -> KNNSettings.getFilteredExactSearchThreshold(INDEX_NAME)).thenReturn(0); + jniServiceMockedStatic = mockStatic(JNIService.class); + } + + @After + public void tearDownAfterTest() { + jniServiceMockedStatic.close(); } @SneakyThrows @@ -365,45 +375,177 @@ public void testEmptyQueryResults() { } @SneakyThrows - public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { + public void testScorer_whenNoFilterBinary_thenSuccess() { + validateScorer_whenNoFilter_thenSuccess(true); + } + + @SneakyThrows + public void testScorer_whenNoFilter_thenSuccess() { + validateScorer_whenNoFilter_thenSuccess(false); + } + + private void validateScorer_whenNoFilter_thenSuccess(final boolean isBinary) throws IOException { // Given int k = 3; - final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 }; - FixedBitSet filterBitSet = new FixedBitSet(filterDocIds.length); - for (int docId : filterDocIds) { - filterBitSet.set(docId); - } jniServiceMockedStatic.when( - () -> JNIService.queryIndex( + () -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), any(), anyInt(), any()) + ).thenReturn(getFilteredKNNQueryResults()); + + jniServiceMockedStatic.when( + () -> JNIService.queryBinaryIndex( anyLong(), - eq(QUERY_VECTOR), + eq(BYTE_QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), - eq(filterBitSet.getBits()), + any(), anyInt(), any() ) ).thenReturn(getFilteredKNNQueryResults()); + final SegmentReader reader = mockSegmentReader(); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); - final SegmentReader reader = mock(SegmentReader.class); + when(leafReaderContext.reader()).thenReturn(reader); + + final KNNQuery query = isBinary + ? KNNQuery.builder() + .field(FIELD_NAME) + .byteQueryVector(BYTE_QUERY_VECTOR) + .k(k) + .indexName(INDEX_NAME) + .filterQuery(FILTER_QUERY) + .methodParameters(HNSW_METHOD_PARAMETERS) + .vectorDataType(VectorDataType.BINARY) + .build() + : KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(QUERY_VECTOR) + .k(k) + .indexName(INDEX_NAME) + .filterQuery(FILTER_QUERY) + .methodParameters(HNSW_METHOD_PARAMETERS) + .vectorDataType(VectorDataType.FLOAT) + .build(); + + final float boost = (float) randomDoubleBetween(0, 10, true); + final KNNWeight knnWeight = new KNNWeight(query, boost); + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + final Map attributesMap = ImmutableMap.of( + KNN_ENGINE, + KNNEngine.FAISS.getName(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ); + + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); + when(fieldInfo.attributes()).thenReturn(attributesMap); + + // When + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + + // Then + assertNotNull(knnScorer); + if (isBinary) { + jniServiceMockedStatic.verify( + () -> JNIService.queryBinaryIndex( + anyLong(), + eq(BYTE_QUERY_VECTOR), + eq(k), + eq(HNSW_METHOD_PARAMETERS), + any(), + any(), + anyInt(), + any() + ), + times(1) + ); + } else { + jniServiceMockedStatic.verify( + () -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), any(), anyInt(), any()), + times(1) + ); + } + } + + @SneakyThrows + public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { + validateANNWithFilterQuery_whenDoingANN_thenSuccess(false); + } + + @SneakyThrows + public void testANNWithFilterQuery_whenDoingANNBinary_thenSuccess() { + validateANNWithFilterQuery_whenDoingANN_thenSuccess(true); + } + + public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean isBinary) throws IOException { + // Given + int k = 3; + final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 }; + FixedBitSet filterBitSet = new FixedBitSet(filterDocIds.length); + for (int docId : filterDocIds) { + filterBitSet.set(docId); + } + if (isBinary) { + jniServiceMockedStatic.when( + () -> JNIService.queryBinaryIndex( + anyLong(), + eq(BYTE_QUERY_VECTOR), + eq(k), + eq(HNSW_METHOD_PARAMETERS), + any(), + eq(filterBitSet.getBits()), + anyInt(), + any() + ) + ).thenReturn(getFilteredKNNQueryResults()); + } else { + jniServiceMockedStatic.when( + () -> JNIService.queryIndex( + anyLong(), + eq(QUERY_VECTOR), + eq(k), + eq(HNSW_METHOD_PARAMETERS), + any(), + eq(filterBitSet.getBits()), + anyInt(), + any() + ) + ).thenReturn(getFilteredKNNQueryResults()); + } + final Bits liveDocsBits = mock(Bits.class); - when(reader.maxDoc()).thenReturn(filterDocIds.length); - when(reader.getLiveDocs()).thenReturn(liveDocsBits); for (int filterDocId : filterDocIds) { when(liveDocsBits.get(filterDocId)).thenReturn(true); } when(liveDocsBits.length()).thenReturn(1000); + + final SegmentReader reader = mockSegmentReader(); + when(reader.maxDoc()).thenReturn(filterDocIds.length); + when(reader.getLiveDocs()).thenReturn(liveDocsBits); + + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); when(leafReaderContext.reader()).thenReturn(reader); - final KNNQuery query = KNNQuery.builder() - .field(FIELD_NAME) - .queryVector(QUERY_VECTOR) - .k(k) - .indexName(INDEX_NAME) - .filterQuery(FILTER_QUERY) - .methodParameters(HNSW_METHOD_PARAMETERS) - .build(); + final KNNQuery query = isBinary + ? KNNQuery.builder() + .field(FIELD_NAME) + .byteQueryVector(BYTE_QUERY_VECTOR) + .vectorDataType(VectorDataType.BINARY) + .k(k) + .indexName(INDEX_NAME) + .filterQuery(FILTER_QUERY) + .methodParameters(HNSW_METHOD_PARAMETERS) + .build() + : KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(QUERY_VECTOR) + .k(k) + .indexName(INDEX_NAME) + .filterQuery(FILTER_QUERY) + .methodParameters(HNSW_METHOD_PARAMETERS) + .build(); final Weight filterQueryWeight = mock(Weight.class); final Scorer filterScorer = mock(Scorer.class); @@ -414,35 +556,13 @@ public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { final float boost = (float) randomDoubleBetween(0, 10, true); final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); - final FSDirectory directory = mock(FSDirectory.class); - when(reader.directory()).thenReturn(directory); - final SegmentInfo segmentInfo = new SegmentInfo( - directory, - Version.LATEST, - Version.LATEST, - SEGMENT_NAME, - 100, - true, - false, - KNNCodecVersion.current().getDefaultCodecDelegate(), - Map.of(), - new byte[StringHelper.ID_LENGTH], - Map.of(), - Sort.RELEVANCE - ); - segmentInfo.setFiles(SEGMENT_FILES_FAISS); - final SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]); - when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo); - - final Path path = mock(Path.class); - when(directory.getDirectory()).thenReturn(path); final FieldInfos fieldInfos = mock(FieldInfos.class); final FieldInfo fieldInfo = mock(FieldInfo.class); final Map attributesMap = ImmutableMap.of( KNN_ENGINE, KNNEngine.FAISS.getName(), - PARAMETERS, - String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + SPACE_TYPE, + isBinary ? SpaceType.HAMMING_BIT.getValue() : SpaceType.L2.getValue() ); when(reader.getFieldInfos()).thenReturn(fieldInfos); @@ -458,18 +578,26 @@ public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { assertNotNull(docIdSetIterator); assertEquals(FILTERED_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); - jniServiceMockedStatic.verify( - () -> JNIService.queryIndex( - anyLong(), - eq(QUERY_VECTOR), - eq(k), - eq(HNSW_METHOD_PARAMETERS), - any(), - eq(filterBitSet.getBits()), - anyInt(), - any() - ) - ); + if (isBinary) { + jniServiceMockedStatic.verify( + () -> JNIService.queryBinaryIndex( + anyLong(), + eq(BYTE_QUERY_VECTOR), + eq(k), + eq(HNSW_METHOD_PARAMETERS), + any(), + any(), + anyInt(), + any() + ), + times(1) + ); + } else { + jniServiceMockedStatic.verify( + () -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), any(), anyInt(), any()), + times(1) + ); + } final List actualDocIds = new ArrayList<>(); final Map translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation); @@ -481,15 +609,56 @@ public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); } + private SegmentReader mockSegmentReader() { + Path path = mock(Path.class); + + FSDirectory directory = mock(FSDirectory.class); + when(directory.getDirectory()).thenReturn(path); + + SegmentInfo segmentInfo = new SegmentInfo( + directory, + Version.LATEST, + Version.LATEST, + SEGMENT_NAME, + 100, + true, + false, + KNNCodecVersion.current().getDefaultCodecDelegate(), + Map.of(), + new byte[StringHelper.ID_LENGTH], + Map.of(), + Sort.RELEVANCE + ); + segmentInfo.setFiles(SEGMENT_FILES_FAISS); + SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]); + + SegmentReader reader = mock(SegmentReader.class); + when(reader.directory()).thenReturn(directory); + when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo); + return reader; + } + @SneakyThrows public void testANNWithFilterQuery_whenExactSearch_thenSuccess() { + validateANNWithFilterQuery_whenExactSearch_thenSuccess(false); + } + + @SneakyThrows + public void testANNWithFilterQuery_whenExactSearchBinary_thenSuccess() { + validateANNWithFilterQuery_whenExactSearch_thenSuccess(true); + } + + public void validateANNWithFilterQuery_whenExactSearch_thenSuccess(final boolean isBinary) throws IOException { float[] vector = new float[] { 0.1f, 0.3f }; + byte[] byteVector = new byte[] { 1, 3 }; int filterDocId = 0; final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); when(leafReaderContext.reader()).thenReturn(reader); - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY, null); + final KNNQuery query = isBinary + ? new KNNQuery(FIELD_NAME, BYTE_QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY, null, VectorDataType.BINARY) + : new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY, null); final Weight filterQueryWeight = mock(Weight.class); final Scorer filterScorer = mock(Scorer.class); when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); @@ -506,9 +675,7 @@ public void testANNWithFilterQuery_whenExactSearch_thenSuccess() { KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, - SpaceType.L2.name(), - PARAMETERS, - String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + isBinary ? SpaceType.HAMMING_BIT.getValue() : SpaceType.L2.getValue() ); final FieldInfos fieldInfos = mock(FieldInfos.class); final FieldInfo fieldInfo = mock(FieldInfo.class); @@ -516,23 +683,36 @@ public void testANNWithFilterQuery_whenExactSearch_thenSuccess() { when(reader.getFieldInfos()).thenReturn(fieldInfos); when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); when(fieldInfo.attributes()).thenReturn(attributesMap); - when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(SpaceType.L2.name()); + if (isBinary) { + when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(SpaceType.HAMMING_BIT.getValue()); + } else { + when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(SpaceType.L2.getValue()); + } when(fieldInfo.getName()).thenReturn(FIELD_NAME); when(reader.getBinaryDocValues(FIELD_NAME)).thenReturn(binaryDocValues); when(binaryDocValues.advance(filterDocId)).thenReturn(filterDocId); BytesRef vectorByteRef = new BytesRef(new KNNVectorAsArraySerializer().floatToByteArray(vector)); - when(binaryDocValues.binaryValue()).thenReturn(vectorByteRef); + + if (isBinary) { + when(binaryDocValues.binaryValue()).thenReturn(new BytesRef(byteVector)); + } else { + when(binaryDocValues.binaryValue()).thenReturn(vectorByteRef); + } final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); assertNotNull(knnScorer); final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); assertNotNull(docIdSetIterator); - assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); + assertEquals(1, docIdSetIterator.cost()); final List actualDocIds = new ArrayList<>(); for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { actualDocIds.add(docId); - assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId) * boost, knnScorer.score(), 0.01f); + if (isBinary) { + assertEquals(BINARY_EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId) * boost, knnScorer.score(), 0.01f); + } else { + assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId) * boost, knnScorer.score(), 0.01f); + } } assertEquals(docIdSetIterator.cost(), actualDocIds.size()); assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); @@ -730,7 +910,7 @@ public void testANNWithParentsFilter_whenExactSearch_thenSuccess() { // Verify final List expectedScores = vectors.stream() - .map(vector -> SpaceType.L2.getVectorSimilarityFunction().compare(QUERY_VECTOR, vector)) + .map(vector -> SpaceType.L2.getKnnVectorSimilarityFunction().compare(QUERY_VECTOR, vector)) .collect(Collectors.toList()); final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); assertEquals(1, docIdSetIterator.nextDoc()); diff --git a/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNByteIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNByteIteratorTests.java new file mode 100644 index 000000000..aabbb1f9c --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNByteIteratorTests.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.filtered; + +import junit.framework.TestCase; +import lombok.SneakyThrows; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.FixedBitSet; +import org.opensearch.knn.index.SpaceType; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class FilteredIdsKNNByteIteratorTests extends TestCase { + @SneakyThrows + public void testNextDoc_whenCalled_IterateAllDocs() { + final SpaceType spaceType = SpaceType.HAMMING_BIT; + final byte[] queryVector = { 1, 2, 3 }; + final int[] filterIds = { 1, 2, 3 }; + final List dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 14, 15, 16 }, new byte[] { 17, 18, 19 }); + final List expectedScores = dataVectors.stream() + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .collect(Collectors.toList()); + + BinaryDocValues values = mock(BinaryDocValues.class); + final List byteRefs = dataVectors.stream().map(vector -> new BytesRef(vector)).collect(Collectors.toList()); + when(values.binaryValue()).thenReturn(byteRefs.get(0), byteRefs.get(1), byteRefs.get(2)); + + FixedBitSet filterBitSet = new FixedBitSet(4); + for (int id : filterIds) { + when(values.advance(id)).thenReturn(id); + filterBitSet.set(id); + } + + // Execute and verify + FilteredIdsKNNByteIterator iterator = new FilteredIdsKNNByteIterator(filterBitSet, queryVector, values, spaceType); + for (int i = 0; i < filterIds.length; i++) { + assertEquals(filterIds[i], iterator.nextDoc()); + assertEquals(expectedScores.get(i), (Float) iterator.score()); + } + assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc()); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIteratorTests.java index dce703050..cf8582a05 100644 --- a/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIteratorTests.java @@ -33,7 +33,7 @@ public void testNextDoc_whenCalled_IterateAllDocs() { new float[] { 17.0f, 18.0f, 19.0f } ); final List expectedScores = dataVectors.stream() - .map(vector -> spaceType.getVectorSimilarityFunction().compare(queryVector, vector)) + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) .collect(Collectors.toList()); BinaryDocValues values = mock(BinaryDocValues.class); diff --git a/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNByteIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNByteIteratorTests.java new file mode 100644 index 000000000..01a4eb2b1 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNByteIteratorTests.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.filtered; + +import junit.framework.TestCase; +import lombok.SneakyThrows; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.FixedBitSet; +import org.opensearch.knn.index.SpaceType; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class NestedFilteredIdsKNNByteIteratorTests extends TestCase { + @SneakyThrows + public void testNextDoc_whenIterate_ReturnBestChildDocsPerParent() { + final SpaceType spaceType = SpaceType.HAMMING_BIT; + final byte[] queryVector = { 1, 2, 3 }; + final int[] filterIds = { 0, 2, 3 }; + // Parent id for 0 -> 1 + // Parent id for 2, 3 -> 4 + // In bit representation, it is 10010. In long, it is 18. + final BitSet parentBitSet = new FixedBitSet(new long[] { 18 }, 5); + final List dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 14, 15, 16 }, new byte[] { 17, 18, 19 }); + final List expectedScores = dataVectors.stream() + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .collect(Collectors.toList()); + + BinaryDocValues values = mock(BinaryDocValues.class); + final List byteRefs = dataVectors.stream().map(vector -> new BytesRef(vector)).collect(Collectors.toList()); + when(values.binaryValue()).thenReturn(byteRefs.get(0), byteRefs.get(1), byteRefs.get(2)); + + FixedBitSet filterBitSet = new FixedBitSet(4); + for (int id : filterIds) { + when(values.advance(id)).thenReturn(id); + filterBitSet.set(id); + } + + // Execute and verify + NestedFilteredIdsKNNByteIterator iterator = new NestedFilteredIdsKNNByteIterator( + filterBitSet, + queryVector, + values, + spaceType, + parentBitSet + ); + assertEquals(filterIds[0], iterator.nextDoc()); + assertEquals(expectedScores.get(0), iterator.score()); + assertEquals(filterIds[2], iterator.nextDoc()); + assertEquals(expectedScores.get(2), iterator.score()); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc()); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIteratorTests.java index d732376ef..508b0d3d6 100644 --- a/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIteratorTests.java @@ -38,7 +38,7 @@ public void testNextDoc_whenIterate_ReturnBestChildDocsPerParent() { new float[] { 14.0f, 15.0f, 16.0f } ); final List expectedScores = dataVectors.stream() - .map(vector -> spaceType.getVectorSimilarityFunction().compare(queryVector, vector)) + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) .collect(Collectors.toList()); BinaryDocValues values = mock(BinaryDocValues.class); diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index e17ee5077..5c7276354 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -53,6 +53,7 @@ import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.NAME; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; @@ -264,6 +265,10 @@ public void testCreateIndex_nmslib_invalid_badParameterType() throws IOException public void testCreateIndex_nmslib_valid() throws IOException { for (SpaceType spaceType : KNNEngine.NMSLIB.getMethod(KNNConstants.METHOD_HNSW).getSpaces()) { + if (SpaceType.UNDEFINED == spaceType) { + continue; + } + Path tmpFile = createTempFile(); JNIService.createIndex( @@ -591,6 +596,7 @@ public void testTrain_whenConfigurationIsIVFSQFP16_thenSucceed() { .startObject() .field(NAME, METHOD_IVF) .field(KNN_ENGINE, FAISS_NAME) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.DEFAULT) .startObject(PARAMETERS) .field(METHOD_PARAMETER_NLIST, ivfNlistParam) .startObject(METHOD_ENCODER_PARAMETER) @@ -801,6 +807,10 @@ public void testQueryIndex_nmslib_valid() throws IOException { int k = 50; for (SpaceType spaceType : KNNEngine.NMSLIB.getMethod(KNNConstants.METHOD_HNSW).getSpaces()) { + if (SpaceType.UNDEFINED == spaceType) { + continue; + } + Path tmpFile = createTempFile(); JNIService.createIndex( @@ -1098,6 +1108,7 @@ public void testTrain_whenConfigurationIsIVFFlat_thenSucceed() throws IOExceptio .startObject() .field(NAME, METHOD_IVF) .field(KNN_ENGINE, FAISS_NAME) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.DEFAULT) .startObject(PARAMETERS) .field(METHOD_PARAMETER_NLIST, ivfNlistParam) .endObject() @@ -1121,6 +1132,7 @@ public void testTrain_whenConfigurationIsIVFPQ_thenSucceed() throws IOException .startObject() .field(NAME, METHOD_IVF) .field(KNN_ENGINE, FAISS_NAME) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.DEFAULT.getValue()) .startObject(PARAMETERS) .field(METHOD_PARAMETER_NLIST, ivfNlistParam) .startObject(METHOD_ENCODER_PARAMETER) @@ -1149,6 +1161,7 @@ public void testTrain_whenConfigurationIsHNSWPQ_thenSucceed() throws IOException .startObject() .field(NAME, METHOD_HNSW) .field(KNN_ENGINE, FAISS_NAME) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.DEFAULT.getValue()) .startObject(PARAMETERS) .startObject(METHOD_ENCODER_PARAMETER) .field(NAME, ENCODER_PQ) diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java index 22110accd..a8d37b6c5 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java @@ -267,6 +267,12 @@ public void testZeroVectorFailsCosineSimilarityOptimized() throws IOException { dataset.close(); } + public void testCalculateHammingBit_whenByte_thenSuccess() { + byte[] v1 = { 1, 16, -128 }; // 0000 0001, 0001 0000, 1000 0000 + byte[] v2 = { 2, 17, -1 }; // 0000 0010, 0001 0001, 1111 1111 + assertEquals(10, KNNScoringUtil.calculateHammingBit(v1, v2), 0.001f); + } + class TestKNNScriptDocValues { private KNNVectorScriptDocValues scriptDocValues; private Directory directory; diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java index 5a83891d9..724baae14 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java @@ -597,11 +597,12 @@ public void testKNNScriptScoreOnModelBasedIndex() throws Exception { .toString(); for (SpaceType spaceType : SpaceType.values()) { - if (spaceType != SpaceType.HAMMING_BIT) { - final float[] queryVector = randomVector(dimensions); - final BiFunction scoreFunction = getScoreFunction(spaceType, queryVector); - createIndexAndAssertScriptScore(testMapping, spaceType, scoreFunction, dimensions, queryVector, true); + if (SpaceType.UNDEFINED == spaceType || SpaceType.HAMMING_BIT == spaceType) { + continue; } + final float[] queryVector = randomVector(dimensions); + final BiFunction scoreFunction = getScoreFunction(spaceType, queryVector); + createIndexAndAssertScriptScore(testMapping, spaceType, scoreFunction, dimensions, queryVector, true); } }