diff --git a/CHANGELOG.md b/CHANGELOG.md index 69dd6a2e3..0623e6f98 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,5 +21,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Documentation ### Maintenance ### Refactoring +* Introduce KNNVectorValues interface to iterate on different types of Vector values during indexing and search [#1897](https://github.com/opensearch-project/k-NN/pull/1897) * Clean up parsing for query [#1824](https://github.com/opensearch-project/k-NN/pull/1824) * Refactor engine package structure [#1913](https://github.com/opensearch-project/k-NN/pull/1913) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java index 26171eece..e4860af31 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java @@ -54,7 +54,7 @@ static NativeEngineFieldVectorsWriter create(final FieldInfo fieldInfo, final throw new IllegalStateException("Unsupported Vector encoding : " + fieldInfo.getVectorEncoding()); } - NativeEngineFieldVectorsWriter(final FieldInfo fieldInfo, final InfoStream infoStream) { + private NativeEngineFieldVectorsWriter(final FieldInfo fieldInfo, final InfoStream infoStream) { this.fieldInfo = fieldInfo; this.infoStream = infoStream; vectors = new HashMap<>(); diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index 04aeb337f..d208d8179 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -115,7 +115,7 @@ public static String buildEngineFileSuffix(String fieldName, String extension) { return String.format("_%s%s", fieldName, extension); } - private static long getTotalLiveDocsCount(final BinaryDocValues binaryDocValues) { + public static long getTotalLiveDocsCount(final BinaryDocValues binaryDocValues) { long totalLiveDocs; if (binaryDocValues instanceof KNN80BinaryDocValues) { totalLiveDocs = ((KNN80BinaryDocValues) binaryDocValues).getTotalLiveDocs(); diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNBinaryVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNBinaryVectorValues.java new file mode 100644 index 000000000..f38099b74 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNBinaryVectorValues.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.vectorvalues; + +import lombok.ToString; +import org.apache.lucene.codecs.KnnFieldVectorsWriter; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.ByteVectorValues; + +import java.io.IOException; + +/** + * Concrete implementation of {@link KNNVectorValues} that returns byte[] as vector where binary vector is stored and + * provides an abstraction over {@link BinaryDocValues}, {@link ByteVectorValues}, {@link KnnFieldVectorsWriter} etc. + */ +@ToString(callSuper = true) +public class KNNBinaryVectorValues extends KNNVectorValues { + KNNBinaryVectorValues(KNNVectorValuesIterator vectorValuesIterator) { + super(vectorValuesIterator); + } + + @Override + public byte[] getVector() throws IOException { + final byte[] vector = VectorValueExtractorStrategy.extractBinaryVector(vectorValuesIterator); + this.dimension = vector.length; + return vector; + } + + /** + * Binary Vector values gets stored as byte[], hence for dimension of the binary vector we have to multiply the + * byte[] size with {@link Byte#SIZE} + * @return int + */ + @Override + public int dimension() { + return super.dimension() * Byte.SIZE; + } +} diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNByteVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNByteVectorValues.java new file mode 100644 index 000000000..ccbbfab77 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNByteVectorValues.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.vectorvalues; + +import lombok.ToString; +import org.apache.lucene.codecs.KnnFieldVectorsWriter; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.ByteVectorValues; + +import java.io.IOException; + +/** + * Concrete implementation of {@link KNNVectorValues} that returns float[] as vector and provides an abstraction over + * {@link BinaryDocValues}, {@link ByteVectorValues}, {@link KnnFieldVectorsWriter} etc. + */ +@ToString(callSuper = true) +public class KNNByteVectorValues extends KNNVectorValues { + KNNByteVectorValues(KNNVectorValuesIterator vectorValuesIterator) { + super(vectorValuesIterator); + } + + @Override + public byte[] getVector() throws IOException { + final byte[] vector = VectorValueExtractorStrategy.extractByteVector(vectorValuesIterator); + this.dimension = vector.length; + return vector; + } +} diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java new file mode 100644 index 000000000..174f3a89e --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.vectorvalues; + +import org.apache.lucene.codecs.KnnFieldVectorsWriter; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.FloatVectorValues; + +import java.io.IOException; + +/** + * Concrete implementation of {@link KNNVectorValues} that returns float[] as vector and provides an abstraction over + * {@link BinaryDocValues}, {@link FloatVectorValues}, {@link KnnFieldVectorsWriter} etc. + */ +public class KNNFloatVectorValues extends KNNVectorValues { + KNNFloatVectorValues(final KNNVectorValuesIterator vectorValuesIterator) { + super(vectorValuesIterator); + } + + @Override + public float[] getVector() throws IOException { + final float[] vector = VectorValueExtractorStrategy.extractFloatVector(vectorValuesIterator); + this.dimension = vector.length; + return vector; + } +} diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java new file mode 100644 index 000000000..c4ed64bc2 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.vectorvalues; + +import lombok.ToString; +import org.apache.lucene.codecs.KnnFieldVectorsWriter; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FloatVectorValues; + +import java.io.IOException; + +/** + * An abstract class to iterate over KNNVectors, as KNNVectors are stored as different representation like + * {@link BinaryDocValues}, {@link FloatVectorValues}, {@link ByteVectorValues}, {@link KnnFieldVectorsWriter} etc. + * @param + */ +@ToString +public abstract class KNNVectorValues { + + protected final KNNVectorValuesIterator vectorValuesIterator; + protected int dimension; + + protected KNNVectorValues(final KNNVectorValuesIterator vectorValuesIterator) { + this.vectorValuesIterator = vectorValuesIterator; + } + + /** + * Return a vector reference. If you are adding this address in a List/Map ensure that you are copying the vector first. + * This is to ensure that we keep the heap and latency in check by reducing the copies of vectors. + * + * @return T an array of byte[], float[] + * @throws IOException if we are not able to get the vector + */ + public abstract T getVector() throws IOException; + + /** + * Dimension of vector is returned. Do call getVector function first before calling this function otherwise you will get 0 value. + * @return int + */ + public int dimension() { + assert docId() != -1 && dimension != 0 : "Cannot get dimension before we retrieve a vector from KNNVectorValues"; + return dimension; + } + + /** + * Returns the total live docs for KNNVectorValues. + * @return long + */ + public long totalLiveDocs() { + return vectorValuesIterator.liveDocs(); + } + + /** + * Returns the current docId where the iterator is pointing to. + * @return int + */ + public int docId() { + return vectorValuesIterator.docId(); + } + + /** + * Advances to a specific docId. Ensure that the passed docId is greater than current docId where Iterator is + * pointing to, otherwise + * {@link IOException} will be thrown + * @return int + * @throws IOException if we are not able to move to the passed docId. + */ + public int advance(int docId) throws IOException { + return vectorValuesIterator.advance(docId); + } + + /** + * Move to nextDocId. + * @return int + * @throws IOException if we cannot move to next docId + */ + public int nextDoc() throws IOException { + return vectorValuesIterator.nextDoc(); + } + +} diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java new file mode 100644 index 000000000..5b6558f32 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.vectorvalues; + +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.search.DocIdSetIterator; +import org.opensearch.knn.index.VectorDataType; + +import java.util.Map; + +/** + * A factory class that provides various methods to create the {@link KNNVectorValues}. + */ +public final class KNNVectorValuesFactory { + + /** + * Returns a {@link KNNVectorValues} for the given {@link DocIdSetIterator} and {@link VectorDataType} + * + * @param vectorDataType {@link VectorDataType} + * @param docIdSetIterator {@link DocIdSetIterator} + * @return {@link KNNVectorValues} of type float[] + */ + public static KNNVectorValues getVectorValues(final VectorDataType vectorDataType, final DocIdSetIterator docIdSetIterator) { + return getVectorValues(vectorDataType, new KNNVectorValuesIterator.DocIdsIteratorValues(docIdSetIterator)); + } + + /** + * Returns a {@link KNNVectorValues} for the given {@link DocIdSetIterator} and a Map of docId and vectors. + * + * @param vectorDataType {@link VectorDataType} + * @param docIdWithFieldSet {@link DocsWithFieldSet} + * @return {@link KNNVectorValues} of type float[] + */ + public static KNNVectorValues getVectorValues( + final VectorDataType vectorDataType, + final DocsWithFieldSet docIdWithFieldSet, + final Map vectors + ) { + return getVectorValues(vectorDataType, new KNNVectorValuesIterator.FieldWriterIteratorValues(docIdWithFieldSet, vectors)); + } + + @SuppressWarnings("unchecked") + private static KNNVectorValues getVectorValues( + final VectorDataType vectorDataType, + final KNNVectorValuesIterator knnVectorValuesIterator + ) { + switch (vectorDataType) { + case FLOAT: + return (KNNVectorValues) new KNNFloatVectorValues(knnVectorValuesIterator); + case BYTE: + return (KNNVectorValues) new KNNByteVectorValues(knnVectorValuesIterator); + case BINARY: + return (KNNVectorValues) new KNNBinaryVectorValues(knnVectorValuesIterator); + } + throw new IllegalArgumentException("Invalid Vector data type provided, hence cannot return VectorValues"); + } +} diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesIterator.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesIterator.java new file mode 100644 index 000000000..4f1445c1c --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesIterator.java @@ -0,0 +1,188 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.vectorvalues; + +import lombok.NonNull; +import org.apache.lucene.codecs.KnnFieldVectorsWriter; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.search.DocIdSetIterator; +import org.opensearch.knn.index.codec.util.KNNCodecUtil; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +/** + * An abstract class that provides an iterator to iterate over KNNVectors, as KNNVectors are stored as different + * representation like {@link BinaryDocValues}, {@link FloatVectorValues}, FieldWriter etc. How to iterate using this + * iterator please refer {@link DocIdsIteratorValues} java docs. + */ +public interface KNNVectorValuesIterator { + + /** + * Returns the current docId where the iterator is pointing to. + * @return int + */ + int docId(); + + /** + * Advances to a specific docId. Ensure that the passed docId is greater than current docId where Iterator is + * pointing to, otherwise + * {@link IOException} will be thrown + * @return int + * @throws IOException if we are not able to move to the passed docId. + */ + int advance(int docId) throws IOException; + + /** + * Move to nextDocId. If no more docs are present then {@link DocIdSetIterator#NO_MORE_DOCS} will be returned. + * @return int + * @throws IOException if we cannot move to next docId + */ + int nextDoc() throws IOException; + + /** + * Return a {@link DocIdSetIterator} + * @return {@link DocIdSetIterator} + */ + DocIdSetIterator getDocIdSetIterator(); + + /** + * Total number of live doc which will the iterator will iterate upon. + * @return long: total number of live docs + */ + long liveDocs(); + + /** + * Returns the {@link VectorValueExtractorStrategy} to extract the vector from the iterator. + * @return VectorValueExtractorStrategy + */ + VectorValueExtractorStrategy getVectorExtractorStrategy(); + + /** + * A DocIdsIteratorValues provides a common iteration logic for all Values that implements + * {@link DocIdSetIterator} interface. Example: {@link BinaryDocValues}, {@link FloatVectorValues} etc. + */ + class DocIdsIteratorValues implements KNNVectorValuesIterator { + protected DocIdSetIterator docIdSetIterator; + private static final List> VALID_ITERATOR_INSTANCE = List.of( + (itr) -> itr instanceof BinaryDocValues, + (itr) -> itr instanceof FloatVectorValues, + (itr) -> itr instanceof ByteVectorValues + ); + + DocIdsIteratorValues(@NonNull final DocIdSetIterator docIdSetIterator) { + validateIteratorType(docIdSetIterator); + this.docIdSetIterator = docIdSetIterator; + } + + @Override + public int docId() { + return docIdSetIterator.docID(); + } + + @Override + public int advance(int docId) throws IOException { + return docIdSetIterator.advance(docId); + } + + @Override + public int nextDoc() throws IOException { + return docIdSetIterator.nextDoc(); + } + + @Override + public DocIdSetIterator getDocIdSetIterator() { + return docIdSetIterator; + } + + @Override + public long liveDocs() { + if (docIdSetIterator instanceof BinaryDocValues) { + return KNNCodecUtil.getTotalLiveDocsCount((BinaryDocValues) docIdSetIterator); + } else if (docIdSetIterator instanceof FloatVectorValues || docIdSetIterator instanceof ByteVectorValues) { + return docIdSetIterator.cost(); + } + throw new IllegalArgumentException( + "DocIdSetIterator present is not of valid type. Valid types are: BinaryDocValues, FloatVectorValues and ByteVectorValues" + ); + } + + @Override + public VectorValueExtractorStrategy getVectorExtractorStrategy() { + return new VectorValueExtractorStrategy.DISIVectorExtractor(); + } + + private void validateIteratorType(final DocIdSetIterator docIdSetIterator) { + VALID_ITERATOR_INSTANCE.stream() + .map(v -> v.apply(docIdSetIterator)) + .filter(Boolean::booleanValue) + .findFirst() + .orElseThrow( + () -> new IllegalArgumentException( + "DocIdSetIterator present is not of valid type. Valid types are: BinaryDocValues, FloatVectorValues and ByteVectorValues" + ) + ); + } + } + + /** + * A FieldWriterIteratorValues is mainly used when Vectors are stored in {@link KnnFieldVectorsWriter} interface. + */ + class FieldWriterIteratorValues implements KNNVectorValuesIterator { + private final DocIdSetIterator docIdSetIterator; + private final Map vectors; + + FieldWriterIteratorValues(@NonNull final DocsWithFieldSet docsWithFieldSet, @NonNull final Map vectors) { + assert docsWithFieldSet.iterator().cost() == vectors.size(); + this.vectors = vectors; + this.docIdSetIterator = docsWithFieldSet.iterator(); + } + + @Override + public int docId() { + return docIdSetIterator.docID(); + } + + @Override + public int advance(int docId) throws IOException { + return docIdSetIterator.advance(docId); + } + + @Override + public int nextDoc() throws IOException { + return docIdSetIterator.nextDoc(); + } + + /** + * Returns a Map of docId and vector. + * @return {@link Map} + */ + public T vectorsValue() { + return vectors.get(docId()); + } + + @Override + public DocIdSetIterator getDocIdSetIterator() { + return docIdSetIterator; + } + + @Override + public long liveDocs() { + return docIdSetIterator.cost(); + } + + @Override + public VectorValueExtractorStrategy getVectorExtractorStrategy() { + return new VectorValueExtractorStrategy.FieldWriterIteratorVectorExtractor(); + } + } + +} diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java b/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java new file mode 100644 index 000000000..07db4e7f6 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java @@ -0,0 +1,126 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.vectorvalues; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.BytesRef; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.util.KNNVectorSerializer; +import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; + +import java.io.IOException; + +/** + * Provides different strategies to extract the vectors from different {@link KNNVectorValuesIterator} + */ +interface VectorValueExtractorStrategy { + + /** + * Extract a float vector from KNNVectorValuesIterator. + * @param iterator {@link KNNVectorValuesIterator} + * @return float[] + * @throws IOException exception while retrieving the vectors + */ + static float[] extractFloatVector(final KNNVectorValuesIterator iterator) throws IOException { + return iterator.getVectorExtractorStrategy().extract(VectorDataType.FLOAT, iterator); + } + + /** + * Extract a byte vector from KNNVectorValuesIterator. + * @param iterator {@link KNNVectorValuesIterator} + * @return byte[] + * @throws IOException exception while retrieving the vectors + */ + static byte[] extractByteVector(final KNNVectorValuesIterator iterator) throws IOException { + return iterator.getVectorExtractorStrategy().extract(VectorDataType.BYTE, iterator); + } + + /** + * Extract a binary vector which is represented as byte[] from KNNVectorValuesIterator. + * @param iterator {@link KNNVectorValuesIterator} + * @return byte[] + * @throws IOException exception while retrieving the vectors + */ + static byte[] extractBinaryVector(final KNNVectorValuesIterator iterator) throws IOException { + return iterator.getVectorExtractorStrategy().extract(VectorDataType.BINARY, iterator); + } + + /** + * Extract Vector based on the vector datatype and vector values iterator. + * @param vectorDataType {@link VectorDataType} + * @param vectorValuesIterator {@link KNNVectorValuesIterator} + * @return vector + * @param could be of type float[], byte[] + * @throws IOException exception during extracting the vectors + */ + T extract(VectorDataType vectorDataType, KNNVectorValuesIterator vectorValuesIterator) throws IOException; + + /** + * Strategy to extract the vector from {@link KNNVectorValuesIterator.DocIdsIteratorValues} + */ + class DISIVectorExtractor implements VectorValueExtractorStrategy { + @Override + public T extract(final VectorDataType vectorDataType, final KNNVectorValuesIterator vectorValuesIterator) throws IOException { + final DocIdSetIterator docIdSetIterator = vectorValuesIterator.getDocIdSetIterator(); + switch (vectorDataType) { + case FLOAT: + if (docIdSetIterator instanceof BinaryDocValues) { + final BinaryDocValues values = (BinaryDocValues) docIdSetIterator; + return (T) getFloatVectorFromByteRef(values.binaryValue()); + } else if (docIdSetIterator instanceof FloatVectorValues) { + return (T) ((FloatVectorValues) docIdSetIterator).vectorValue(); + } + throw new IllegalArgumentException( + "VectorValuesIterator is not of a valid type. Valid Types are: BinaryDocValues and FloatVectorValues" + ); + case BYTE: + case BINARY: + if (docIdSetIterator instanceof BinaryDocValues) { + final BinaryDocValues values = (BinaryDocValues) docIdSetIterator; + final BytesRef bytesRef = values.binaryValue(); + return (T) ArrayUtil.copyOfSubArray(bytesRef.bytes, bytesRef.offset, bytesRef.offset + bytesRef.length); + } else if (docIdSetIterator instanceof ByteVectorValues) { + return (T) ((ByteVectorValues) docIdSetIterator).vectorValue(); + } + throw new IllegalArgumentException( + "VectorValuesIterator is not of a valid type. Valid Types are: BinaryDocValues and ByteVectorValues" + ); + } + throw new IllegalArgumentException("Valid Vector data type not passed to extract vector from DISIVectorExtractor strategy"); + } + + private float[] getFloatVectorFromByteRef(final BytesRef bytesRef) { + final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByBytesRef(bytesRef); + return vectorSerializer.byteToFloatArray(bytesRef); + } + } + + /** + * Strategy to extract the vector from {@link KNNVectorValuesIterator.FieldWriterIteratorValues} + */ + class FieldWriterIteratorVectorExtractor implements VectorValueExtractorStrategy { + + @SuppressWarnings("unchecked") + @Override + public T extract(final VectorDataType vectorDataType, final KNNVectorValuesIterator vectorValuesIterator) throws IOException { + switch (vectorDataType) { + case FLOAT: + return (T) ((KNNVectorValuesIterator.FieldWriterIteratorValues) vectorValuesIterator).vectorsValue(); + case BYTE: + case BINARY: + return (T) ((KNNVectorValuesIterator.FieldWriterIteratorValues) vectorValuesIterator).vectorsValue(); + } + throw new IllegalArgumentException( + "Valid Vector data type not passed to extract vector from FieldWriterIteratorVectorExtractor strategy" + ); + } + } + +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80BinaryDocValuesTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80BinaryDocValuesTests.java index 620559867..727cb8e6e 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80BinaryDocValuesTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80BinaryDocValuesTests.java @@ -10,7 +10,7 @@ import org.apache.lucene.index.DocIDMerger; import org.apache.lucene.index.MergeState; import org.opensearch.knn.KNNTestCase; -import org.opensearch.knn.index.codec.KNNCodecTestUtil; +import org.opensearch.knn.index.vectorvalues.TestVectorValues; import org.opensearch.knn.index.codec.util.BinaryDocValuesSub; import java.io.IOException; @@ -30,7 +30,7 @@ public void testNextDoc() throws IOException { public int get(int docID) { return expectedDoc; } - }, new KNNCodecTestUtil.ConstantVectorBinaryDocValues(10, 128, 1.0f)); + }, new TestVectorValues.ConstantVectorBinaryDocValues(10, 128, 1.0f)); DocIDMerger docIDMerger = DocIDMerger.of(ImmutableList.of(sub), false); KNN80BinaryDocValues knn80BinaryDocValues = new KNN80BinaryDocValues(docIDMerger); @@ -53,7 +53,7 @@ public void testCost() { } public void testBinaryValue() throws IOException { - BinaryDocValues binaryDocValues = new KNNCodecTestUtil.ConstantVectorBinaryDocValues(10, 128, 1.0f); + BinaryDocValues binaryDocValues = new TestVectorValues.ConstantVectorBinaryDocValues(10, 128, 1.0f); BinaryDocValuesSub sub = new BinaryDocValuesSub(new MergeState.DocMap() { @Override public int get(int docID) { diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java index aabfc2d9f..ce8fad384 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java @@ -27,6 +27,7 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.vectorvalues.TestVectorValues; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.SpaceType; @@ -69,8 +70,6 @@ import static org.opensearch.knn.index.codec.KNNCodecTestUtil.assertFileInCorrectLocation; import static org.opensearch.knn.index.codec.KNNCodecTestUtil.assertLoadableByEngine; import static org.opensearch.knn.index.codec.KNNCodecTestUtil.assertValidFooter; -import static org.opensearch.knn.index.codec.KNNCodecTestUtil.getRandomVectors; -import static org.opensearch.knn.index.codec.KNNCodecTestUtil.RandomVectorDocValuesProducer; public class KNN80DocValuesConsumerTests extends KNNTestCase { @@ -155,7 +154,10 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, public void testAddKNNBinaryField_noVectors() throws IOException { // When there are no new vectors, no more graph index requests should be added - RandomVectorDocValuesProducer randomVectorDocValuesProducer = new RandomVectorDocValuesProducer(0, 128); + TestVectorValues.RandomVectorDocValuesProducer randomVectorDocValuesProducer = new TestVectorValues.RandomVectorDocValuesProducer( + 0, + 128 + ); Long initialGraphIndexRequests = KNNCounter.GRAPH_INDEX_REQUESTS.getCount(); Long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); Long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); @@ -224,7 +226,10 @@ public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException // Add documents to the field KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); - RandomVectorDocValuesProducer randomVectorDocValuesProducer = new RandomVectorDocValuesProducer(docsInSegment, dimension); + TestVectorValues.RandomVectorDocValuesProducer randomVectorDocValuesProducer = new TestVectorValues.RandomVectorDocValuesProducer( + docsInSegment, + dimension + ); knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true, true); // The document should be created in the correct location @@ -277,7 +282,10 @@ public void testAddKNNBinaryField_fromScratch_nmslibLegacy() throws IOException // Add documents to the field KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); - RandomVectorDocValuesProducer randomVectorDocValuesProducer = new RandomVectorDocValuesProducer(docsInSegment, dimension); + TestVectorValues.RandomVectorDocValuesProducer randomVectorDocValuesProducer = new TestVectorValues.RandomVectorDocValuesProducer( + docsInSegment, + dimension + ); knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true, true); // The document should be created in the correct location @@ -338,7 +346,10 @@ public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException // Add documents to the field KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); - RandomVectorDocValuesProducer randomVectorDocValuesProducer = new RandomVectorDocValuesProducer(docsInSegment, dimension); + TestVectorValues.RandomVectorDocValuesProducer randomVectorDocValuesProducer = new TestVectorValues.RandomVectorDocValuesProducer( + docsInSegment, + dimension + ); knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true, true); // The document should be created in the correct location @@ -401,7 +412,10 @@ public void testAddKNNBinaryField_whenFaissBinary_thenAdded() throws IOException // Add documents to the field KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); - RandomVectorDocValuesProducer randomVectorDocValuesProducer = new RandomVectorDocValuesProducer(docsInSegment, dimension); + TestVectorValues.RandomVectorDocValuesProducer randomVectorDocValuesProducer = new TestVectorValues.RandomVectorDocValuesProducer( + docsInSegment, + dimension + ); knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true, true); // The document should be created in the correct location @@ -428,7 +442,7 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio int dimension = 16; String modelId = "test-model-id"; - float[][] trainingData = getRandomVectors(200, dimension); + float[][] trainingData = TestVectorValues.getRandomVectors(200, dimension); long trainingPtr = JNIService.transferVectors(0, trainingData); Map parameters = ImmutableMap.of( @@ -497,7 +511,10 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio // Add documents to the field KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); - RandomVectorDocValuesProducer randomVectorDocValuesProducer = new RandomVectorDocValuesProducer(docsInSegment, dimension); + TestVectorValues.RandomVectorDocValuesProducer randomVectorDocValuesProducer = new TestVectorValues.RandomVectorDocValuesProducer( + docsInSegment, + dimension + ); knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true, true); // The document should be created in the correct location diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriterTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriterTests.java index 16d6b20da..29e3531cf 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriterTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriterTests.java @@ -40,12 +40,12 @@ public void testCreate_ForDifferentInputs_thenSuccess() { byteWriter.addValue(1, new byte[] { 1, 2 }); } + @SuppressWarnings("unchecked") public void testAddValue_ForDifferentInputs_thenSuccess() { final FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); - final NativeEngineFieldVectorsWriter floatWriter = new NativeEngineFieldVectorsWriter<>( - fieldInfo, - InfoStream.getDefault() - ); + Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.FLOAT32); + final NativeEngineFieldVectorsWriter floatWriter = (NativeEngineFieldVectorsWriter) NativeEngineFieldVectorsWriter + .create(fieldInfo, InfoStream.getDefault()); final float[] vec1 = new float[] { 1.0f, 2.0f }; final float[] vec2 = new float[] { 2.0f, 2.0f }; floatWriter.addValue(1, vec1); @@ -53,9 +53,11 @@ public void testAddValue_ForDifferentInputs_thenSuccess() { Assert.assertEquals(vec1, floatWriter.getVectors().get(1)); Assert.assertEquals(vec2, floatWriter.getVectors().get(2)); - Mockito.verify(fieldInfo, Mockito.never()).getVectorEncoding(); + Mockito.verify(fieldInfo).getVectorEncoding(); - final NativeEngineFieldVectorsWriter byteWriter = new NativeEngineFieldVectorsWriter<>(fieldInfo, InfoStream.getDefault()); + Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.BYTE); + final NativeEngineFieldVectorsWriter byteWriter = (NativeEngineFieldVectorsWriter) NativeEngineFieldVectorsWriter + .create(fieldInfo, InfoStream.getDefault()); final byte[] bvec1 = new byte[] { 1, 2 }; final byte[] bvec2 = new byte[] { 2, 2 }; byteWriter.addValue(1, bvec1); @@ -63,34 +65,36 @@ public void testAddValue_ForDifferentInputs_thenSuccess() { Assert.assertEquals(bvec1, byteWriter.getVectors().get(1)); Assert.assertEquals(bvec2, byteWriter.getVectors().get(2)); - Mockito.verify(fieldInfo, Mockito.never()).getVectorEncoding(); + Mockito.verify(fieldInfo, Mockito.times(2)).getVectorEncoding(); } + @SuppressWarnings("unchecked") public void testCopyValue_whenValidInput_thenException() { final FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); - final NativeEngineFieldVectorsWriter floatWriter = new NativeEngineFieldVectorsWriter<>( - fieldInfo, - InfoStream.getDefault() - ); + Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.FLOAT32); + final NativeEngineFieldVectorsWriter floatWriter = (NativeEngineFieldVectorsWriter) NativeEngineFieldVectorsWriter + .create(fieldInfo, InfoStream.getDefault()); expectThrows(UnsupportedOperationException.class, () -> floatWriter.copyValue(new float[3])); - final NativeEngineFieldVectorsWriter byteWriter = new NativeEngineFieldVectorsWriter<>(fieldInfo, InfoStream.getDefault()); + Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.BYTE); + final NativeEngineFieldVectorsWriter byteWriter = (NativeEngineFieldVectorsWriter) NativeEngineFieldVectorsWriter + .create(fieldInfo, InfoStream.getDefault()); expectThrows(UnsupportedOperationException.class, () -> byteWriter.copyValue(new byte[3])); } + @SuppressWarnings("unchecked") public void testRamByteUsed_whenValidInput_thenSuccess() { final FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.FLOAT32); Mockito.when(fieldInfo.getVectorDimension()).thenReturn(2); - final NativeEngineFieldVectorsWriter floatWriter = new NativeEngineFieldVectorsWriter<>( - fieldInfo, - InfoStream.getDefault() - ); + final NativeEngineFieldVectorsWriter floatWriter = (NativeEngineFieldVectorsWriter) NativeEngineFieldVectorsWriter + .create(fieldInfo, InfoStream.getDefault()); // testing for value > 0 as we don't have a concrete way to find out expected bytes. This can OS dependent too. Assert.assertTrue(floatWriter.ramBytesUsed() > 0); Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.BYTE); - final NativeEngineFieldVectorsWriter byteWriter = new NativeEngineFieldVectorsWriter<>(fieldInfo, InfoStream.getDefault()); + final NativeEngineFieldVectorsWriter byteWriter = (NativeEngineFieldVectorsWriter) NativeEngineFieldVectorsWriter + .create(fieldInfo, InfoStream.getDefault()); // testing for value > 0 as we don't have a concrete way to find out expected bytes. This can OS dependent too. Assert.assertTrue(byteWriter.ramBytesUsed() > 0); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java index 3893cc994..2afd86a04 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java @@ -10,17 +10,11 @@ import lombok.Builder; import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.CodecUtil; -import org.apache.lucene.codecs.DocValuesProducer; -import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.IndexOptions; -import org.apache.lucene.index.NumericDocValues; import org.apache.lucene.index.SegmentInfo; import org.apache.lucene.index.SegmentWriteState; -import org.apache.lucene.index.SortedDocValues; -import org.apache.lucene.index.SortedNumericDocValues; -import org.apache.lucene.index.SortedSetDocValues; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.Sort; @@ -29,7 +23,6 @@ import org.apache.lucene.store.FSDirectory; import org.apache.lucene.store.FilterDirectory; import org.apache.lucene.store.IOContext; -import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.StringHelper; import org.apache.lucene.util.Version; import java.util.Set; @@ -37,19 +30,15 @@ import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.query.KNNQueryResult; import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.codec.util.KNNVectorSerializer; -import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.jni.JNIService; import java.io.IOException; import java.nio.file.Paths; -import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.Map; -import static com.carrotsearch.randomizedtesting.RandomizedTest.randomFloat; import static org.junit.Assert.assertTrue; import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; @@ -198,126 +187,6 @@ public FieldInfo build() { } } - public static abstract class VectorDocValues extends BinaryDocValues { - - final int count; - final int dimension; - int current; - KNNVectorSerializer knnVectorSerializer; - - public VectorDocValues(int count, int dimension) { - this.count = count; - this.dimension = dimension; - this.current = -1; - this.knnVectorSerializer = KNNVectorSerializerFactory.getDefaultSerializer(); - } - - @Override - public boolean advanceExact(int target) throws IOException { - return false; - } - - @Override - public int docID() { - if (this.current > this.count) { - return BinaryDocValues.NO_MORE_DOCS; - } - return this.current; - } - - @Override - public int nextDoc() throws IOException { - return advance(current + 1); - } - - @Override - public int advance(int target) throws IOException { - current = target; - if (current >= count) { - current = NO_MORE_DOCS; - } - return current; - } - - @Override - public long cost() { - return 0; - } - } - - public static class ConstantVectorBinaryDocValues extends VectorDocValues { - - private final BytesRef value; - - public ConstantVectorBinaryDocValues(int count, int dimension, float value) { - super(count, dimension); - float[] array = new float[dimension]; - Arrays.fill(array, value); - this.value = new BytesRef(knnVectorSerializer.floatToByteArray(array)); - } - - @Override - public BytesRef binaryValue() throws IOException { - return value; - } - } - - public static class RandomVectorBinaryDocValues extends VectorDocValues { - - public RandomVectorBinaryDocValues(int count, int dimension) { - super(count, dimension); - } - - @Override - public BytesRef binaryValue() throws IOException { - return new BytesRef(knnVectorSerializer.floatToByteArray(getRandomVector(dimension))); - } - } - - public static class RandomVectorDocValuesProducer extends DocValuesProducer { - - final RandomVectorBinaryDocValues randomBinaryDocValues; - - public RandomVectorDocValuesProducer(int count, int dimension) { - this.randomBinaryDocValues = new RandomVectorBinaryDocValues(count, dimension); - } - - @Override - public NumericDocValues getNumeric(FieldInfo field) { - return null; - } - - @Override - public BinaryDocValues getBinary(FieldInfo field) throws IOException { - return randomBinaryDocValues; - } - - @Override - public SortedDocValues getSorted(FieldInfo field) { - return null; - } - - @Override - public SortedNumericDocValues getSortedNumeric(FieldInfo field) { - return null; - } - - @Override - public SortedSetDocValues getSortedSet(FieldInfo field) { - return null; - } - - @Override - public void checkIntegrity() { - - } - - @Override - public void close() throws IOException { - - } - } - public static void assertFileInCorrectLocation(SegmentWriteState state, String expectedFile) throws IOException { assertTrue(Set.of(state.directory.listAll()).contains(expectedFile)); } @@ -378,22 +247,6 @@ public static void assertBinaryIndexLoadableByEngine( JNIService.free(indexPtr, knnEngine); } - public static float[][] getRandomVectors(int count, int dimension) { - float[][] data = new float[count][dimension]; - for (int i = 0; i < count; i++) { - data[i] = getRandomVector(dimension); - } - return data; - } - - public static float[] getRandomVector(int dimension) { - float[] data = new float[dimension]; - for (int i = 0; i < dimension; i++) { - data[i] = randomFloat(); - } - return data; - } - @Builder(builderMethodName = "segmentInfoBuilder") public static SegmentInfo newSegmentInfo(final Directory directory, final String segmentName, int docsInSegment, final Codec codec) { return new SegmentInfo( diff --git a/src/test/java/org/opensearch/knn/index/codec/util/BinaryDocValuesSubTests.java b/src/test/java/org/opensearch/knn/index/codec/util/BinaryDocValuesSubTests.java index a2105af3a..757930dcd 100644 --- a/src/test/java/org/opensearch/knn/index/codec/util/BinaryDocValuesSubTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/util/BinaryDocValuesSubTests.java @@ -7,14 +7,14 @@ import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.MergeState; import org.opensearch.knn.KNNTestCase; -import org.opensearch.knn.index.codec.KNNCodecTestUtil; +import org.opensearch.knn.index.vectorvalues.TestVectorValues; import java.io.IOException; public class BinaryDocValuesSubTests extends KNNTestCase { public void testNextDoc() throws IOException { - BinaryDocValues binaryDocValues = new KNNCodecTestUtil.ConstantVectorBinaryDocValues(10, 128, 2.0f); + BinaryDocValues binaryDocValues = new TestVectorValues.ConstantVectorBinaryDocValues(10, 128, 2.0f); MergeState.DocMap docMap = new MergeState.DocMap() { @Override public int get(int docID) { @@ -28,7 +28,7 @@ public int get(int docID) { } public void testGetValues() { - BinaryDocValues binaryDocValues = new KNNCodecTestUtil.ConstantVectorBinaryDocValues(10, 128, 2.0f); + BinaryDocValues binaryDocValues = new TestVectorValues.ConstantVectorBinaryDocValues(10, 128, 2.0f); MergeState.DocMap docMap = new MergeState.DocMap() { @Override public int get(int docID) { diff --git a/src/test/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactoryTests.java b/src/test/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactoryTests.java new file mode 100644 index 000000000..9827cb03b --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactoryTests.java @@ -0,0 +1,61 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.vectorvalues; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.DocsWithFieldSet; +import org.junit.Assert; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.VectorDataType; + +import java.util.Map; + +public class KNNVectorValuesFactoryTests extends KNNTestCase { + private static final int COUNT = 10; + private static final int DIMENSION = 10; + + public void testGetVectorValuesFromDISI_whenValidInput_thenSuccess() { + final BinaryDocValues binaryDocValues = new TestVectorValues.RandomVectorBinaryDocValues(COUNT, DIMENSION); + final KNNVectorValues floatVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, binaryDocValues); + Assert.assertNotNull(floatVectorValues); + + final KNNVectorValues byteVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.BYTE, binaryDocValues); + Assert.assertNotNull(byteVectorValues); + + final KNNVectorValues binaryVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.BINARY, binaryDocValues); + Assert.assertNotNull(binaryVectorValues); + } + + public void testGetVectorValuesUsingDocWithFieldSet_whenValidInput_thenSuccess() { + final DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet(); + docsWithFieldSet.add(0); + docsWithFieldSet.add(1); + final Map floatVectorMap = Map.of(0, new float[] { 1, 2 }, 1, new float[] { 2, 3 }); + final KNNVectorValues floatVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + docsWithFieldSet, + floatVectorMap + ); + Assert.assertNotNull(floatVectorValues); + + final Map byteVectorMap = Map.of(0, new byte[] { 4, 5 }, 1, new byte[] { 6, 7 }); + + final KNNVectorValues byteVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.BYTE, + docsWithFieldSet, + byteVectorMap + ); + Assert.assertNotNull(byteVectorValues); + + final KNNVectorValues binaryVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.BINARY, + docsWithFieldSet, + byteVectorMap + ); + Assert.assertNotNull(binaryVectorValues); + } + +} diff --git a/src/test/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesTests.java b/src/test/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesTests.java new file mode 100644 index 000000000..f5a1351ae --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesTests.java @@ -0,0 +1,145 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.vectorvalues; + +import lombok.SneakyThrows; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.search.DocIdSetIterator; +import org.junit.Assert; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.VectorDataType; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +public class KNNVectorValuesTests extends KNNTestCase { + + @SneakyThrows + public void testFloatVectorValues_whenValidInput_thenSuccess() { + final List floatArray = List.of(new float[] { 1, 2 }, new float[] { 2, 3 }); + final int dimension = floatArray.get(0).length; + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + floatArray + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + new CompareVectorValues().validateVectorValues(knnVectorValues, floatArray, dimension, true); + + final DocsWithFieldSet docsWithFieldSet = getDocIdSetIterator(floatArray.size()); + + final Map vectorsMap = Map.of(0, floatArray.get(0), 1, floatArray.get(1)); + final KNNVectorValues knnVectorValuesForFieldWriter = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + docsWithFieldSet, + vectorsMap + ); + new CompareVectorValues().validateVectorValues(knnVectorValuesForFieldWriter, floatArray, dimension, false); + + final TestVectorValues.PredefinedFloatVectorBinaryDocValues preDefinedFloatVectorValues = + new TestVectorValues.PredefinedFloatVectorBinaryDocValues(floatArray); + final KNNVectorValues knnFloatVectorValuesBinaryDocValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + preDefinedFloatVectorValues + ); + new CompareVectorValues().validateVectorValues(knnFloatVectorValuesBinaryDocValues, floatArray, dimension, false); + } + + @SneakyThrows + public void testByteVectorValues_whenValidInput_thenSuccess() { + final List byteArray = List.of(new byte[] { 4, 5 }, new byte[] { 6, 7 }); + final int dimension = byteArray.get(0).length; + final TestVectorValues.PreDefinedByteVectorValues randomVectorValues = new TestVectorValues.PreDefinedByteVectorValues(byteArray); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.BYTE, randomVectorValues); + new CompareVectorValues().validateVectorValues(knnVectorValues, byteArray, dimension, true); + + final DocsWithFieldSet docsWithFieldSet = getDocIdSetIterator(byteArray.size()); + final Map vectorsMap = Map.of(0, byteArray.get(0), 1, byteArray.get(1)); + final KNNVectorValues knnVectorValuesForFieldWriter = KNNVectorValuesFactory.getVectorValues( + VectorDataType.BYTE, + docsWithFieldSet, + vectorsMap + ); + new CompareVectorValues().validateVectorValues(knnVectorValuesForFieldWriter, byteArray, dimension, false); + + final TestVectorValues.PredefinedByteVectorBinaryDocValues preDefinedByteVectorValues = + new TestVectorValues.PredefinedByteVectorBinaryDocValues(byteArray); + final KNNVectorValues knnBinaryVectorValuesBinaryDocValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.BYTE, + preDefinedByteVectorValues + ); + new CompareVectorValues().validateVectorValues(knnBinaryVectorValuesBinaryDocValues, byteArray, dimension, false); + } + + @SneakyThrows + public void testBinaryVectorValues_whenValidInput_thenSuccess() { + final List byteArray = List.of(new byte[] { 1, 5, 8 }, new byte[] { 6, 7, 9 }); + int dimension = byteArray.get(0).length * 8; + final TestVectorValues.PreDefinedBinaryVectorValues randomVectorValues = new TestVectorValues.PreDefinedBinaryVectorValues( + byteArray + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.BINARY, randomVectorValues); + new CompareVectorValues().validateVectorValues(knnVectorValues, byteArray, dimension, true); + + final DocsWithFieldSet docsWithFieldSet = getDocIdSetIterator(byteArray.size()); + final Map vectorsMap = Map.of(0, byteArray.get(0), 1, byteArray.get(1)); + final KNNBinaryVectorValues knnVectorValuesForFieldWriter = (KNNBinaryVectorValues) KNNVectorValuesFactory.getVectorValues( + VectorDataType.BINARY, + docsWithFieldSet, + vectorsMap + ); + new CompareVectorValues().validateVectorValues(knnVectorValuesForFieldWriter, byteArray, dimension, false); + + final TestVectorValues.PredefinedByteVectorBinaryDocValues preDefinedByteVectorValues = + new TestVectorValues.PredefinedByteVectorBinaryDocValues(byteArray); + final KNNVectorValues knnBinaryVectorValuesBinaryDocValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.BINARY, + preDefinedByteVectorValues + ); + new CompareVectorValues().validateVectorValues(knnBinaryVectorValuesBinaryDocValues, byteArray, dimension, false); + } + + public void testDocIdsIteratorValues_whenInvalidDisi_thenThrowException() { + Assert.assertThrows( + IllegalArgumentException.class, + () -> new KNNVectorValuesIterator.DocIdsIteratorValues(new TestVectorValues.NotBinaryDocValues()) + ); + } + + private DocsWithFieldSet getDocIdSetIterator(int numberOfDocIds) { + final DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet(); + for (int i = 0; i < numberOfDocIds; i++) { + docsWithFieldSet.add(i); + } + return docsWithFieldSet; + } + + private class CompareVectorValues { + void validateVectorValues(KNNVectorValues vectorValues, List vectors, int dimension, boolean validateAddress) + throws IOException { + Assert.assertEquals(vectorValues.totalLiveDocs(), vectors.size()); + int docId, i = 0; + T oldActual = null; + int oldDocId = -1; + final KNNVectorValuesIterator iterator = vectorValues.vectorValuesIterator; + for (docId = iterator.nextDoc(); docId != DocIdSetIterator.NO_MORE_DOCS && i < vectors.size(); docId = iterator.nextDoc()) { + T actual = vectorValues.getVector(); + T expected = vectors.get(i); + Assert.assertNotEquals(oldDocId, docId); + Assert.assertEquals(dimension, vectorValues.dimension()); + // this will check if reference is correct for the vectors. This is mainly required because for + // VectorValues of Lucene when reading vectors put the vector at same reference + if (oldActual != null && validateAddress) { + Assert.assertSame(actual, oldActual); + } + oldActual = actual; + // this will do the deep equals + Assert.assertArrayEquals(new Object[] { actual }, new Object[] { expected }); + i++; + } + } + } + +} diff --git a/src/test/java/org/opensearch/knn/index/vectorvalues/TestVectorValues.java b/src/test/java/org/opensearch/knn/index/vectorvalues/TestVectorValues.java new file mode 100644 index 000000000..3bf79b004 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/vectorvalues/TestVectorValues.java @@ -0,0 +1,363 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.knn.index.vectorvalues; + +import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.NumericDocValues; +import org.apache.lucene.index.SortedDocValues; +import org.apache.lucene.index.SortedNumericDocValues; +import org.apache.lucene.index.SortedSetDocValues; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.util.BytesRef; +import org.opensearch.knn.index.codec.util.KNNVectorSerializer; +import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.randomFloat; + +public class TestVectorValues { + + public static class RandomVectorBinaryDocValues extends VectorDocValues { + + public RandomVectorBinaryDocValues(int count, int dimension) { + super(count, dimension); + } + + @Override + public BytesRef binaryValue() throws IOException { + return new BytesRef(knnVectorSerializer.floatToByteArray(getRandomVector(dimension))); + } + } + + public static class ConstantVectorBinaryDocValues extends VectorDocValues { + + private final BytesRef value; + + public ConstantVectorBinaryDocValues(int count, int dimension, float value) { + super(count, dimension); + float[] array = new float[dimension]; + Arrays.fill(array, value); + this.value = new BytesRef(knnVectorSerializer.floatToByteArray(array)); + } + + @Override + public BytesRef binaryValue() throws IOException { + return value; + } + } + + public static class PredefinedFloatVectorBinaryDocValues extends VectorDocValues { + private final List vectors; + + public PredefinedFloatVectorBinaryDocValues(final List vectors) { + super(vectors.size(), vectors.get(0).length); + this.vectors = vectors; + } + + @Override + public BytesRef binaryValue() throws IOException { + return new BytesRef(knnVectorSerializer.floatToByteArray(vectors.get(docID()))); + } + } + + public static class PredefinedByteVectorBinaryDocValues extends VectorDocValues { + private final List vectors; + + public PredefinedByteVectorBinaryDocValues(final List vectors) { + super(vectors.size(), vectors.get(0).length); + this.vectors = vectors; + } + + @Override + public BytesRef binaryValue() throws IOException { + return new BytesRef(vectors.get(docID())); + } + } + + public static class RandomVectorDocValuesProducer extends DocValuesProducer { + + final RandomVectorBinaryDocValues randomBinaryDocValues; + + public RandomVectorDocValuesProducer(int count, int dimension) { + this.randomBinaryDocValues = new RandomVectorBinaryDocValues(count, dimension); + } + + @Override + public NumericDocValues getNumeric(FieldInfo field) { + return null; + } + + @Override + public BinaryDocValues getBinary(FieldInfo field) throws IOException { + return randomBinaryDocValues; + } + + @Override + public SortedDocValues getSorted(FieldInfo field) { + return null; + } + + @Override + public SortedNumericDocValues getSortedNumeric(FieldInfo field) { + return null; + } + + @Override + public SortedSetDocValues getSortedSet(FieldInfo field) { + return null; + } + + @Override + public void checkIntegrity() { + + } + + @Override + public void close() throws IOException { + + } + } + + static abstract class VectorDocValues extends BinaryDocValues { + + final int count; + final int dimension; + int current; + KNNVectorSerializer knnVectorSerializer; + + public VectorDocValues(int count, int dimension) { + this.count = count; + this.dimension = dimension; + this.current = -1; + this.knnVectorSerializer = KNNVectorSerializerFactory.getDefaultSerializer(); + } + + @Override + public boolean advanceExact(int target) throws IOException { + return false; + } + + @Override + public int docID() { + if (this.current > this.count) { + return BinaryDocValues.NO_MORE_DOCS; + } + return this.current; + } + + @Override + public int nextDoc() throws IOException { + return advance(current + 1); + } + + @Override + public int advance(int target) throws IOException { + current = target; + if (current >= count) { + current = NO_MORE_DOCS; + } + return current; + } + + @Override + public long cost() { + return count; + } + } + + public static class PreDefinedFloatVectorValues extends FloatVectorValues { + final int count; + final int dimension; + final List vectors; + int current; + float[] vector; + + public PreDefinedFloatVectorValues(final List vectors) { + super(); + this.count = vectors.size(); + this.dimension = vectors.get(0).length; + this.vectors = vectors; + this.current = -1; + vector = new float[dimension]; + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return count; + } + + @Override + public float[] vectorValue() throws IOException { + // since in FloatVectorValues the reference to returned vector doesn't change. This code ensure that we + // are replicating the behavior so that if someone uses this RandomFloatVectorValues they get an + // experience similar to what we get in prod. + System.arraycopy(vectors.get(docID()), 0, vector, 0, dimension); + return vector; + } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + throw new UnsupportedOperationException("scorer not supported with PreDefinedFloatVectorValues"); + } + + @Override + public int docID() { + if (this.current > this.count) { + return FloatVectorValues.NO_MORE_DOCS; + } + return this.current; + } + + @Override + public int nextDoc() throws IOException { + return advance(current + 1); + } + + @Override + public int advance(int target) throws IOException { + current = target; + if (current >= count) { + current = NO_MORE_DOCS; + } + return current; + } + } + + public static class PreDefinedByteVectorValues extends ByteVectorValues { + private final int count; + private final int dimension; + private final List vectors; + private int current; + private final byte[] vector; + + public PreDefinedByteVectorValues(final List vectors) { + super(); + this.count = vectors.size(); + this.dimension = vectors.get(0).length; + this.vectors = vectors; + this.current = -1; + vector = new byte[dimension]; + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return count; + } + + @Override + public byte[] vectorValue() throws IOException { + // since in FloatVectorValues the reference to returned vector doesn't change. This code ensure that we + // are replicating the behavior so that if someone uses this RandomFloatVectorValues they get an + // experience similar to what we get in prod. + System.arraycopy(vectors.get(docID()), 0, vector, 0, dimension); + return vector; + } + + @Override + public VectorScorer scorer(byte[] query) throws IOException { + throw new UnsupportedOperationException("scorer not supported with PreDefinedFloatVectorValues"); + } + + @Override + public int docID() { + if (this.current > this.count) { + return FloatVectorValues.NO_MORE_DOCS; + } + return this.current; + } + + @Override + public int nextDoc() throws IOException { + return advance(current + 1); + } + + @Override + public int advance(int target) throws IOException { + current = target; + if (current >= count) { + current = NO_MORE_DOCS; + } + return current; + } + } + + public static class PreDefinedBinaryVectorValues extends PreDefinedByteVectorValues { + + public PreDefinedBinaryVectorValues(List vectors) { + super(vectors); + } + + @Override + public int dimension() { + return super.dimension() * Byte.SIZE; + } + } + + public static class NotBinaryDocValues extends NumericDocValues { + + @Override + public long longValue() throws IOException { + return 0; + } + + @Override + public boolean advanceExact(int target) throws IOException { + return false; + } + + @Override + public int docID() { + return 0; + } + + @Override + public int nextDoc() throws IOException { + return 0; + } + + @Override + public int advance(int target) throws IOException { + return 0; + } + + @Override + public long cost() { + return 0; + } + } + + public static float[][] getRandomVectors(int count, int dimension) { + float[][] data = new float[count][dimension]; + for (int i = 0; i < count; i++) { + data[i] = getRandomVector(dimension); + } + return data; + } + + public static float[] getRandomVector(int dimension) { + float[] data = new float[dimension]; + for (int i = 0; i < dimension; i++) { + data[i] = randomFloat(); + } + return data; + } +} diff --git a/src/test/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategyTests.java b/src/test/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategyTests.java new file mode 100644 index 000000000..68a49a54c --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategyTests.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.vectorvalues; + +import lombok.SneakyThrows; +import org.junit.Assert; +import org.mockito.Mockito; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.VectorDataType; + +/** + * To avoid unit test duplication, tests for exception is added here. For non exception cases tests are present in + * {@link KNNVectorValuesTests} + */ +public class VectorValueExtractorStrategyTests extends KNNTestCase { + + @SneakyThrows + public void testExtractWithDISI_whenInvalidIterator_thenException() { + final VectorValueExtractorStrategy disiStrategy = new VectorValueExtractorStrategy.DISIVectorExtractor(); + final KNNVectorValuesIterator vectorValuesIterator = Mockito.mock(KNNVectorValuesIterator.DocIdsIteratorValues.class); + Mockito.when(vectorValuesIterator.getDocIdSetIterator()).thenReturn(new TestVectorValues.NotBinaryDocValues()); + Assert.assertThrows(IllegalArgumentException.class, () -> disiStrategy.extract(VectorDataType.FLOAT, vectorValuesIterator)); + Assert.assertThrows(IllegalArgumentException.class, () -> disiStrategy.extract(VectorDataType.BINARY, vectorValuesIterator)); + Assert.assertThrows(IllegalArgumentException.class, () -> disiStrategy.extract(VectorDataType.BYTE, vectorValuesIterator)); + } +}