From 5685992f6b3266b34eac2c76576a59318461eea2 Mon Sep 17 00:00:00 2001 From: Junqiu Lei Date: Tue, 25 Jun 2024 13:16:26 -0700 Subject: [PATCH] Add binary format support with IVF method in Faiss Engine Signed-off-by: Junqiu Lei --- CHANGELOG.md | 1 + jni/include/faiss_wrapper.h | 13 + .../org_opensearch_knn_jni_FaissService.h | 16 + jni/src/faiss_wrapper.cpp | 133 ++++++++ .../org_opensearch_knn_jni_FaissService.cpp | 28 ++ .../opensearch/knn/common/KNNConstants.java | 1 + .../org/opensearch/knn/index/IndexUtil.java | 167 ++++++---- .../opensearch/knn/index/VectorDataType.java | 2 + .../KNN80Codec/KNN80DocValuesConsumer.java | 83 ++--- .../index/mapper/KNNVectorFieldMapper.java | 12 +- .../knn/index/mapper/ModelFieldMapper.java | 8 +- .../index/memory/NativeMemoryAllocation.java | 12 +- .../memory/NativeMemoryEntryContext.java | 15 +- .../memory/NativeMemoryLoadStrategy.java | 13 +- .../knn/index/query/KNNQueryBuilder.java | 1 + .../opensearch/knn/index/query/KNNWeight.java | 12 +- .../org/opensearch/knn/index/util/Faiss.java | 2 +- .../org/opensearch/knn/indices/ModelDao.java | 1 + .../opensearch/knn/indices/ModelMetadata.java | 155 +++++----- .../org/opensearch/knn/jni/FaissService.java | 29 ++ .../org/opensearch/knn/jni/JNICommons.java | 13 + .../org/opensearch/knn/jni/JNIService.java | 12 +- .../plugin/rest/RestTrainModelHandler.java | 12 +- .../TrainingJobRouterTransportAction.java | 17 +- .../transport/TrainingModelRequest.java | 28 +- .../TrainingModelTransportAction.java | 6 +- .../training/ByteTrainingDataConsumer.java | 81 +++++ .../training/FloatTrainingDataConsumer.java | 67 ++++ .../knn/training/TrainingDataConsumer.java | 46 ++- .../opensearch/knn/training/TrainingJob.java | 19 +- .../opensearch/knn/training/VectorReader.java | 20 +- .../opensearch/knn/KNNSingleNodeTestCase.java | 14 +- .../org/opensearch/knn/index/FaissIT.java | 286 ++++++++++++++++-- .../opensearch/knn/index/IndexUtilTests.java | 64 +++- .../index/KNNCreateIndexFromModelTests.java | 3 +- .../KNN80DocValuesConsumerTests.java | 32 +- .../knn/index/codec/KNNCodecTestCase.java | 10 +- .../mapper/KNNVectorFieldMapperTests.java | 18 +- .../memory/NativeMemoryAllocationTests.java | 18 +- .../memory/NativeMemoryCacheManagerTests.java | 4 +- .../memory/NativeMemoryEntryContextTests.java | 22 +- .../memory/NativeMemoryLoadStrategyTests.java | 9 +- .../knn/index/query/KNNQueryBuilderTests.java | 3 + .../knn/index/query/KNNWeightTests.java | 4 + .../knn/indices/ModelCacheTests.java | 38 ++- .../opensearch/knn/indices/ModelDaoTests.java | 43 ++- .../knn/indices/ModelMetadataTests.java | 141 ++++++--- .../opensearch/knn/indices/ModelTests.java | 110 ++++++- .../plugin/action/RestKNNStatsHandlerIT.java | 15 +- .../transport/GetModelResponseTests.java | 8 +- ...oveModelFromCacheTransportActionTests.java | 4 +- ...TrainingJobRouterTransportActionTests.java | 100 +++++- .../transport/TrainingModelRequestTests.java | 48 ++- .../TrainingModelTransportActionTests.java | 6 +- ...ateModelGraveyardTransportActionTests.java | 4 +- .../UpdateModelMetadataRequestTests.java | 10 +- ...dateModelMetadataTransportActionTests.java | 4 +- ...va => FloatTrainingDataConsumerTests.java} | 8 +- .../knn/training/TrainingJobTests.java | 28 +- .../knn/training/VectorReaderTests.java | 101 ++++--- .../org/opensearch/knn/KNNRestTestCase.java | 35 +++ 61 files changed, 1762 insertions(+), 453 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/training/ByteTrainingDataConsumer.java create mode 100644 src/main/java/org/opensearch/knn/training/FloatTrainingDataConsumer.java rename src/test/java/org/opensearch/knn/training/{TrainingDataConsumerTests.java => FloatTrainingDataConsumerTests.java} (88%) diff --git a/CHANGELOG.md b/CHANGELOG.md index bd5b9ff72..5637adda0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Adds dynamic query parameter ef_search [#1783](https://github.com/opensearch-project/k-NN/pull/1783) * Adds dynamic query parameter ef_search in radial search faiss engine [#1790](https://github.com/opensearch-project/k-NN/pull/1790) * Add binary format support with HNSW method in Faiss Engine [#1781](https://github.com/opensearch-project/k-NN/pull/1781) +* Add binary format support with IVF method in Faiss Engine [#1784](https://github.com/opensearch-project/k-NN/pull/1784) ### Enhancements ### Bug Fixes * Fixing the arithmetic to find the number of vectors to stream from java to jni layer.[#1804](https://github.com/opensearch-project/k-NN/pull/1804) diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 2b9bc2c76..5ad0dedc4 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -29,6 +29,12 @@ namespace knn_jni { jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ, jobject parametersJ); + // Create an index with ids and vectors. Instead of creating a new index, this function creates the index + // based off of the template index passed in. The index is serialized to indexPathJ. + void CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ, + jobject parametersJ); + // Load an index from indexPathJ into memory. // // Return a pointer to the loaded index @@ -96,6 +102,13 @@ namespace knn_jni { jbyteArray TrainIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension, jlong trainVectorsPointerJ); + // Create an empty binary index defined by the values in the Java map, parametersJ. Train the index with + // the vector of floats located at trainVectorsPointerJ. + // + // Return the serialized representation + jbyteArray TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension, + jlong trainVectorsPointerJ); + /* * Perform a range search with filter against the index located in memory at indexPointerJ. * diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index 7cc071ff3..025fb12e8 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -43,6 +43,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryInde JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: createBinaryIndexFromTemplate + * Signature: ([IJILjava/lang/String;[BLjava/util/Map;)V + */ + JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndexFromTemplate + (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject); + /* * Class: org_opensearch_knn_jni_FaissService * Method: loadIndex @@ -139,6 +147,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_initLibrary JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex (JNIEnv *, jclass, jobject, jint, jlong); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: trainBinaryIndex + * Signature: (Ljava/util/Map;IJ)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinaryIndex + (JNIEnv *, jclass, jobject, jint, jlong); + /* * Class: org_opensearch_knn_jni_FaissService * Method: transferVectors diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 9abb2357f..92393245e 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -70,6 +70,9 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, // Train an index with data provided void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x); +// Train a binary index with data provided +void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x); + // Converts the int FilterIds to Faiss ids type array. void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds); @@ -223,6 +226,76 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * faiss::write_index(&idMap, indexPathCpp.c_str()); } +void knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, + jbyteArray templateIndexJ, jobject parametersJ) { + if (idsJ == nullptr) { + throw std::runtime_error("IDs cannot be null"); + } + + if (vectorsAddressJ <= 0) { + throw std::runtime_error("VectorsAddress cannot be less than 0"); + } + + if(dimJ <= 0) { + throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); + } + + if (indexPathJ == nullptr) { + throw std::runtime_error("Index path cannot be null"); + } + + if (templateIndexJ == nullptr) { + throw std::runtime_error("Template index cannot be null"); + } + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + omp_set_num_threads(threadCount); + } + jniUtil->DeleteLocalRef(env, parametersJ); + + // Read data set + // Read vectors from memory address + auto *inputVectors = reinterpret_cast*>(vectorsAddressJ); + int dim = (int)dimJ; + if (dim % 8 != 0) { + throw std::runtime_error("Dimensions should be multiply of 8"); + } + int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8)); + int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); + if (numIds != numVectors) { + throw std::runtime_error("Number of IDs does not match number of vectors"); + } + + // Get vector of bytes from jbytearray + int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ); + jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr); + + faiss::VectorIOReader vectorIoReader; + for (int i = 0; i < indexBytesCount; i++) { + vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]); + } + jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT); + + // Create faiss index + std::unique_ptr indexWriter; + indexWriter.reset(faiss::read_index_binary(&vectorIoReader, 0)); + + auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); + faiss::IndexBinaryIDMap idMap = faiss::IndexBinaryIDMap(indexWriter.get()); + idMap.add_with_ids(numVectors, reinterpret_cast(inputVectors->data()), idVector.data()); + // Releasing the vectorsAddressJ memory as that is not required once we have created the index. + // This is not the ideal approach, please refer this gh issue for long term solution: + // https://github.com/opensearch-project/k-NN/issues/1600 + delete inputVectors; + // Write the index to disk + std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); + faiss::write_index_binary(&idMap, indexPathCpp.c_str()); +} + jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) { if (indexPathJ == nullptr) { throw std::runtime_error("Index path cannot be null"); @@ -624,6 +697,57 @@ jbyteArray knn_jni::faiss_wrapper::TrainIndex(knn_jni::JNIUtilInterface * jniUti return ret; } +jbyteArray knn_jni::faiss_wrapper::TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, + jint dimensionJ, jlong trainVectorsPointerJ) { + // First, we need to build the index + if (parametersJ == nullptr) { + throw std::runtime_error("Parameters cannot be null"); + } + + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + + jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE); + std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); + faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp); + + // Create faiss index + jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); + std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ)); + + std::unique_ptr indexWriter; + indexWriter.reset(faiss::index_binary_factory((int) dimensionJ, indexDescriptionCpp.c_str())); + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + omp_set_num_threads(threadCount); + } + + // Train index if needed + auto *trainingVectorsPointerCpp = reinterpret_cast*>(trainVectorsPointerJ); + int numVectors = trainingVectorsPointerCpp->size()/(int) dimensionJ; + if(!indexWriter->is_trained) { + InternalTrainBinaryIndex(indexWriter.get(), numVectors, trainingVectorsPointerCpp->data()); + } + jniUtil->DeleteLocalRef(env, parametersJ); + + // Now that indexWriter is trained, we just load the bytes into an array and return + faiss::VectorIOWriter vectorIoWriter; + faiss::write_index_binary(indexWriter.get(), &vectorIoWriter); + + // Wrap in smart pointer + std::unique_ptr jbytesBuffer; + jbytesBuffer.reset(new jbyte[vectorIoWriter.data.size()]); + int c = 0; + for (auto b : vectorIoWriter.data) { + jbytesBuffer[c++] = (jbyte) b; + } + + jbyteArray ret = jniUtil->NewByteArray(env, vectorIoWriter.data.size()); + jniUtil->SetByteArrayRegion(env, ret, 0, vectorIoWriter.data.size(), jbytesBuffer.get()); + return ret; +} + faiss::MetricType TranslateSpaceToMetric(const std::string& spaceType) { if (spaceType == knn_jni::L2) { return faiss::METRIC_L2; @@ -682,6 +806,15 @@ void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) { } } +void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x) { + if (auto * indexIvf = dynamic_cast(index)) { + indexIvf->make_direct_map(); + } + if (!index->is_trained) { + index->train(n, reinterpret_cast(x)); + } +} + std::unique_ptr buildIDGrouperBitmap(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ, std::vector* bitmap) { int *parentIdsArray = jniUtil->GetIntArrayElements(env, parentIdsJ, nullptr); int parentIdsLength = jniUtil->GetJavaIntArrayLength(env, parentIdsJ); diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 6e447b034..2394e2951 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -90,6 +90,21 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT } } +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndexFromTemplate(JNIEnv * env, jclass cls, + jintArray idsJ, + jlong vectorsAddressJ, + jint dimJ, + jstring indexPathJ, + jbyteArray templateIndexJ, + jobject parametersJ) +{ + try { + knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex(JNIEnv * env, jclass cls, jstring indexPathJ) { try { @@ -220,6 +235,19 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex return nullptr; } +JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinaryIndex(JNIEnv * env, jclass cls, + jobject parametersJ, + jint dimensionJ, + jlong trainVectorsPointerJ) +{ + try { + return knn_jni::faiss_wrapper::TrainBinaryIndex(&jniUtil, env, parametersJ, dimensionJ, trainVectorsPointerJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return nullptr; +} + JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors(JNIEnv * env, jclass cls, jlong vectorsPointerJ, jobjectArray vectorsJ) diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index a85852027..77aae79c3 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -67,6 +67,7 @@ public class KNNConstants { public static final String SEARCH_SIZE_PARAMETER = "search_size"; public static final String VECTOR_DATA_TYPE_FIELD = "data_type"; + public static final String MODEL_VECTOR_DATA_TYPE_KEY = VECTOR_DATA_TYPE_FIELD; public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT; public static final String RADIAL_SEARCH_KEY = "radial_search"; diff --git a/src/main/java/org/opensearch/knn/index/IndexUtil.java b/src/main/java/org/opensearch/knn/index/IndexUtil.java index 5815b343e..2e3f56d96 100644 --- a/src/main/java/org/opensearch/knn/index/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/IndexUtil.java @@ -48,28 +48,10 @@ public class IndexUtil { private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT = Version.V_2_13_0; private static final Version MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH = Version.V_2_14_0; private static final Version MINIMAL_SUPPORTED_VERSION_FOR_METHOD_PARAMETERS = Version.V_2_16_0; + private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VECTOR_DATA_TYPE = Version.V_2_16_0; // public so neural search can access it public static final Map minimalRequiredVersionMap = initializeMinimalRequiredVersionMap(); - private static Map initializeMinimalRequiredVersionMap() { - final Map versionMap = new HashMap<>() { - { - put("ignore_unmapped", MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED); - put(MODEL_NODE_ASSIGNMENT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT); - put(MODEL_METHOD_COMPONENT_CONTEXT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT); - put(KNNConstants.RADIAL_SEARCH_KEY, MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH); - put(KNNConstants.METHOD_PARAMETER, MINIMAL_SUPPORTED_VERSION_FOR_METHOD_PARAMETERS); - } - }; - - for (final MethodParameter methodParameter : MethodParameter.values()) { - if (methodParameter.getVersion() != null) { - versionMap.put(methodParameter.getName(), methodParameter.getVersion()); - } - } - return Collections.unmodifiableMap(versionMap); - } - /** * Determines the size of a file on disk in kilobytes * @@ -88,37 +70,6 @@ public static int getFileSizeInKB(String filePath) { return Math.toIntExact((file.length() / BYTES_PER_KILOBYTES) + 1L); // Add one so that integer division rounds up } - /** - * This method retrieves the field mapping by a given field path from the index metadata. - * - * @param properties Index metadata mapping properties. - * @param fieldPath The field path string that make up the path to the field mapping. e.g. "a.b.field" or "field". - * The field path is applied and checked in OpenSearch, so it is guaranteed to be valid. - * - * @return The field mapping object if found, or null if the field is not found in the index metadata. - */ - private static Object getFieldMapping(final Map properties, final String fieldPath) { - String[] fieldPaths = fieldPath.split("\\."); - Object currentFieldMapping = properties; - - // Iterate through the field path list to retrieve the field mapping. - for (String path : fieldPaths) { - currentFieldMapping = ((Map) currentFieldMapping).get(path); - if (currentFieldMapping == null) { - return null; - } - - if (currentFieldMapping instanceof Map) { - Object possibleProperties = ((Map) currentFieldMapping).get("properties"); - if (possibleProperties instanceof Map) { - currentFieldMapping = possibleProperties; - } - } - } - - return currentFieldMapping; - } - /** * Validate that a field is a k-NN vector field and has the expected dimension * @@ -135,7 +86,8 @@ public static ValidationException validateKnnField( IndexMetadata indexMetadata, String field, int expectedDimension, - ModelDao modelDao + ModelDao modelDao, + VectorDataType expectedVectorDataType ) { // Index metadata should not be null if (indexMetadata == null) { @@ -190,6 +142,29 @@ public static ValidationException validateKnnField( return exception; } + if (expectedVectorDataType != null) { + if (VectorDataType.BYTE == expectedVectorDataType) { + exception.addValidationError( + String.format(Locale.ROOT, "vector data type \"%s\" is not supported for training.", expectedVectorDataType.getValue()) + ); + return exception; + } + VectorDataType trainIndexDataType = getVectorDataTypeFromFieldMapping(fieldMap); + + if (trainIndexDataType != expectedVectorDataType) { + exception.addValidationError( + String.format( + Locale.ROOT, + "Field \"%s\" has data type %s, which is different from data type used in the training request: %s", + field, + trainIndexDataType.getValue(), + expectedVectorDataType.getValue() + ) + ); + return exception; + } + } + // Return if dimension does not need to be checked if (expectedDimension < 0) { return null; @@ -321,4 +296,94 @@ public static boolean isBinaryIndex(KNNEngine knnEngine, Map par && parameters.get(VECTOR_DATA_TYPE_FIELD) != null && parameters.get(VECTOR_DATA_TYPE_FIELD).toString().equals(VectorDataType.BINARY.getValue()); } + + /** + * Tell if it is binary index or not + * + * @param vectorDataType vector data type + * @return true if it is binary index + */ + public static boolean isBinaryIndex(VectorDataType vectorDataType) { + return VectorDataType.BINARY == vectorDataType; + } + + /** + * Update vector data type into parameters + * + * @param parameters parameters associated with an index + * @param vectorDataType vector data type + */ + public static void updateVectorDataTypeToParameters(Map parameters, VectorDataType vectorDataType) { + if (VectorDataType.BINARY == vectorDataType) { + parameters.put(VECTOR_DATA_TYPE_FIELD, VectorDataType.BINARY.getValue()); + } + } + + /** + * This method retrieves the field mapping by a given field path from the index metadata. + * + * @param properties Index metadata mapping properties. + * @param fieldPath The field path string that make up the path to the field mapping. e.g. "a.b.field" or "field". + * The field path is applied and checked in OpenSearch, so it is guaranteed to be valid. + * + * @return The field mapping object if found, or null if the field is not found in the index metadata. + */ + private static Object getFieldMapping(final Map properties, final String fieldPath) { + String[] fieldPaths = fieldPath.split("\\."); + Object currentFieldMapping = properties; + + // Iterate through the field path list to retrieve the field mapping. + for (String path : fieldPaths) { + currentFieldMapping = ((Map) currentFieldMapping).get(path); + if (currentFieldMapping == null) { + return null; + } + + if (currentFieldMapping instanceof Map) { + Object possibleProperties = ((Map) currentFieldMapping).get("properties"); + if (possibleProperties instanceof Map) { + currentFieldMapping = possibleProperties; + } + } + } + + return currentFieldMapping; + } + + /** + * This method is used to get the vector data type from field mapping + * @param fieldMap field mapping + * @return vector data type + */ + private static VectorDataType getVectorDataTypeFromFieldMapping(Map fieldMap) { + if (fieldMap.containsKey(VECTOR_DATA_TYPE_FIELD)) { + return VectorDataType.get((String) fieldMap.get(VECTOR_DATA_TYPE_FIELD)); + } + return VectorDataType.DEFAULT; + } + + /** + * Initialize the minimal required version map + * + * @return minimal required version map + */ + private static Map initializeMinimalRequiredVersionMap() { + final Map versionMap = new HashMap<>() { + { + put("ignore_unmapped", MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED); + put(MODEL_NODE_ASSIGNMENT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT); + put(MODEL_METHOD_COMPONENT_CONTEXT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT); + put(KNNConstants.RADIAL_SEARCH_KEY, MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH); + put(KNNConstants.METHOD_PARAMETER, MINIMAL_SUPPORTED_VERSION_FOR_METHOD_PARAMETERS); + put(KNNConstants.MODEL_VECTOR_DATA_TYPE_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VECTOR_DATA_TYPE); + } + }; + + for (final MethodParameter methodParameter : MethodParameter.values()) { + if (methodParameter.getVersion() != null) { + versionMap.put(methodParameter.getName(), methodParameter.getVersion()); + } + } + return Collections.unmodifiableMap(versionMap); + } } diff --git a/src/main/java/org/opensearch/knn/index/VectorDataType.java b/src/main/java/org/opensearch/knn/index/VectorDataType.java index 8add84609..4f79b232f 100644 --- a/src/main/java/org/opensearch/knn/index/VectorDataType.java +++ b/src/main/java/org/opensearch/knn/index/VectorDataType.java @@ -131,4 +131,6 @@ public static VectorDataType get(String vectorDataType) { ); } } + + public static VectorDataType DEFAULT = FLOAT; } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index 50c1c9271..ea5cb5e3b 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -5,7 +5,6 @@ package org.opensearch.knn.index.codec.KNN80Codec; -import com.google.common.collect.ImmutableMap; import lombok.NonNull; import lombok.extern.log4j.Log4j2; import org.apache.lucene.store.ChecksumIndexInput; @@ -15,6 +14,7 @@ import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.knn.index.IndexUtil; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.transfer.VectorTransfer; @@ -111,30 +111,10 @@ private KNNEngine getKNNEngine(@NonNull FieldInfo field) { return KNNEngine.getEngine(engineName); } - private VectorTransfer getVectorTransfer(FieldInfo field) { - if (VectorDataType.BINARY.getValue().equalsIgnoreCase(field.attributes().get(KNNConstants.VECTOR_DATA_TYPE_FIELD))) { - return new VectorTransferByte(KNNSettings.getVectorStreamingMemoryLimit().getBytes()); - } - return new VectorTransferFloat(KNNSettings.getVectorStreamingMemoryLimit().getBytes()); - } - public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh) throws IOException { // Get values to be indexed BinaryDocValues values = valuesProducer.getBinary(field); - KNNCodecUtil.Pair pair = KNNCodecUtil.getPair(values, getVectorTransfer(field)); - if (pair.getVectorAddress() == 0 || pair.docs.length == 0) { - logger.info("Skipping engine index creation as there are no vectors or docs in the segment"); - return; - } - long arraySize = calculateArraySize(pair.docs.length, pair.getDimension(), pair.serializationMode); - if (isMerge) { - KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment(); - KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(pair.docs.length); - KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(arraySize); - } - // Increment counter for number of graph index requests - KNNCounter.GRAPH_INDEX_REQUESTS.increment(); final KNNEngine knnEngine = getKNNEngine(field); final String engineFileName = buildEngineFileName( state.segmentInfo.name, @@ -146,30 +126,53 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, ((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName ).toString(); + + // Determine if we are creating an index from a model or from scratch NativeIndexCreator indexCreator; - // Create library index either from model or from scratch - if (field.attributes().containsKey(MODEL_ID)) { - String modelId = field.attributes().get(MODEL_ID); + KNNCodecUtil.Pair pair; + Map fieldAttributes = field.attributes(); + + if (fieldAttributes.containsKey(MODEL_ID)) { + String modelId = fieldAttributes.get(MODEL_ID); Model model = ModelCache.getInstance().get(modelId); if (model.getModelBlob() == null) { throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); } - indexCreator = () -> createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, indexPath); + VectorDataType vectorDataType = model.getModelMetadata().getVectorDataType(); + pair = KNNCodecUtil.getPair(values, getVectorTransfer(vectorDataType)); + indexCreator = () -> createKNNIndexFromTemplate(model, pair, knnEngine, indexPath); } else { + // get vector data type from field attributes or provide default value + VectorDataType vectorDataType = VectorDataType.get( + fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()) + ); + pair = KNNCodecUtil.getPair(values, getVectorTransfer(vectorDataType)); indexCreator = () -> createKNNIndexFromScratch(field, pair, knnEngine, indexPath); } + // Skip index creation if no vectors or docs in segment + if (pair.getVectorAddress() == 0 || pair.docs.length == 0) { + logger.info("Skipping engine index creation as there are no vectors or docs in the segment"); + return; + } + + long arraySize = calculateArraySize(pair.docs.length, pair.getDimension(), pair.serializationMode); + if (isMerge) { + KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment(); + KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(pair.docs.length); + KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(arraySize); recordMergeStats(pair.docs.length, arraySize); } + // Increment counter for number of graph index requests + KNNCounter.GRAPH_INDEX_REQUESTS.increment(); + if (isRefresh) { recordRefreshStats(); } - // This is a bit of a hack. We have to create an output here and then immediately close it to ensure that - // engineFileName is added to the tracked files by Lucene's TrackingDirectoryWrapper. Otherwise, the file will - // not be marked as added to the directory. + // Ensure engineFileName is added to the tracked files by Lucene's TrackingDirectoryWrapper state.directory.createOutput(engineFileName, state.context).close(); indexCreator.createIndex(); writeFooter(indexPath, engineFileName); @@ -188,18 +191,19 @@ private void recordRefreshStats() { KNNGraphValue.REFRESH_TOTAL_OPERATIONS.increment(); } - private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) { - Map parameters = ImmutableMap.of( - KNNConstants.INDEX_THREAD_QTY, - KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY) - ); + private void createKNNIndexFromTemplate(Model model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) { + Map parameters = new HashMap<>(); + parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); + + IndexUtil.updateVectorDataTypeToParameters(parameters, model.getModelMetadata().getVectorDataType()); + AccessController.doPrivileged((PrivilegedAction) () -> { JNIService.createIndexFromTemplate( pair.docs, pair.getVectorAddress(), pair.getDimension(), indexPath, - model, + model.getModelBlob(), parameters, knnEngine ); @@ -242,13 +246,13 @@ private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pa // Update index description of Faiss for binary data type if (KNNEngine.FAISS == knnEngine && VectorDataType.BINARY.getValue() - .equals(fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue())) + .equals(fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue())) && parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null) { parameters.put( KNNConstants.INDEX_DESCRIPTION_PARAMETER, FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString() ); - parameters.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.BINARY.getValue()); + IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY); } // Used to determine how many threads to use when indexing @@ -354,4 +358,11 @@ private boolean isChecksumValid(long value) { // https://github.com/apache/lucene/blob/branch_9_0/lucene/core/src/java/org/apache/lucene/codecs/CodecUtil.java#L644-L647 return (value & CRC32_CHECKSUM_SANITY) != 0; } + + private VectorTransfer getVectorTransfer(VectorDataType vectorDataType) { + if (VectorDataType.BINARY == vectorDataType) { + return new VectorTransferByte(KNNSettings.getVectorStreamingMemoryLimit().getBytes()); + } + return new VectorTransferFloat(KNNSettings.getVectorStreamingMemoryLimit().getBytes()); + } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 7f7c83f3e..7c7f446dd 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -603,7 +603,8 @@ protected void parseCreateField(ParseContext context) throws IOException { context, fieldType().getDimension(), fieldType().getSpaceType(), - getMethodComponentContext(fieldType().getKnnMethodContext()) + getMethodComponentContext(fieldType().getKnnMethodContext()), + fieldType().getVectorDataType() ); } @@ -646,8 +647,13 @@ protected List getFieldsForByteVector(final byte[] array, final FieldType return fields; } - protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType, MethodComponentContext methodComponentContext) - throws IOException { + protected void parseCreateField( + ParseContext context, + int dimension, + SpaceType spaceType, + MethodComponentContext methodComponentContext, + VectorDataType vectorDataType + ) throws IOException { validateIfKNNPluginEnabled(); validateIfCircuitBreakerIsNotTriggered(); diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index 554871279..adaaef28e 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -62,6 +62,12 @@ protected void parseCreateField(ParseContext context) throws IOException { ); } - parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getSpaceType(), modelMetadata.getMethodComponentContext()); + parseCreateField( + context, + modelMetadata.getDimension(), + modelMetadata.getSpaceType(), + modelMetadata.getMethodComponentContext(), + modelMetadata.getVectorDataType() + ); } } diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java index b108fb6f0..0b0f1e615 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java @@ -13,6 +13,8 @@ import lombok.Getter; import org.apache.lucene.index.LeafReaderContext; +import org.opensearch.knn.index.IndexUtil; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.query.KNNWeight; import org.opensearch.knn.jni.JNICommons; import org.opensearch.knn.jni.JNIService; @@ -248,6 +250,7 @@ class TrainingDataAllocation implements NativeMemoryAllocation { private int readCount; private Semaphore readSemaphore; private Semaphore writeSemaphore; + private VectorDataType vectorDataType; /** * Constructor @@ -256,7 +259,7 @@ class TrainingDataAllocation implements NativeMemoryAllocation { * @param memoryAddress pointer in memory to the training data allocation * @param size amount memory needed for allocation in kilobytes */ - TrainingDataAllocation(ExecutorService executor, long memoryAddress, int size) { + public TrainingDataAllocation(ExecutorService executor, long memoryAddress, int size, VectorDataType vectorDataType) { this.executor = executor; this.closed = false; this.memoryAddress = memoryAddress; @@ -265,6 +268,7 @@ class TrainingDataAllocation implements NativeMemoryAllocation { this.readCount = 0; this.readSemaphore = new Semaphore(1); this.writeSemaphore = new Semaphore(1); + this.vectorDataType = vectorDataType; } @Override @@ -295,7 +299,11 @@ private void cleanup() { closed = true; if (this.memoryAddress != 0) { - JNICommons.freeVectorData(this.memoryAddress); + if (IndexUtil.isBinaryIndex(vectorDataType)) { + JNICommons.freeByteVectorData(this.memoryAddress); + } else { + JNICommons.freeVectorData(this.memoryAddress); + } } } diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java index 7f14a2341..b5ddff1e2 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java @@ -14,6 +14,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Nullable; import org.opensearch.knn.index.IndexUtil; +import org.opensearch.knn.index.VectorDataType; import java.io.IOException; import java.util.Map; @@ -169,6 +170,7 @@ public static class TrainingDataEntryContext extends NativeMemoryEntryContext trainingDataAllocation.writeUnlock(), ex -> { // Close unsafe will assume that the caller passes control of the writelock to it. It // will then handle releasing the write lock once the close operations finish. diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index a10e04788..5baaf59cd 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -515,6 +515,7 @@ protected Query doToQuery(QueryShardContext context) { knnEngine = modelMetadata.getKnnEngine(); spaceType = modelMetadata.getSpaceType(); methodComponentContext = modelMetadata.getMethodComponentContext(); + vectorDataType = modelMetadata.getVectorDataType(); } else if (knnMethodContext != null) { // If the dimension is set but the knnMethodContext is not then the field is using the legacy mapping diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index c08997c26..235e66411 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -213,6 +213,7 @@ private Map doANNSearch(final LeafReaderContext context, final B KNNEngine knnEngine; SpaceType spaceType; + VectorDataType vectorDataType; // Check if a modelId exists. If so, the space type and engine will need to be picked up from the model's // metadata. @@ -225,11 +226,15 @@ private Map doANNSearch(final LeafReaderContext context, final B knnEngine = modelMetadata.getKnnEngine(); spaceType = modelMetadata.getSpaceType(); + vectorDataType = modelMetadata.getVectorDataType(); } else { String engineName = fieldInfo.attributes().getOrDefault(KNN_ENGINE, KNNEngine.NMSLIB.getName()); knnEngine = KNNEngine.getEngine(engineName); String spaceTypeName = fieldInfo.attributes().getOrDefault(SPACE_TYPE, SpaceType.L2.getValue()); spaceType = SpaceType.getSpace(spaceTypeName); + vectorDataType = VectorDataType.get( + fieldInfo.attributes().getOrDefault(VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue()) + ); } /* @@ -261,12 +266,7 @@ private Map doANNSearch(final LeafReaderContext context, final B new NativeMemoryEntryContext.IndexEntryContext( indexPath.toString(), NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), - getParametersAtLoading( - spaceType, - knnEngine, - knnQuery.getIndexName(), - VectorDataType.get(fieldInfo.attributes().getOrDefault(VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue())) - ), + getParametersAtLoading(spaceType, knnEngine, knnQuery.getIndexName(), vectorDataType), knnQuery.getIndexName(), modelId ), diff --git a/src/main/java/org/opensearch/knn/index/util/Faiss.java b/src/main/java/org/opensearch/knn/index/util/Faiss.java index 711c206f5..4e39c1af1 100644 --- a/src/main/java/org/opensearch/knn/index/util/Faiss.java +++ b/src/main/java/org/opensearch/knn/index/util/Faiss.java @@ -305,7 +305,7 @@ public class Faiss extends NativeLibrary { return ((4L * centroids * dimension) / BYTES_PER_KILOBYTES) + 1; }) .build() - ).addSpaces(SpaceType.UNDEFINED, SpaceType.L2, SpaceType.INNER_PRODUCT).build() + ).addSpaces(SpaceType.UNDEFINED, SpaceType.L2, SpaceType.INNER_PRODUCT, SpaceType.HAMMING_BIT).build() ); final static Faiss INSTANCE = new Faiss( diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index 0bc6c5edb..37edcd3ae 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -292,6 +292,7 @@ private void putInternal(Model model, ActionListener listener, Do put(KNNConstants.MODEL_DESCRIPTION, modelMetadata.getDescription()); put(KNNConstants.MODEL_ERROR, modelMetadata.getError()); put(KNNConstants.MODEL_NODE_ASSIGNMENT, modelMetadata.getNodeAssignment()); + put(KNNConstants.VECTOR_DATA_TYPE_FIELD, modelMetadata.getVectorDataType()); MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext(); if (!methodComponentContext.getName().isEmpty()) { diff --git a/src/main/java/org/opensearch/knn/indices/ModelMetadata.java b/src/main/java/org/opensearch/knn/indices/ModelMetadata.java index f3a5506cd..6bfb3aaf2 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelMetadata.java +++ b/src/main/java/org/opensearch/knn/indices/ModelMetadata.java @@ -26,6 +26,7 @@ import org.opensearch.knn.index.IndexUtil; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import java.io.IOException; @@ -48,6 +49,7 @@ public class ModelMetadata implements Writeable, ToXContentObject { final private String timestamp; final private String description; final private String trainingNodeAssignment; + final private VectorDataType vectorDataType; private MethodComponentContext methodComponentContext; private String error; @@ -81,6 +83,12 @@ public ModelMetadata(StreamInput in) throws IOException { } else { this.methodComponentContext = MethodComponentContext.EMPTY; } + + if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), KNNConstants.MODEL_VECTOR_DATA_TYPE_KEY)) { + this.vectorDataType = VectorDataType.get(in.readString()); + } else { + this.vectorDataType = VectorDataType.DEFAULT; + } } /** @@ -95,6 +103,7 @@ public ModelMetadata(StreamInput in) throws IOException { * @param error error message associated with model * @param trainingNodeAssignment node assignment for the model * @param methodComponentContext method component context associated with model + * @param vectorDataType vector data type of the model */ public ModelMetadata( KNNEngine knnEngine, @@ -105,7 +114,8 @@ public ModelMetadata( String description, String error, String trainingNodeAssignment, - MethodComponentContext methodComponentContext + MethodComponentContext methodComponentContext, + VectorDataType vectorDataType ) { this.knnEngine = Objects.requireNonNull(knnEngine, "knnEngine must not be null"); this.spaceType = Objects.requireNonNull(spaceType, "spaceType must not be null"); @@ -128,6 +138,7 @@ public ModelMetadata( this.error = Objects.requireNonNull(error, "error must not be null"); this.trainingNodeAssignment = Objects.requireNonNull(trainingNodeAssignment, "node assignment must not be null"); this.methodComponentContext = Objects.requireNonNull(methodComponentContext, "method context must not be null"); + this.vectorDataType = Objects.requireNonNull(vectorDataType, "vector data type must not be null"); } /** @@ -211,6 +222,10 @@ public MethodComponentContext getMethodComponentContext() { return methodComponentContext; } + public VectorDataType getVectorDataType() { + return vectorDataType; + } + /** * setter for model's state * @@ -241,7 +256,8 @@ public String toString() { description, error, trainingNodeAssignment, - methodComponentContext.toClusterStateString() + methodComponentContext.toClusterStateString(), + vectorDataType.getValue() ); } @@ -259,6 +275,7 @@ public boolean equals(Object obj) { equalsBuilder.append(getTimestamp(), other.getTimestamp()); equalsBuilder.append(getDescription(), other.getDescription()); equalsBuilder.append(getError(), other.getError()); + equalsBuilder.append(getVectorDataType(), other.getVectorDataType()); return equalsBuilder.isEquals(); } @@ -273,6 +290,7 @@ public int hashCode() { .append(getDescription()) .append(getError()) .append(getMethodComponentContext()) + .append(getVectorDataType()) .toHashCode(); } @@ -284,81 +302,60 @@ public int hashCode() { */ public static ModelMetadata fromString(String modelMetadataString) { String[] modelMetadataArray = modelMetadataString.split(DELIMITER, -1); + int length = modelMetadataArray.length; - // Training node assignment was added as a field in Version 2.12.0 - // Because models can be created on older versions and the cluster can be upgraded after, - // we need to accept model metadata arrays both with and without the training node assignment. - if (modelMetadataArray.length == 7) { - log.debug( - "Model metadata array does not contain training node assignment or method component context. Assuming empty string node assignment and empty method component context." - ); - KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]); - SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]); - int dimension = Integer.parseInt(modelMetadataArray[2]); - ModelState modelState = ModelState.getModelState(modelMetadataArray[3]); - String timestamp = modelMetadataArray[4]; - String description = modelMetadataArray[5]; - String error = modelMetadataArray[6]; - return new ModelMetadata( - knnEngine, - spaceType, - dimension, - modelState, - timestamp, - description, - error, - "", - MethodComponentContext.EMPTY - ); - } else if (modelMetadataArray.length == 8) { - log.debug("Model metadata contains training node assignment. Assuming empty method component context."); - KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]); - SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]); - int dimension = Integer.parseInt(modelMetadataArray[2]); - ModelState modelState = ModelState.getModelState(modelMetadataArray[3]); - String timestamp = modelMetadataArray[4]; - String description = modelMetadataArray[5]; - String error = modelMetadataArray[6]; - String trainingNodeAssignment = modelMetadataArray[7]; - return new ModelMetadata( - knnEngine, - spaceType, - dimension, - modelState, - timestamp, - description, - error, - trainingNodeAssignment, - MethodComponentContext.EMPTY - ); - } else if (modelMetadataArray.length == 9) { - log.debug("Model metadata contains training node assignment and method context"); - KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]); - SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]); - int dimension = Integer.parseInt(modelMetadataArray[2]); - ModelState modelState = ModelState.getModelState(modelMetadataArray[3]); - String timestamp = modelMetadataArray[4]; - String description = modelMetadataArray[5]; - String error = modelMetadataArray[6]; - String trainingNodeAssignment = modelMetadataArray[7]; - MethodComponentContext methodComponentContext = MethodComponentContext.fromClusterStateString(modelMetadataArray[8]); - return new ModelMetadata( - knnEngine, - spaceType, - dimension, - modelState, - timestamp, - description, - error, - trainingNodeAssignment, - methodComponentContext - ); - } else { + if (length < 7 || length > 10) { throw new IllegalArgumentException( "Illegal format for model metadata. Must be of the form " - + "\",,,,,,\" or \",,,,,,,\" or \",,,,,,,,\"." + + "\",,,,,,\" or " + + "\",,,,,,,\" or " + + "\",,,,,,,,\" or " + + "\",,,,,,,,,\"." ); } + + KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]); + SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]); + int dimension = Integer.parseInt(modelMetadataArray[2]); + ModelState modelState = ModelState.getModelState(modelMetadataArray[3]); + String timestamp = modelMetadataArray[4]; + String description = modelMetadataArray[5]; + String error = modelMetadataArray[6]; + String trainingNodeAssignment = length > 7 ? modelMetadataArray[7] : ""; + MethodComponentContext methodComponentContext = length > 8 + ? MethodComponentContext.fromClusterStateString(modelMetadataArray[8]) + : MethodComponentContext.EMPTY; + VectorDataType vectorDataType = length > 9 ? VectorDataType.get(modelMetadataArray[9]) : VectorDataType.DEFAULT; + + log.debug(getLogMessage(length)); + + return new ModelMetadata( + knnEngine, + spaceType, + dimension, + modelState, + timestamp, + description, + error, + trainingNodeAssignment, + methodComponentContext, + vectorDataType + ); + } + + private static String getLogMessage(int length) { + switch (length) { + case 7: + return "Model metadata array does not contain training node assignment or method component context. Assuming empty string node assignment and empty method component context."; + case 8: + return "Model metadata contains training node assignment. Assuming empty method component context."; + case 9: + return "Model metadata contains training node assignment and method context."; + case 10: + return "Model metadata contains training node assignment, method context and vector data type."; + default: + throw new IllegalArgumentException("Unexpected metadata array length: " + length); + } } private static String objectToString(Object value) { @@ -387,6 +384,7 @@ public static ModelMetadata getMetadataFromSourceMap(final Map m Object error = modelSourceMap.get(KNNConstants.MODEL_ERROR); Object trainingNodeAssignment = modelSourceMap.get(KNNConstants.MODEL_NODE_ASSIGNMENT); Object methodComponentContext = modelSourceMap.get(KNNConstants.MODEL_METHOD_COMPONENT_CONTEXT); + Object vectorDataType = modelSourceMap.get(KNNConstants.VECTOR_DATA_TYPE_FIELD); if (trainingNodeAssignment == null) { trainingNodeAssignment = ""; @@ -407,6 +405,10 @@ public static ModelMetadata getMetadataFromSourceMap(final Map m methodComponentContext = MethodComponentContext.EMPTY; } + if (vectorDataType == null) { + vectorDataType = VectorDataType.DEFAULT.getValue(); + } + ModelMetadata modelMetadata = new ModelMetadata( KNNEngine.getEngine(objectToString(engine)), SpaceType.getSpace(objectToString(space)), @@ -416,7 +418,8 @@ public static ModelMetadata getMetadataFromSourceMap(final Map m objectToString(description), objectToString(error), objectToString(trainingNodeAssignment), - (MethodComponentContext) methodComponentContext + (MethodComponentContext) methodComponentContext, + VectorDataType.get(objectToString(vectorDataType)) ); return modelMetadata; } @@ -436,6 +439,9 @@ public void writeTo(StreamOutput out) throws IOException { if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), IndexUtil.MODEL_METHOD_COMPONENT_CONTEXT_KEY)) { getMethodComponentContext().writeTo(out); } + if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), KNNConstants.MODEL_VECTOR_DATA_TYPE_KEY)) { + out.writeString(vectorDataType.getValue()); + } } @Override @@ -456,6 +462,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws getMethodComponentContext().toXContent(builder, params); builder.endObject(); } + if (IndexUtil.isClusterOnOrAfterMinRequiredVersion(KNNConstants.MODEL_VECTOR_DATA_TYPE_KEY)) { + builder.field(KNNConstants.VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); + } return builder; } } diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 21de90765..1f23f6fcd 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -96,6 +96,25 @@ public static native void createIndexFromTemplate( Map parameters ); + /** + * Create a binary index for the native library with a provided template index + * + * @param ids array of ids mapping to the data passed in + * @param vectorsAddress address of native memory where vectors are stored + * @param dim dimension of the vector to be indexed + * @param indexPath path to save index file to + * @param templateIndex empty template index + * @param parameters additional build time parameters + */ + public static native void createBinaryIndexFromTemplate( + int[] ids, + long vectorsAddress, + int dim, + String indexPath, + byte[] templateIndex, + Map parameters + ); + /** * Load an index into memory * @@ -249,6 +268,16 @@ public static native KNNQueryResult[] queryBinaryIndexWithFilter( */ public static native byte[] trainIndex(Map indexParameters, int dimension, long trainVectorsPointer); + /** + * Train an empty binary index + * + * @param indexParameters parameters used to build index + * @param dimension dimension for the index + * @param trainVectorsPointer pointer to where training vectors are stored in native memory + * @return bytes array of trained template index + */ + public static native byte[] trainBinaryIndex(Map indexParameters, int dimension, long trainVectorsPointer); + /** *

* The function is deprecated. Use {@link JNICommons#storeVectorData(long, float[][], long)} diff --git a/src/main/java/org/opensearch/knn/jni/JNICommons.java b/src/main/java/org/opensearch/knn/jni/JNICommons.java index d0111b115..31a8f43cc 100644 --- a/src/main/java/org/opensearch/knn/jni/JNICommons.java +++ b/src/main/java/org/opensearch/knn/jni/JNICommons.java @@ -78,4 +78,17 @@ public class JNICommons { * @param memoryAddress address to be freed. */ public static native void freeVectorData(long memoryAddress); + + /** + * Free up the memory allocated for the byte data stored in memory address. This function should be used with the memory + * address returned by {@link JNICommons#storeVectorData(long, float[][], long)} + * + *

+ * The function is not threadsafe. If multiple threads are trying to free up same memory location, then it can + * lead to errors. + *

+ * + * @param memoryAddress address to be freed. + */ + public static native void freeByteVectorData(long memoryAddress); } diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index cefd0af53..2a8d3ea8f 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -83,8 +83,13 @@ public static void createIndexFromTemplate( KNNEngine knnEngine ) { if (KNNEngine.FAISS == knnEngine) { - FaissService.createIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters); - return; + if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { + FaissService.createBinaryIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters); + return; + } else { + FaissService.createIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters); + return; + } } throw new IllegalArgumentException( @@ -308,6 +313,9 @@ public static void freeSharedIndexState(long shareIndexStateAddr, KNNEngine knnE */ public static byte[] trainIndex(Map indexParameters, int dimension, long trainVectorsPointer, KNNEngine knnEngine) { if (KNNEngine.FAISS == knnEngine) { + if (IndexUtil.isBinaryIndex(knnEngine, indexParameters)) { + return FaissService.trainBinaryIndex(indexParameters, dimension, trainVectorsPointer); + } return FaissService.trainIndex(indexParameters, dimension, trainVectorsPointer); } diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java index fb8ccc4ce..e0b94ec76 100644 --- a/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java @@ -17,6 +17,7 @@ import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.indices.ModelUtil; import org.opensearch.knn.plugin.KNNPlugin; import org.opensearch.knn.plugin.transport.TrainingJobRouterAction; @@ -40,6 +41,7 @@ import static org.opensearch.knn.common.KNNConstants.SEARCH_SIZE_PARAMETER; import static org.opensearch.knn.common.KNNConstants.TRAIN_FIELD_PARAMETER; import static org.opensearch.knn.common.KNNConstants.TRAIN_INDEX_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; /** * Rest Handler for model training api endpoint. @@ -83,6 +85,7 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr String trainingIndex = (String) DEFAULT_NOT_SET_OBJECT_VALUE; String trainingField = (String) DEFAULT_NOT_SET_OBJECT_VALUE; String description = (String) DEFAULT_NOT_SET_OBJECT_VALUE; + VectorDataType vectorDataType = (VectorDataType) DEFAULT_NOT_SET_OBJECT_VALUE; int dimension = DEFAULT_NOT_SET_INT_VALUE; int maximumVectorCount = DEFAULT_NOT_SET_INT_VALUE; @@ -110,6 +113,8 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr } else if (MODEL_DESCRIPTION.equals(fieldName) && ensureNotSet(fieldName, description)) { description = parser.textOrNull(); ModelUtil.blockCommasInModelDescription(description); + } else if (VECTOR_DATA_TYPE_FIELD.equals(fieldName) && ensureNotSet(fieldName, vectorDataType)) { + vectorDataType = VectorDataType.get(parser.text()); } else { throw new IllegalArgumentException("Unable to parse token. \"" + fieldName + "\" is not a valid " + "parameter."); } @@ -126,6 +131,10 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr description = ""; } + if (vectorDataType == DEFAULT_NOT_SET_OBJECT_VALUE) { + vectorDataType = VectorDataType.DEFAULT; + } + TrainingModelRequest trainingModelRequest = new TrainingModelRequest( modelId, knnMethodContext, @@ -133,7 +142,8 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr trainingIndex, trainingField, preferredNodeId, - description + description, + vectorDataType ); if (maximumVectorCount != DEFAULT_NOT_SET_INT_VALUE) { diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java index 9b2d3a9de..78f3769c5 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java @@ -22,6 +22,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.ValidationException; import org.opensearch.common.inject.Inject; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportRequestOptions; @@ -133,7 +134,9 @@ protected void getTrainingIndexSizeInKB(TrainingModelRequest trainingModelReques trainingVectors = trainingModelRequest.getMaximumVectorCount(); } - listener.onResponse(estimateVectorSetSizeInKB(trainingVectors, trainingModelRequest.getDimension())); + listener.onResponse( + estimateVectorSetSizeInKB(trainingVectors, trainingModelRequest.getDimension(), trainingModelRequest.getVectorDataType()) + ); }, listener::onFailure)); } @@ -144,8 +147,14 @@ protected void getTrainingIndexSizeInKB(TrainingModelRequest trainingModelReques * @param dimension dimension of vectors * @return size estimate */ - public static int estimateVectorSetSizeInKB(long vectorCount, int dimension) { - // Ensure we do not overflow the int on estimate - return Math.toIntExact(((Float.BYTES * dimension * vectorCount) / BYTES_PER_KILOBYTES) + 1L); + public static int estimateVectorSetSizeInKB(long vectorCount, int dimension, VectorDataType vectorDataType) { + switch (vectorDataType) { + case BINARY: + return Math.toIntExact(((Byte.BYTES * (dimension / 8) * vectorCount) / BYTES_PER_KILOBYTES) + 1L); + case BYTE: + return Math.toIntExact(((Byte.BYTES * dimension * vectorCount) / BYTES_PER_KILOBYTES) + 1L); + default: + return Math.toIntExact(((Float.BYTES * dimension * vectorCount) / BYTES_PER_KILOBYTES) + 1L); + } } } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java index 5f3913ac5..16a1a103a 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java @@ -21,6 +21,7 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.IndexUtil; import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.training.VectorSpaceInfo; @@ -41,6 +42,7 @@ public class TrainingModelRequest extends ActionRequest { private final String trainingField; private final String preferredNodeId; private final String description; + private final VectorDataType vectorDataType; private int maximumVectorCount; private int searchSize; @@ -65,7 +67,8 @@ public TrainingModelRequest( String trainingIndex, String trainingField, String preferredNodeId, - String description + String description, + VectorDataType vectorDataType ) { super(); this.modelId = modelId; @@ -75,6 +78,7 @@ public TrainingModelRequest( this.trainingField = trainingField; this.preferredNodeId = preferredNodeId; this.description = description; + this.vectorDataType = vectorDataType; // Set these as defaults initially. If call wants to override them, they can use the setters. this.maximumVectorCount = Integer.MAX_VALUE; // By default, get all vectors in the index @@ -103,6 +107,11 @@ public TrainingModelRequest(StreamInput in) throws IOException { this.maximumVectorCount = in.readInt(); this.searchSize = in.readInt(); this.trainingDataSizeInKB = in.readInt(); + if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), KNNConstants.MODEL_VECTOR_DATA_TYPE_KEY)) { + this.vectorDataType = VectorDataType.get(in.readString()); + } else { + this.vectorDataType = VectorDataType.DEFAULT; + } } /** @@ -213,6 +222,10 @@ public int getSearchSize() { return searchSize; } + public VectorDataType getVectorDataType() { + return vectorDataType; + } + /** * Setter for search size. * @@ -314,7 +327,13 @@ public ActionRequestValidationException validate() { } // Validate the training field - ValidationException fieldValidation = IndexUtil.validateKnnField(indexMetadata, this.trainingField, this.dimension, modelDao); + ValidationException fieldValidation = IndexUtil.validateKnnField( + indexMetadata, + this.trainingField, + this.dimension, + modelDao, + this.vectorDataType + ); if (fieldValidation != null) { exception = exception == null ? new ActionRequestValidationException() : exception; exception.addValidationErrors(fieldValidation.validationErrors()); @@ -336,5 +355,10 @@ public void writeTo(StreamOutput out) throws IOException { out.writeInt(this.maximumVectorCount); out.writeInt(this.searchSize); out.writeInt(this.trainingDataSizeInKB); + if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), KNNConstants.MODEL_VECTOR_DATA_TYPE_KEY)) { + out.writeString(this.vectorDataType.getValue()); + } else { + out.writeString(VectorDataType.DEFAULT.getValue()); + } } } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java index 33b420e2c..a9eca609d 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java @@ -51,7 +51,8 @@ protected void doExecute(Task task, TrainingModelRequest request, ActionListener NativeMemoryLoadStrategy.TrainingLoadStrategy.getInstance(), clusterService, request.getMaximumVectorCount(), - request.getSearchSize() + request.getSearchSize(), + request.getVectorDataType() ); // Allocation representing size model will occupy in memory during training @@ -68,7 +69,8 @@ protected void doExecute(Task task, TrainingModelRequest request, ActionListener modelAnonymousEntryContext, request.getDimension(), request.getDescription(), - clusterService.localNode().getEphemeralId() + clusterService.localNode().getEphemeralId(), + request.getVectorDataType() ); KNNCounter.TRAINING_REQUESTS.increment(); diff --git a/src/main/java/org/opensearch/knn/training/ByteTrainingDataConsumer.java b/src/main/java/org/opensearch/knn/training/ByteTrainingDataConsumer.java new file mode 100644 index 000000000..70cfb4f4c --- /dev/null +++ b/src/main/java/org/opensearch/knn/training/ByteTrainingDataConsumer.java @@ -0,0 +1,81 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.training; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.knn.jni.JNICommons; +import org.opensearch.knn.index.memory.NativeMemoryAllocation; +import org.opensearch.search.SearchHit; + +import java.util.ArrayList; +import java.util.List; + +/** + * Transfers byte vectors from JVM to native memory. + */ +public class ByteTrainingDataConsumer extends TrainingDataConsumer { + private static final Logger logger = LogManager.getLogger(TrainingDataConsumer.class); + + /** + * Constructor + * + * @param trainingDataAllocation NativeMemoryAllocation that contains information about native memory allocation. + */ + public ByteTrainingDataConsumer(NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation) { + super(trainingDataAllocation); + } + + @Override + public void accept(List byteVectors) { + long memoryAddress = trainingDataAllocation.getMemoryAddress(); + memoryAddress = JNICommons.storeByteVectorData(memoryAddress, byteVectors.toArray(new byte[0][0]), byteVectors.size()); + trainingDataAllocation.setMemoryAddress(memoryAddress); + } + + @Override + public void processTrainingVectors(SearchResponse searchResponse, int vectorsToAdd, String fieldName) { + SearchHit[] hits = searchResponse.getHits().getHits(); + List vectors = new ArrayList<>(); + String[] fieldPath = fieldName.split("\\."); + int nullVectorCount = 0; + + for (int vector = 0; vector < vectorsToAdd; vector++) { + Object fieldValue = extractFieldValue(hits[vector], fieldPath); + if (fieldValue == null) { + nullVectorCount++; + continue; + } + + byte[] byteArray; + if (!(fieldValue instanceof List)) { + continue; + } + List fieldList = (List) fieldValue; + byteArray = new byte[fieldList.size()]; + for (int i = 0; i < fieldList.size(); i++) { + byteArray[i] = fieldList.get(i).byteValue(); + } + + vectors.add(byteArray); + } + + if (nullVectorCount > 0) { + logger.warn("Found {} documents with null byte vectors in field {}", nullVectorCount, fieldName); + } + + setTotalVectorsCountAdded(getTotalVectorsCountAdded() + vectors.size()); + + accept(vectors); + } +} diff --git a/src/main/java/org/opensearch/knn/training/FloatTrainingDataConsumer.java b/src/main/java/org/opensearch/knn/training/FloatTrainingDataConsumer.java new file mode 100644 index 000000000..d742a9184 --- /dev/null +++ b/src/main/java/org/opensearch/knn/training/FloatTrainingDataConsumer.java @@ -0,0 +1,67 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.training; + +import org.apache.commons.lang.ArrayUtils; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.knn.jni.JNIService; +import org.opensearch.knn.index.memory.NativeMemoryAllocation; +import org.opensearch.search.SearchHit; + +import java.util.ArrayList; +import java.util.List; + +/** + * Transfers float vectors from JVM to native memory. + */ +public class FloatTrainingDataConsumer extends TrainingDataConsumer { + + /** + * Constructor + * + * @param trainingDataAllocation NativeMemoryAllocation that contains information about native memory allocation. + */ + public FloatTrainingDataConsumer(NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation) { + super(trainingDataAllocation); + } + + @Override + public void accept(List floats) { + trainingDataAllocation.setMemoryAddress( + JNIService.transferVectors( + trainingDataAllocation.getMemoryAddress(), + floats.stream().map(v -> ArrayUtils.toPrimitive((Float[]) v)).toArray(float[][]::new) + ) + ); + } + + @Override + public void processTrainingVectors(SearchResponse searchResponse, int vectorsToAdd, String fieldName) { + SearchHit[] hits = searchResponse.getHits().getHits(); + List vectors = new ArrayList<>(); + String[] fieldPath = fieldName.split("\\."); + + for (int vector = 0; vector < vectorsToAdd; vector++) { + Object fieldValue = extractFieldValue(hits[vector], fieldPath); + if (!(fieldValue instanceof List)) { + continue; + } + + List fieldList = (List) fieldValue; + vectors.add(fieldList.stream().map(Number::floatValue).toArray(Float[]::new)); + } + + setTotalVectorsCountAdded(getTotalVectorsCountAdded() + vectors.size()); + + accept(vectors); + } +} diff --git a/src/main/java/org/opensearch/knn/training/TrainingDataConsumer.java b/src/main/java/org/opensearch/knn/training/TrainingDataConsumer.java index 6732bd3f4..9d0683fdc 100644 --- a/src/main/java/org/opensearch/knn/training/TrainingDataConsumer.java +++ b/src/main/java/org/opensearch/knn/training/TrainingDataConsumer.java @@ -11,19 +11,25 @@ package org.opensearch.knn.training; -import org.apache.commons.lang.ArrayUtils; -import org.opensearch.knn.jni.JNIService; +import lombok.Getter; +import lombok.Setter; +import org.opensearch.action.search.SearchResponse; import org.opensearch.knn.index.memory.NativeMemoryAllocation; +import org.opensearch.search.SearchHit; import java.util.List; -import java.util.function.Consumer; +import java.util.Map; /** - * Transfers vectors from JVM to native memory. + * TrainingDataConsumer is an abstract class that defines the interface for consuming training data. + * It is used to process training data and add it to the training data allocation. */ -public class TrainingDataConsumer implements Consumer> { +public abstract class TrainingDataConsumer { - private final NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation; + @Setter + @Getter + private int totalVectorsCountAdded = 0; + protected final NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation; /** * Constructor @@ -34,13 +40,25 @@ public TrainingDataConsumer(NativeMemoryAllocation.TrainingDataAllocation traini this.trainingDataAllocation = trainingDataAllocation; } - @Override - public void accept(List floats) { - trainingDataAllocation.setMemoryAddress( - JNIService.transferVectors( - trainingDataAllocation.getMemoryAddress(), - floats.stream().map(ArrayUtils::toPrimitive).toArray(float[][]::new) - ) - ); + protected abstract void accept(List vectors); + + public abstract void processTrainingVectors(SearchResponse searchResponse, int vectorsToAdd, String fieldName); + + /** + * Traverses the hit to the desired field and extracts its value. + * + * @param hit The search hit to extract the field value from + * @param fieldPath The path to the desired field + * @return The extracted field value, or null if the field does not exist + */ + protected Object extractFieldValue(SearchHit hit, String[] fieldPath) { + Map currentMap = hit.getSourceAsMap(); + for (int pathPart = 0; pathPart < fieldPath.length - 1; pathPart++) { + currentMap = (Map) currentMap.get(fieldPath[pathPart]); + if (currentMap == null) { + return null; + } + } + return currentMap.get(fieldPath[fieldPath.length - 1]); } } diff --git a/src/main/java/org/opensearch/knn/training/TrainingJob.java b/src/main/java/org/opensearch/knn/training/TrainingJob.java index aa2786c0a..928396289 100644 --- a/src/main/java/org/opensearch/knn/training/TrainingJob.java +++ b/src/main/java/org/opensearch/knn/training/TrainingJob.java @@ -16,7 +16,9 @@ import org.apache.logging.log4j.Logger; import org.opensearch.common.UUIDs; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.IndexUtil; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.memory.NativeMemoryAllocation; @@ -32,6 +34,8 @@ import java.util.Map; import java.util.Objects; +import static org.opensearch.knn.index.util.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; + /** * Encapsulates all information required to generate and train a model. */ @@ -66,7 +70,8 @@ public TrainingJob( NativeMemoryEntryContext.AnonymousEntryContext modelAnonymousEntryContext, int dimension, String description, - String nodeAssignment + String nodeAssignment, + VectorDataType vectorDataType ) { // Generate random base64 string if one is not provided this.modelId = StringUtils.isNotBlank(modelId) ? modelId : UUIDs.randomBase64UUID(); @@ -84,7 +89,8 @@ public TrainingJob( description, "", nodeAssignment, - knnMethodContext.getMethodComponentContext() + knnMethodContext.getMethodComponentContext(), + vectorDataType ), null, this.modelId @@ -182,6 +188,15 @@ public void run() { KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY) ); + if (VectorDataType.BINARY == model.getModelMetadata().getVectorDataType()) { + trainParameters.put( + KNNConstants.INDEX_DESCRIPTION_PARAMETER, + FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + trainParameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString() + ); + } + + IndexUtil.updateVectorDataTypeToParameters(trainParameters, model.getModelMetadata().getVectorDataType()); + byte[] modelBlob = JNIService.trainIndex( trainParameters, model.getModelMetadata().getDimension(), diff --git a/src/main/java/org/opensearch/knn/training/VectorReader.java b/src/main/java/org/opensearch/knn/training/VectorReader.java index aeebae129..f1fd744fd 100644 --- a/src/main/java/org/opensearch/knn/training/VectorReader.java +++ b/src/main/java/org/opensearch/knn/training/VectorReader.java @@ -30,7 +30,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.function.Consumer; public class VectorReader { @@ -59,13 +58,13 @@ public VectorReader(Client client) { * @param vectorConsumer consumer used to do something with the collected vectors after each search * @param listener ActionListener that should be called once all search operations complete */ - public void read( + public void read( ClusterService clusterService, String indexName, String fieldName, int maxVectorCount, int searchSize, - Consumer> vectorConsumer, + TrainingDataConsumer vectorConsumer, ActionListener listener ) { @@ -89,7 +88,7 @@ public void read( throw validationException; } - ValidationException fieldValidationException = IndexUtil.validateKnnField(indexMetadata, fieldName, -1, null); + ValidationException fieldValidationException = IndexUtil.validateKnnField(indexMetadata, fieldName, -1, null, null); if (fieldValidationException != null) { validationException = validationException == null ? new ValidationException() : validationException; validationException.addValidationErrors(validationException.validationErrors()); @@ -136,14 +135,14 @@ private SearchScrollRequestBuilder createSearchScrollRequestBuilder() { return searchScrollRequestBuilder; } - private static class VectorReaderListener implements ActionListener { + private static class VectorReaderListener implements ActionListener { final Client client; final String fieldName; final int maxVectorCount; int collectedVectorCount; final ActionListener listener; - final Consumer> vectorConsumer; + final TrainingDataConsumer vectorConsumer; SearchScrollRequestBuilder searchScrollRequestBuilder; /** @@ -162,7 +161,7 @@ public VectorReaderListener( int maxVectorCount, int collectedVectorCount, ActionListener listener, - Consumer> vectorConsumer, + TrainingDataConsumer vectorConsumer, SearchScrollRequestBuilder searchScrollRequestBuilder ) { this.client = client; @@ -181,12 +180,9 @@ public void onResponse(SearchResponse searchResponse) { // Either add the entire set of returned hits, or maxVectorCount - collectedVectorCount hits SearchHit[] hits = searchResponse.getHits().getHits(); int vectorsToAdd = Integer.min(maxVectorCount - collectedVectorCount, hits.length); - List trainingData = extractVectorsFromHits(searchResponse, vectorsToAdd); - this.collectedVectorCount += trainingData.size(); - - // Do something with the vectors - vectorConsumer.accept(trainingData); + vectorConsumer.processTrainingVectors(searchResponse, vectorsToAdd, fieldName); + this.collectedVectorCount = vectorConsumer.getTotalVectorsCountAdded(); if (vectorsToAdd <= 0 || this.collectedVectorCount >= maxVectorCount) { // Clear scroll context diff --git a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java index f9c0161d6..06431bf07 100644 --- a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java +++ b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java @@ -46,8 +46,17 @@ import java.util.concurrent.ExecutionException; import static org.mockito.Mockito.when; -import static org.opensearch.knn.common.KNNConstants.*; +import static org.opensearch.knn.common.KNNConstants.DIMENSION; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.MODEL_BLOB_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION; +import static org.opensearch.knn.common.KNNConstants.MODEL_ERROR; +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; +import static org.opensearch.knn.common.KNNConstants.MODEL_STATE; +import static org.opensearch.knn.common.KNNConstants.MODEL_TIMESTAMP; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; public class KNNSingleNodeTestCase extends OpenSearchSingleNodeTestCase { @Override @@ -201,7 +210,8 @@ protected void writeModelToModelSystemIndex(Model model) throws IOException, Exe .field(MODEL_STATE, modelMetadata.getState().getName()) .field(MODEL_TIMESTAMP, modelMetadata.getTimestamp().toString()) .field(MODEL_DESCRIPTION, modelMetadata.getDescription()) - .field(MODEL_ERROR, modelMetadata.getError()); + .field(MODEL_ERROR, modelMetadata.getError()) + .field(VECTOR_DATA_TYPE_FIELD, modelMetadata.getVectorDataType().getValue()); if (model.getModelBlob() != null) { builder.field(MODEL_BLOB_PARAMETER, Base64.getEncoder().encodeToString(model.getModelBlob())); diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index b9116b0b1..0ac632c52 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -44,6 +44,7 @@ import java.util.TreeMap; import java.util.stream.Collectors; +import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M; import static org.opensearch.knn.common.KNNConstants.ENCODER_PQ; @@ -55,18 +56,25 @@ import static org.opensearch.knn.common.KNNConstants.FP16_MAX_VALUE; import static org.opensearch.knn.common.KNNConstants.FP16_MIN_VALUE; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE; import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.MIN_SCORE; +import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.NAME; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.common.KNNConstants.TRAIN_FIELD_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.TRAIN_INDEX_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; public class FaissIT extends KNNRestTestCase { private static final String DOC_ID_1 = "doc1"; @@ -107,13 +115,13 @@ public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHN .startObject(FIELD_NAME) .field("type", "knn_vector") .field("dimension", dimension) - .startObject(KNNConstants.KNN_METHOD) + .startObject(KNN_METHOD) .field(NAME, hnswMethod.getMethodComponent().getName()) .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) .field(KNN_ENGINE, KNNEngine.FAISS.getName()) .startObject(PARAMETERS) - .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) .endObject() .endObject() @@ -166,13 +174,13 @@ public void testEndToEnd_whenDoRadiusSearch_whenScoreThreshold_whenMethodIsHNSWF .startObject(FIELD_NAME) .field("type", "knn_vector") .field("dimension", dimension) - .startObject(KNNConstants.KNN_METHOD) + .startObject(KNN_METHOD) .field(NAME, hnswMethod.getMethodComponent().getName()) .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) .field(KNN_ENGINE, KNNEngine.FAISS.getName()) .startObject(PARAMETERS) - .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) .endObject() .endObject() @@ -226,13 +234,13 @@ public void testEndToEnd_whenDoRadiusSearch_whenMoreThanOneScoreThreshold_whenMe .startObject(FIELD_NAME) .field("type", "knn_vector") .field("dimension", dimension) - .startObject(KNNConstants.KNN_METHOD) + .startObject(KNN_METHOD) .field(NAME, hnswMethod.getMethodComponent().getName()) .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) .field(KNN_ENGINE, KNNEngine.FAISS.getName()) .startObject(PARAMETERS) - .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) .endObject() .endObject() @@ -296,8 +304,8 @@ public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHN .field(NAME, METHOD_HNSW) .field(KNN_ENGINE, FAISS_NAME) .startObject(PARAMETERS) - .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) .startObject(METHOD_ENCODER_PARAMETER) .field(NAME, ENCODER_PQ) @@ -424,8 +432,8 @@ public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() { .field(NAME, METHOD_HNSW) .field(KNN_ENGINE, FAISS_NAME) .startObject(PARAMETERS) - .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) .startObject(METHOD_ENCODER_PARAMETER) .field(NAME, ENCODER_PQ) @@ -531,13 +539,13 @@ public void testHNSWSQFP16_whenIndexedAndQueried_thenSucceed() { .startObject(fieldName) .field("type", "knn_vector") .field("dimension", dimension) - .startObject(KNNConstants.KNN_METHOD) + .startObject(KNN_METHOD) .field(NAME, hnswMethod.getMethodComponent().getName()) .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) .field(KNN_ENGINE, KNNEngine.FAISS.getName()) .startObject(PARAMETERS) - .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) .startObject(METHOD_ENCODER_PARAMETER) .field(NAME, ENCODER_SQ) @@ -644,13 +652,13 @@ public void testHNSWSQFP16_whenIndexedWithOutOfFP16Range_thenThrowException() { .startObject(fieldName) .field("type", "knn_vector") .field("dimension", dimension) - .startObject(KNNConstants.KNN_METHOD) + .startObject(KNN_METHOD) .field(NAME, hnswMethod.getMethodComponent().getName()) .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) .field(KNN_ENGINE, KNNEngine.FAISS.getName()) .startObject(PARAMETERS) - .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) .startObject(METHOD_ENCODER_PARAMETER) .field(NAME, ENCODER_SQ) @@ -744,13 +752,13 @@ public void testHNSWSQFP16_whenClipToFp16isTrueAndIndexedWithOutOfFP16Range_then .startObject(fieldName) .field("type", "knn_vector") .field("dimension", dimension) - .startObject(KNNConstants.KNN_METHOD) + .startObject(KNN_METHOD) .field(NAME, hnswMethod.getMethodComponent().getName()) .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) .field(KNN_ENGINE, KNNEngine.FAISS.getName()) .startObject(PARAMETERS) - .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) .startObject(METHOD_ENCODER_PARAMETER) .field(NAME, ENCODER_SQ) @@ -997,7 +1005,7 @@ public void testEndToEnd_whenMethodIsHNSWPQAndHyperParametersNotSet_thenSucceed( .field(NAME, METHOD_HNSW) .field(KNN_ENGINE, FAISS_NAME) .startObject(PARAMETERS) - .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) .startObject(METHOD_ENCODER_PARAMETER) .field(NAME, ENCODER_PQ) .startObject(PARAMETERS) @@ -1204,7 +1212,7 @@ public void testDocUpdate() throws IOException { .startObject(fieldName) .field("type", "knn_vector") .field("dimension", dimension) - .startObject(KNNConstants.KNN_METHOD) + .startObject(KNN_METHOD) .field(NAME, hnswMethod.getMethodComponent().getName()) .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) .field(KNN_ENGINE, KNNEngine.FAISS.getName()) @@ -1240,7 +1248,7 @@ public void testDocDeletion() throws IOException { .startObject(fieldName) .field("type", "knn_vector") .field("dimension", dimension) - .startObject(KNNConstants.KNN_METHOD) + .startObject(KNN_METHOD) .field(NAME, hnswMethod.getMethodComponent().getName()) .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) .field(KNN_ENGINE, KNNEngine.FAISS.getName()) @@ -1418,7 +1426,7 @@ public void testFiltering_whenUsingFaissExactSearchWithIP_thenMatchExpectedScore .startObject(FIELD_NAME) .field("type", "knn_vector") .field("dimension", 2) - .startObject(KNNConstants.KNN_METHOD) + .startObject(KNN_METHOD) .field(NAME, KNNEngine.FAISS.getMethod(METHOD_HNSW).getMethodComponent().getName()) .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.INNER_PRODUCT.getValue()) .field(KNN_ENGINE, KNNEngine.FAISS.getName()) @@ -1593,6 +1601,230 @@ public void testIVF_InvalidPQM_thenFail() { ); } + @SneakyThrows + public void testIVF_whenBinaryFormat_whenIVF_thenSuccess() { + String modelId = "test-model-ivf-binary"; + int dimension = 8; + + String trainingIndexName = "train-index-ivf-binary"; + String trainingFieldName = "train-field-ivf-binary"; + + String trainIndexMapping = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(trainingFieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .field("data_type", VectorDataType.BINARY.getValue()) + .startObject(KNN_METHOD) + .field(NAME, KNNEngine.FAISS.getMethod(METHOD_HNSW).getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.HAMMING_BIT.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_M, 24) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, 128) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .toString(); + + createKnnIndex(trainingIndexName, trainIndexMapping); + + int trainingDataCount = 40; + bulkIngestRandomBinaryVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension); + + XContentBuilder trainModelXContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TRAIN_INDEX_PARAMETER, trainingIndexName) + .field(TRAIN_FIELD_PARAMETER, trainingFieldName) + .field(DIMENSION, dimension) + .field(MODEL_DESCRIPTION, "My model description") + .field(VECTOR_DATA_TYPE_FIELD, VectorDataType.BINARY.getValue()) + .field( + KNN_METHOD, + Map.of( + NAME, + METHOD_IVF, + KNN_ENGINE, + FAISS_NAME, + METHOD_PARAMETER_SPACE_TYPE, + SpaceType.HAMMING_BIT.getValue(), + PARAMETERS, + Map.of(METHOD_PARAMETER_NLIST, 1, METHOD_PARAMETER_NPROBES, 1) + ) + ) + .endObject(); + + trainModel(modelId, trainModelXContentBuilder); + + // Make sure training succeeds after 30 seconds + assertTrainingSucceeds(modelId, 30, 1000); + + // Create knn index from model + String fieldName = "test-field-name-ivf-binary"; + String indexName = "test-index-name-ivf-binary"; + String indexMapping = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field(MODEL_ID, modelId) + .endObject() + .endObject() + .endObject() + .toString(); + + createKnnIndex(indexName, getKNNDefaultIndexSettings(), indexMapping); + Integer[] vector1 = { 11 }; + Integer[] vector2 = { 22 }; + Integer[] vector3 = { 33 }; + Integer[] vector4 = { 44 }; + addKnnDoc(indexName, "1", fieldName, vector1); + addKnnDoc(indexName, "2", fieldName, vector2); + addKnnDoc(indexName, "3", fieldName, vector3); + addKnnDoc(indexName, "4", fieldName, vector4); + + Integer[] queryVector = { 15 }; + int k = 2; + + XContentBuilder queryBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("knn") + .startObject(fieldName) + .field("vector", queryVector) + .field("k", k) + .endObject() + .endObject() + .endObject() + .endObject(); + Response searchResponse = searchKNNIndex(indexName, queryBuilder, k); + List results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName); + assertEquals(k, results.size()); + + deleteKNNIndex(indexName); + Thread.sleep(45 * 1000); + deleteModel(modelId); + deleteKNNIndex(trainingIndexName); + validateGraphEviction(); + } + + @SneakyThrows + public void testIVF_whenBinaryFormat_whenIVFPQ_thenSuccess() { + String modelId = "test-model-ivfpq-binary"; + int dimension = 8; + + String trainingIndexName = "train-index-ivfpq-binary"; + String trainingFieldName = "train-field-ivfpq-binary"; + + String trainIndexMapping = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(trainingFieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .field("data_type", VectorDataType.BINARY.getValue()) + .startObject(KNN_METHOD) + .field(NAME, KNNEngine.FAISS.getMethod(METHOD_HNSW).getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.HAMMING_BIT.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_M, 24) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, 128) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .toString(); + + createKnnIndex(trainingIndexName, trainIndexMapping); + + int trainingDataCount = 50; + bulkIngestRandomBinaryVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension); + + XContentBuilder trainModelXContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TRAIN_INDEX_PARAMETER, trainingIndexName) + .field(TRAIN_FIELD_PARAMETER, trainingFieldName) + .field(DIMENSION, dimension) + .field(MODEL_DESCRIPTION, "My model description") + .field(VECTOR_DATA_TYPE_FIELD, VectorDataType.BINARY.getValue()) + .startObject(KNN_METHOD) + .field(NAME, METHOD_IVF) + .field(KNN_ENGINE, FAISS_NAME) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.HAMMING_BIT.getValue()) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NPROBES, 1) + .field(METHOD_PARAMETER_NLIST, 1) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_PQ) + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_CODE_SIZE, 8) + .field(ENCODER_PARAMETER_PQ_M, 8) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + trainModel(modelId, trainModelXContentBuilder); + + // Make sure training succeeds after 30 seconds + assertTrainingSucceeds(modelId, 30, 1000); + + // Create knn index from model + String fieldName = "test-field-name-ivfpq-binary"; + String indexName = "test-index-name-ivfpq-binary"; + + String indexMapping = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field(MODEL_ID, modelId) + .endObject() + .endObject() + .endObject() + .toString(); + + createKnnIndex(indexName, getKNNDefaultIndexSettings(), indexMapping); + Integer[] vector1 = { 11 }; + Integer[] vector2 = { 22 }; + Integer[] vector3 = { 33 }; + Integer[] vector4 = { 44 }; + addKnnDoc(indexName, "1", fieldName, vector1); + addKnnDoc(indexName, "2", fieldName, vector2); + addKnnDoc(indexName, "3", fieldName, vector3); + addKnnDoc(indexName, "4", fieldName, vector4); + + Integer[] queryVector = { 15 }; + int k = 2; + + XContentBuilder queryBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("knn") + .startObject(fieldName) + .field("vector", queryVector) + .field("k", k) + .endObject() + .endObject() + .endObject() + .endObject(); + Response searchResponse = searchKNNIndex(indexName, queryBuilder, k); + List results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName); + assertEquals(k, results.size()); + + deleteKNNIndex(indexName); + Thread.sleep(45 * 1000); + deleteModel(modelId); + deleteKNNIndex(trainingIndexName); + validateGraphEviction(); + } + protected void setupKNNIndexForFilterQuery() throws Exception { // Create Mappings XContentBuilder builder = XContentFactory.jsonBuilder() @@ -1601,7 +1833,7 @@ protected void setupKNNIndexForFilterQuery() throws Exception { .startObject(FIELD_NAME) .field("type", "knn_vector") .field("dimension", 3) - .startObject(KNNConstants.KNN_METHOD) + .startObject(KNN_METHOD) .field(NAME, KNNEngine.FAISS.getMethod(METHOD_HNSW).getMethodComponent().getName()) .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2) .field(KNN_ENGINE, KNNEngine.FAISS.getName()) diff --git a/src/test/java/org/opensearch/knn/index/IndexUtilTests.java b/src/test/java/org/opensearch/knn/index/IndexUtilTests.java index d500fc342..1b00ecfaa 100644 --- a/src/test/java/org/opensearch/knn/index/IndexUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/IndexUtilTests.java @@ -117,7 +117,7 @@ public void testValidateKnnField_NestedField() { when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata); - ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao); + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null); assertNull(e); } @@ -138,7 +138,7 @@ public void testValidateKnnField_NonNestedField() { when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata); - ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao); + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null); assertNull(e); } @@ -158,7 +158,7 @@ public void testValidateKnnField_NonKnnField() { when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata); - ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao); + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null); assert Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Field \"" + field + "\" is not of type knn_vector.;"); } @@ -182,7 +182,7 @@ public void testValidateKnnField_WrongFieldPath() { when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata); - ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao); + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null); assert (Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Field \"" + field + "\" does not exist.;")); } @@ -206,7 +206,7 @@ public void testValidateKnnField_EmptyField() { when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata); - ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao); + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null); System.out.println(Objects.requireNonNull(e).getMessage()); @@ -223,7 +223,7 @@ public void testValidateKnnField_EmptyIndexMetadata() { when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata); - ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao); + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null); assert (Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Invalid index. Index does not contain a mapping;")); } @@ -259,4 +259,56 @@ public void testIsBinaryIndex_whenNonBinary_thenFalse() { nonBinaryIndexParams.put(VECTOR_DATA_TYPE_FIELD, "byte"); assertFalse(IndexUtil.isBinaryIndex(KNNEngine.FAISS, nonBinaryIndexParams)); } + + public void testValidateKnnField_whenTrainModelUseDifferentVectorDataTypeFromTrainIndex_thenThrowException() { + Map fieldValues = Map.of("type", "knn_vector", "dimension", 8, "data_type", "float"); + Map top_level_field = Map.of("top_level_field", fieldValues); + Map properties = Map.of("properties", top_level_field); + String field = "top_level_field"; + int dimension = 8; + + MappingMetadata mappingMetadata = mock(MappingMetadata.class); + when(mappingMetadata.getSourceAsMap()).thenReturn(properties); + IndexMetadata indexMetadata = mock(IndexMetadata.class); + when(indexMetadata.mapping()).thenReturn(mappingMetadata); + ModelDao modelDao = mock(ModelDao.class); + + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, VectorDataType.BINARY); + System.out.println(Objects.requireNonNull(e).getMessage()); + + assert Objects.requireNonNull(e) + .getMessage() + .matches( + "Validation Failed: 1: Field \"" + + field + + "\" has data type float, which is different from data type used in the training request: binary;" + ); + } + + public void testValidateKnnField_whenPassByteVectorDataType_thenThrowException() { + Map fieldValues = Map.of("type", "knn_vector", "dimension", 8, "data_type", "byte"); + Map top_level_field = Map.of("top_level_field", fieldValues); + Map properties = Map.of("properties", top_level_field); + String field = "top_level_field"; + int dimension = 8; + + MappingMetadata mappingMetadata = mock(MappingMetadata.class); + when(mappingMetadata.getSourceAsMap()).thenReturn(properties); + IndexMetadata indexMetadata = mock(IndexMetadata.class); + when(indexMetadata.mapping()).thenReturn(mappingMetadata); + ModelDao modelDao = mock(ModelDao.class); + + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, VectorDataType.BYTE); + System.out.println(Objects.requireNonNull(e).getMessage()); + + assert Objects.requireNonNull(e) + .getMessage() + .matches("Validation Failed: 1: vector data type \"" + VectorDataType.BYTE.getValue() + "\" is not supported for training.;"); + } + + public void testUpdateVectorDataTypeToParameters_whenVectorDataTypeIsBinary() { + Map indexParams = new HashMap<>(); + IndexUtil.updateVectorDataTypeToParameters(indexParams, VectorDataType.BINARY); + assertEquals(VectorDataType.BINARY.getValue(), indexParams.get(VECTOR_DATA_TYPE_FIELD)); + } } diff --git a/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java b/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java index 11a8bdb15..e9b78e7ec 100644 --- a/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java @@ -63,7 +63,8 @@ public void testCreateIndexFromModel() throws IOException, InterruptedException "", "", "test-node", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); Model model = new Model(modelMetadata, modelBlob, modelId); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java index 847cad04e..7fdd7df3a 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java @@ -125,8 +125,21 @@ public void testAddBinaryField_withoutKNN() throws IOException { DocValuesConsumer delegate = mock(DocValuesConsumer.class); doNothing().when(delegate).addBinaryField(fieldInfo, docValuesProducer); + String segmentName = String.format("test_segment%s", randomAlphaOfLength(4)); + int docsInSegment = 100; + + SegmentInfo segmentInfo = KNNCodecTestUtil.segmentInfoBuilder() + .directory(directory) + .segmentName(segmentName) + .docsInSegment(docsInSegment) + .codec(codec) + .build(); + + FieldInfos fieldInfos = mock(FieldInfos.class); + SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); + final boolean[] called = { false }; - KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(delegate, null) { + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(delegate, state) { @Override public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh) { @@ -148,7 +161,19 @@ public void testAddKNNBinaryField_noVectors() throws IOException { Long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); Long initialMergeSize = KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue(); Long initialMergeDocs = KNNGraphValue.MERGE_TOTAL_DOCS.getValue(); - KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, null); + String segmentName = String.format("test_segment%s", randomAlphaOfLength(4)); + int docsInSegment = 100; + + SegmentInfo segmentInfo = KNNCodecTestUtil.segmentInfoBuilder() + .directory(directory) + .segmentName(segmentName) + .docsInSegment(docsInSegment) + .codec(codec) + .build(); + + FieldInfos fieldInfos = mock(FieldInfos.class); + SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); FieldInfo fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("test-field").build(); knn80DocValuesConsumer.addKNNBinaryField(fieldInfo, randomVectorDocValuesProducer, true, true); assertEquals(initialGraphIndexRequests, KNNCounter.GRAPH_INDEX_REQUESTS.getCount()); @@ -424,7 +449,8 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio "Empty description", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), modelBytes, modelId diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index b82bc85e0..66fe9770d 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -20,15 +20,16 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.MethodComponentContext; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.query.KNNQueryFactory; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.query.KNNQuery; -import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.query.KNNWeight; -import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.VectorField; import org.apache.lucene.codecs.Codec; import org.apache.lucene.document.Document; import org.apache.lucene.document.FieldType; @@ -213,7 +214,8 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); Model mockModel = new Model(modelMetadata1, modelBlob, modelId); diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index c3ddcf185..0e3fc6fb4 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -175,7 +175,8 @@ public void testBuilder_build_fromModel() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); builder.modelId.setValue(modelId); Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); @@ -676,7 +677,8 @@ public void testKNNVectorFieldMapper_merge_fromModel() throws IOException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); when(mockModelDao.getMetadata(modelId)).thenReturn(mockModelMetadata); @@ -747,7 +749,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { parseContext, TEST_DIMENSION, luceneFieldMapper.fieldType().spaceType, - luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext() + luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(), + VectorDataType.FLOAT ); // Document should have 2 fields: one for VectorField (binary doc values) and one for KnnVectorField @@ -791,7 +794,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { parseContext, TEST_DIMENSION, luceneFieldMapper.fieldType().spaceType, - luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext() + luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(), + VectorDataType.FLOAT ); // Document should have 1 field: one for KnnVectorField @@ -826,7 +830,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { parseContext, TEST_DIMENSION, luceneFieldMapper.fieldType().spaceType, - luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext() + luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(), + VectorDataType.BYTE ); // Document should have 2 fields: one for VectorField (binary doc values) and one for KnnByteVectorField @@ -869,7 +874,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { parseContext, TEST_DIMENSION, luceneFieldMapper.fieldType().spaceType, - luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext() + luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(), + VectorDataType.BYTE ); // Document should have 1 field: one for KnnByteVectorField diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java index ab73c3946..a8f93787c 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java @@ -315,7 +315,8 @@ public void testTrainingDataAllocation_close() throws InterruptedException { NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = new NativeMemoryAllocation.TrainingDataAllocation( executorService, memoryAddress, - 0 + 0, + VectorDataType.FLOAT ); trainingDataAllocation.close(); @@ -341,7 +342,8 @@ public void testTrainingDataAllocation_getMemoryAddress() { NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = new NativeMemoryAllocation.TrainingDataAllocation( null, memoryAddress, - 0 + 0, + VectorDataType.FLOAT ); assertEquals(memoryAddress, trainingDataAllocation.getMemoryAddress()); @@ -354,7 +356,8 @@ public void testTrainingDataAllocation_readLock() throws InterruptedException { NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = new NativeMemoryAllocation.TrainingDataAllocation( null, 0, - 0 + 0, + VectorDataType.FLOAT ); int initialValue = 10; @@ -387,7 +390,8 @@ public void testTrainingDataAllocation_writeLock() throws InterruptedException { NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = new NativeMemoryAllocation.TrainingDataAllocation( null, 0, - 0 + 0, + VectorDataType.FLOAT ); int initialValue = 10; @@ -422,7 +426,8 @@ public void testTrainingDataAllocation_getSize() { NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = new NativeMemoryAllocation.TrainingDataAllocation( null, 0, - size + size, + VectorDataType.FLOAT ); assertEquals(size, trainingDataAllocation.getSizeInKB()); @@ -434,7 +439,8 @@ public void testTrainingDataAllocation_setMemoryAddress() { NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = new NativeMemoryAllocation.TrainingDataAllocation( null, pointer, - 0 + 0, + VectorDataType.FLOAT ); assertEquals(pointer, trainingDataAllocation.getMemoryAddress()); diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java index 718df0b1f..85eaf3322 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java @@ -16,6 +16,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.knn.common.exception.OutOfNativeMemoryException; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.plugin.KNNPlugin; import org.opensearch.plugins.Plugin; import org.opensearch.test.OpenSearchSingleNodeTestCase; @@ -186,7 +187,8 @@ public void testGetTrainingSize() throws ExecutionException { NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = new NativeMemoryAllocation.TrainingDataAllocation( null, 0, - allocationEntryWeight + allocationEntryWeight, + VectorDataType.FLOAT ); NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = mock( diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryEntryContextTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryEntryContextTests.java index 495f20347..f87a069a2 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryEntryContextTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryEntryContextTests.java @@ -15,6 +15,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.IndexUtil; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import java.io.BufferedOutputStream; @@ -122,13 +123,15 @@ public void testTrainingDataEntryContext_load() { trainingLoadStrategy, null, 0, - 0 + 0, + VectorDataType.DEFAULT ); NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = new NativeMemoryAllocation.TrainingDataAllocation( null, 0, - 0 + 0, + VectorDataType.DEFAULT ); when(trainingLoadStrategy.load(trainingDataEntryContext)).thenReturn(trainingDataAllocation); @@ -145,7 +148,8 @@ public void testTrainingDataEntryContext_getTrainIndexName() { null, null, 0, - 0 + 0, + VectorDataType.DEFAULT ); assertEquals(trainIndexName, trainingDataEntryContext.getTrainIndexName()); @@ -160,7 +164,8 @@ public void testTrainingDataEntryContext_getTrainFieldName() { null, null, 0, - 0 + 0, + VectorDataType.DEFAULT ); assertEquals(trainFieldName, trainingDataEntryContext.getTrainFieldName()); @@ -175,7 +180,8 @@ public void testTrainingDataEntryContext_getMaxVectorCount() { null, null, maxVectorCount, - 0 + 0, + VectorDataType.DEFAULT ); assertEquals(maxVectorCount, trainingDataEntryContext.getMaxVectorCount()); @@ -190,7 +196,8 @@ public void testTrainingDataEntryContext_getSearchSize() { null, null, 0, - searchSize + searchSize, + VectorDataType.DEFAULT ); assertEquals(searchSize, trainingDataEntryContext.getSearchSize()); @@ -205,7 +212,8 @@ public void testTrainingDataEntryContext_getIndicesService() { null, clusterService, 0, - 0 + 0, + VectorDataType.DEFAULT ); assertEquals(clusterService, trainingDataEntryContext.getClusterService()); diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java index 43ad7e968..b277629e6 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java @@ -22,7 +22,7 @@ import org.opensearch.knn.index.query.KNNQueryResult; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.util.KNNEngine; -import org.opensearch.knn.training.TrainingDataConsumer; +import org.opensearch.knn.training.FloatTrainingDataConsumer; import org.opensearch.knn.training.VectorReader; import org.opensearch.watcher.ResourceWatcherService; @@ -150,12 +150,12 @@ public void testTrainingLoadStrategy_load() { logger.info("J0"); doAnswer(invocationOnMock -> { logger.info("J1"); - TrainingDataConsumer trainingDataConsumer = (TrainingDataConsumer) invocationOnMock.getArguments()[5]; + FloatTrainingDataConsumer floatTrainingDataConsumer = (FloatTrainingDataConsumer) invocationOnMock.getArguments()[5]; ActionListener listener = (ActionListener) invocationOnMock.getArguments()[6]; Thread thread = new Thread(() -> { try { Thread.sleep(2000); - trainingDataConsumer.accept(vectors); // Transfer some floats + floatTrainingDataConsumer.accept(vectors); // Transfer some floats listener.onResponse(null); } catch (InterruptedException e) { listener.onFailure(null); @@ -176,7 +176,8 @@ public void testTrainingLoadStrategy_load() { NativeMemoryLoadStrategy.TrainingLoadStrategy.getInstance(), null, 0, - 0 + 0, + VectorDataType.FLOAT ); // Load the allocation. Initially, the memory address should be 0. However, after the readlock is obtained, diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 06e370026..738ee890d 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -903,6 +903,7 @@ public void testDoToQuery_FromModel() { when(modelMetadata.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); when(modelMetadata.getState()).thenReturn(ModelState.CREATED); when(modelMetadata.getMethodComponentContext()).thenReturn(new MethodComponentContext("ivf", emptyMap())); + when(modelMetadata.getVectorDataType()).thenReturn(VectorDataType.DEFAULT); ModelDao modelDao = mock(ModelDao.class); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); KNNQueryBuilder.initialize(modelDao); @@ -940,6 +941,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2); when(modelMetadata.getState()).thenReturn(ModelState.CREATED); when(modelMetadata.getMethodComponentContext()).thenReturn(new MethodComponentContext("ivf", emptyMap())); + when(modelMetadata.getVectorDataType()).thenReturn(VectorDataType.DEFAULT); ModelDao modelDao = mock(ModelDao.class); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); KNNQueryBuilder.initialize(modelDao); @@ -975,6 +977,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_th when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2); when(modelMetadata.getState()).thenReturn(ModelState.CREATED); + when(modelMetadata.getVectorDataType()).thenReturn(VectorDataType.DEFAULT); when(modelMetadata.getMethodComponentContext()).thenReturn(new MethodComponentContext("ivf", emptyMap())); ModelDao modelDao = mock(ModelDao.class); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 9769bcc15..a7ee6993b 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -38,6 +38,7 @@ import org.opensearch.core.common.unit.ByteSizeValue; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.KNNCodecVersion; @@ -62,6 +63,7 @@ import java.util.function.Function; import java.util.stream.Collectors; +import static java.util.Collections.emptyMap; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; @@ -199,6 +201,8 @@ public void testQueryScoreForFaissWithModel() { when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); when(modelMetadata.getSpaceType()).thenReturn(spaceType); when(modelMetadata.getState()).thenReturn(ModelState.CREATED); + when(modelMetadata.getVectorDataType()).thenReturn(VectorDataType.DEFAULT); + when(modelMetadata.getMethodComponentContext()).thenReturn(new MethodComponentContext("ivf", emptyMap())); when(modelDao.getMetadata(eq("modelId"))).thenReturn(modelMetadata); KNNWeight.initialize(modelDao); diff --git a/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java b/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java index 3a5255cd3..e0111204d 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java @@ -19,6 +19,7 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import java.time.ZoneOffset; @@ -45,7 +46,8 @@ public void testGet_normal() throws ExecutionException, InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), "hello".getBytes(), modelId @@ -82,7 +84,8 @@ public void testGet_modelDoesNotFitInCache() throws ExecutionException, Interrup "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), new byte[BYTES_PER_KILOBYTES + 1], modelId @@ -140,7 +143,8 @@ public void testGetTotalWeight() throws ExecutionException, InterruptedException "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), new byte[size1], modelId1 @@ -156,7 +160,8 @@ public void testGetTotalWeight() throws ExecutionException, InterruptedException "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), new byte[size2], modelId2 @@ -200,7 +205,8 @@ public void testRemove_normal() throws ExecutionException, InterruptedException "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), new byte[size1], modelId1 @@ -216,8 +222,8 @@ public void testRemove_normal() throws ExecutionException, InterruptedException "", "", "", - MethodComponentContext.EMPTY - + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), new byte[size2], modelId2 @@ -266,7 +272,8 @@ public void testRebuild_normal() throws ExecutionException, InterruptedException "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), "hello".getBytes(), modelId @@ -312,7 +319,8 @@ public void testRebuild_afterSettingUpdate() throws ExecutionException, Interrup "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), new byte[modelSize], modelId @@ -381,7 +389,8 @@ public void testContains() throws ExecutionException, InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), new byte[modelSize1], modelId1 @@ -423,7 +432,8 @@ public void testRemoveAll() throws ExecutionException, InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), new byte[modelSize1], modelId1 @@ -441,7 +451,8 @@ public void testRemoveAll() throws ExecutionException, InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), new byte[modelSize2], modelId2 @@ -487,7 +498,8 @@ public void testModelCacheEvictionDueToSize() throws ExecutionException, Interru "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), new byte[BYTES_PER_KILOBYTES * 2], modelId diff --git a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java index 75c523332..b18a7259e 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java @@ -35,6 +35,7 @@ import org.opensearch.knn.common.exception.DeleteModelException; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.plugin.transport.DeleteModelResponse; import org.opensearch.knn.plugin.transport.GetModelResponse; @@ -139,7 +140,8 @@ public void testModelIndexHealth() throws InterruptedException, ExecutionExcepti "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), modelBlob, modelId @@ -159,7 +161,8 @@ public void testModelIndexHealth() throws InterruptedException, ExecutionExcepti "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), modelBlob, modelId @@ -187,7 +190,8 @@ public void testPut_withId() throws InterruptedException, IOException { "", "", "", - new MethodComponentContext("test", Collections.emptyMap()) + new MethodComponentContext("test", Collections.emptyMap()), + VectorDataType.DEFAULT ), modelBlob, modelId @@ -248,7 +252,8 @@ public void testPut_withoutModel() throws InterruptedException, IOException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), modelBlob, modelId @@ -310,7 +315,8 @@ public void testPut_invalid_badState() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), modelBlob, "any-id" @@ -347,7 +353,8 @@ public void testUpdate() throws IOException, InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), null, modelId @@ -386,7 +393,8 @@ public void testUpdate() throws IOException, InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), modelBlob, modelId @@ -437,7 +445,8 @@ public void testGet() throws IOException, InterruptedException, ExecutionExcepti "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), modelBlob, modelId @@ -456,7 +465,8 @@ public void testGet() throws IOException, InterruptedException, ExecutionExcepti "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), null, modelId @@ -493,7 +503,8 @@ public void testGetMetadata() throws IOException, InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); Model model = new Model(modelMetadata, modelBlob, modelId); @@ -570,7 +581,8 @@ public void testDelete() throws IOException, InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), modelBlob, modelId @@ -604,7 +616,8 @@ public void testDelete() throws IOException, InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), modelBlob, modelId1 @@ -672,7 +685,8 @@ public void testDeleteModelInTrainingWithStepListeners() throws IOException, Exe "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), modelBlob, modelId @@ -714,7 +728,8 @@ public void testDeleteWithStepListeners() throws IOException, InterruptedExcepti "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), modelBlob, modelId diff --git a/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java b/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java index 74715671f..c23b7e2fd 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java @@ -19,6 +19,7 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import java.io.IOException; @@ -45,7 +46,8 @@ public void testStreams() throws IOException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); BytesStreamOutput streamOutput = new BytesStreamOutput(); @@ -67,7 +69,8 @@ public void testGetKnnEngine() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); assertEquals(knnEngine, modelMetadata.getKnnEngine()); @@ -84,7 +87,8 @@ public void testGetSpaceType() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); assertEquals(spaceType, modelMetadata.getSpaceType()); @@ -101,7 +105,8 @@ public void testGetDimension() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); assertEquals(dimension, modelMetadata.getDimension()); @@ -118,7 +123,8 @@ public void testGetState() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); assertEquals(modelState, modelMetadata.getState()); @@ -135,7 +141,8 @@ public void testGetTimestamp() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); assertEquals(timeValue, modelMetadata.getTimestamp()); @@ -152,7 +159,8 @@ public void testDescription() { description, "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); assertEquals(description, modelMetadata.getDescription()); @@ -169,12 +177,31 @@ public void testGetError() { "", error, "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); assertEquals(error, modelMetadata.getError()); } + public void testGetVectorDataType() { + VectorDataType vectorDataType = VectorDataType.BINARY; + ModelMetadata modelMetadata = new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L2, + 12, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "", + "", + MethodComponentContext.EMPTY, + vectorDataType + ); + + assertEquals(vectorDataType, modelMetadata.getVectorDataType()); + } + public void testSetState() { ModelState modelState = ModelState.FAILED; ModelMetadata modelMetadata = new ModelMetadata( @@ -186,7 +213,8 @@ public void testSetState() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); assertEquals(modelState, modelMetadata.getState()); @@ -207,7 +235,8 @@ public void testSetError() { "", error, "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); assertEquals(error, modelMetadata.getError()); @@ -244,7 +273,9 @@ public void testToString() { + "," + nodeAssignment + "," - + methodComponentContext.toClusterStateString(); + + methodComponentContext.toClusterStateString() + + "," + + VectorDataType.DEFAULT.getValue(); ModelMetadata modelMetadata = new ModelMetadata( knnEngine, @@ -255,7 +286,8 @@ public void testToString() { description, error, nodeAssignment, - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); assertEquals(expected, modelMetadata.toString()); @@ -275,7 +307,8 @@ public void testEquals() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); ModelMetadata modelMetadata2 = new ModelMetadata( KNNEngine.FAISS, @@ -286,7 +319,8 @@ public void testEquals() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); ModelMetadata modelMetadata3 = new ModelMetadata( @@ -298,7 +332,8 @@ public void testEquals() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); ModelMetadata modelMetadata4 = new ModelMetadata( KNNEngine.FAISS, @@ -309,7 +344,8 @@ public void testEquals() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); ModelMetadata modelMetadata5 = new ModelMetadata( KNNEngine.FAISS, @@ -320,7 +356,8 @@ public void testEquals() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); ModelMetadata modelMetadata6 = new ModelMetadata( KNNEngine.FAISS, @@ -331,7 +368,8 @@ public void testEquals() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); ModelMetadata modelMetadata7 = new ModelMetadata( KNNEngine.FAISS, @@ -342,7 +380,8 @@ public void testEquals() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); ModelMetadata modelMetadata8 = new ModelMetadata( KNNEngine.FAISS, @@ -353,7 +392,8 @@ public void testEquals() { "diff descript", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); ModelMetadata modelMetadata9 = new ModelMetadata( KNNEngine.FAISS, @@ -364,7 +404,8 @@ public void testEquals() { "", "diff error", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); ModelMetadata modelMetadata10 = new ModelMetadata( @@ -376,7 +417,8 @@ public void testEquals() { "", "", "", - new MethodComponentContext("test", Collections.emptyMap()) + new MethodComponentContext("test", Collections.emptyMap()), + VectorDataType.DEFAULT ); assertEquals(modelMetadata1, modelMetadata1); @@ -406,7 +448,8 @@ public void testHashCode() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); ModelMetadata modelMetadata2 = new ModelMetadata( KNNEngine.FAISS, @@ -417,7 +460,8 @@ public void testHashCode() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); ModelMetadata modelMetadata3 = new ModelMetadata( @@ -429,7 +473,8 @@ public void testHashCode() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); ModelMetadata modelMetadata4 = new ModelMetadata( KNNEngine.FAISS, @@ -440,7 +485,8 @@ public void testHashCode() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); ModelMetadata modelMetadata5 = new ModelMetadata( KNNEngine.FAISS, @@ -451,7 +497,8 @@ public void testHashCode() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); ModelMetadata modelMetadata6 = new ModelMetadata( KNNEngine.FAISS, @@ -462,7 +509,8 @@ public void testHashCode() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); ModelMetadata modelMetadata7 = new ModelMetadata( KNNEngine.FAISS, @@ -473,7 +521,8 @@ public void testHashCode() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); ModelMetadata modelMetadata8 = new ModelMetadata( KNNEngine.FAISS, @@ -484,7 +533,8 @@ public void testHashCode() { "diff descript", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); ModelMetadata modelMetadata9 = new ModelMetadata( KNNEngine.FAISS, @@ -495,7 +545,8 @@ public void testHashCode() { "", "diff error", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); ModelMetadata modelMetadata10 = new ModelMetadata( @@ -507,7 +558,8 @@ public void testHashCode() { "", "", "", - new MethodComponentContext("test", Collections.emptyMap()) + new MethodComponentContext("test", Collections.emptyMap()), + VectorDataType.DEFAULT ); assertEquals(modelMetadata1.hashCode(), modelMetadata1.hashCode()); @@ -550,7 +602,9 @@ public void testFromString() { + "," + nodeAssignment + "," - + methodComponentContext.toClusterStateString(); + + methodComponentContext.toClusterStateString() + + "," + + VectorDataType.DEFAULT.getValue(); String stringRep2 = knnEngine.getName() + "," @@ -564,7 +618,9 @@ public void testFromString() { + "," + description + "," - + error; + + error + + "," + + VectorDataType.DEFAULT.getValue(); ModelMetadata expected1 = new ModelMetadata( knnEngine, @@ -575,7 +631,8 @@ public void testFromString() { description, error, nodeAssignment, - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); ModelMetadata expected2 = new ModelMetadata( @@ -587,7 +644,8 @@ public void testFromString() { description, error, "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); ModelMetadata fromString1 = ModelMetadata.fromString(stringRep1); @@ -620,8 +678,8 @@ public void testFromResponseMap() throws IOException { description, error, nodeAssignment, - methodComponentContext - + methodComponentContext, + VectorDataType.DEFAULT ); ModelMetadata expected2 = new ModelMetadata( knnEngine, @@ -632,7 +690,8 @@ public void testFromResponseMap() throws IOException { description, error, "", - emptyMethodComponentContext + emptyMethodComponentContext, + VectorDataType.DEFAULT ); Map metadataAsMap = new HashMap<>(); metadataAsMap.put(KNNConstants.KNN_ENGINE, knnEngine.getName()); @@ -643,6 +702,7 @@ public void testFromResponseMap() throws IOException { metadataAsMap.put(KNNConstants.MODEL_DESCRIPTION, description); metadataAsMap.put(KNNConstants.MODEL_ERROR, error); metadataAsMap.put(KNNConstants.MODEL_NODE_ASSIGNMENT, nodeAssignment); + metadataAsMap.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()); XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); builder = methodComponentContext.toXContent(builder, ToXContent.EMPTY_PARAMS).endObject(); @@ -678,7 +738,8 @@ public void testBlockCommasInDescription() { description, error, nodeAssignment, - methodComponentContext + methodComponentContext, + VectorDataType.DEFAULT ) ); assertEquals("Model description cannot contain any commas: ','", e.getMessage()); diff --git a/src/test/java/org/opensearch/knn/indices/ModelTests.java b/src/test/java/org/opensearch/knn/indices/ModelTests.java index 13579acad..59bfe035f 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelTests.java @@ -15,6 +15,7 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import java.time.ZoneOffset; @@ -41,7 +42,8 @@ public void testInvalidConstructor() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), null, "test-model" @@ -62,7 +64,8 @@ public void testInvalidDimension() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), new byte[16], "test-model" @@ -80,7 +83,8 @@ public void testInvalidDimension() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), new byte[16], "test-model" @@ -98,7 +102,8 @@ public void testInvalidDimension() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), new byte[16], "test-model" @@ -117,7 +122,8 @@ public void testGetModelMetadata() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); Model model = new Model(modelMetadata, new byte[16], "test-model"); assertEquals(modelMetadata, model.getModelMetadata()); @@ -135,7 +141,8 @@ public void testGetModelBlob() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), modelBlob, "test-model" @@ -155,7 +162,8 @@ public void testGetLength() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), new byte[size], "test-model" @@ -172,7 +180,8 @@ public void testGetLength() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), null, "test-model" @@ -192,7 +201,8 @@ public void testSetModelBlob() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), blob1, "test-model" @@ -209,17 +219,50 @@ public void testEquals() { String time = ZonedDateTime.now(ZoneOffset.UTC).toString(); Model model1 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", "", MethodComponentContext.EMPTY), + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L1, + 2, + ModelState.CREATED, + time, + "", + "", + "", + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT + ), new byte[16], "test-model-1" ); Model model2 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", "", MethodComponentContext.EMPTY), + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L1, + 2, + ModelState.CREATED, + time, + "", + "", + "", + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT + ), new byte[16], "test-model-1" ); Model model3 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 2, ModelState.CREATED, time, "", "", "", MethodComponentContext.EMPTY), + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L2, + 2, + ModelState.CREATED, + time, + "", + "", + "", + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT + ), new byte[16], "test-model-2" ); @@ -234,17 +277,50 @@ public void testHashCode() { String time = ZonedDateTime.now(ZoneOffset.UTC).toString(); Model model1 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", "", MethodComponentContext.EMPTY), + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L1, + 2, + ModelState.CREATED, + time, + "", + "", + "", + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT + ), new byte[16], "test-model-1" ); Model model2 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", "", MethodComponentContext.EMPTY), + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L1, + 2, + ModelState.CREATED, + time, + "", + "", + "", + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT + ), new byte[16], "test-model-1" ); Model model3 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", "", MethodComponentContext.EMPTY), + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L1, + 2, + ModelState.CREATED, + time, + "", + "", + "", + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT + ), new byte[16], "test-model-2" ); @@ -274,7 +350,8 @@ public void testModelFromSourceMap() { description, error, nodeAssignment, - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); Map modelAsMap = new HashMap<>(); modelAsMap.put(KNNConstants.MODEL_ID, modelID); @@ -287,6 +364,7 @@ public void testModelFromSourceMap() { modelAsMap.put(KNNConstants.MODEL_ERROR, error); modelAsMap.put(KNNConstants.MODEL_NODE_ASSIGNMENT, nodeAssignment); modelAsMap.put(KNNConstants.MODEL_BLOB_PARAMETER, "aGVsbG8="); + modelAsMap.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()); byte[] blob1 = "hello".getBytes(); Model expected = new Model(metadata, blob1, modelID); diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java index d9949aaf2..5d16fe59d 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java @@ -36,7 +36,20 @@ import static org.opensearch.knn.TestUtils.KNN_VECTOR; import static org.opensearch.knn.TestUtils.PROPERTIES; import static org.opensearch.knn.TestUtils.VECTOR_TYPE; -import static org.opensearch.knn.common.KNNConstants.*; +import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; +import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.MIN_SCORE; +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; +import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.common.KNNConstants.NAME; /** * Integration tests to check the correctness of RestKNNStatsHandler diff --git a/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java b/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java index a6985e72a..d0dcf7429 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java @@ -20,6 +20,7 @@ import org.opensearch.knn.index.KNNClusterUtil; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelMetadata; @@ -43,7 +44,8 @@ private ModelMetadata getModelMetadata(ModelState state) { "test model", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); } @@ -68,7 +70,7 @@ public void testXContent() throws IOException { Model model = new Model(getModelMetadata(ModelState.CREATED), testModelBlob, modelId); GetModelResponse getModelResponse = new GetModelResponse(model); String expectedResponseString = - "{\"model_id\":\"test-model\",\"model_blob\":\"aGVsbG8=\",\"state\":\"created\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\",\"training_node_assignment\":\"\",\"model_definition\":{\"name\":\"\",\"parameters\":{}}}"; + "{\"model_id\":\"test-model\",\"model_blob\":\"aGVsbG8=\",\"state\":\"created\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\",\"training_node_assignment\":\"\",\"model_definition\":{\"name\":\"\",\"parameters\":{}},\"data_type\":\"float\"}"; XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); getModelResponse.toXContent(xContentBuilder, null); assertEquals(expectedResponseString, xContentBuilder.toString()); @@ -84,7 +86,7 @@ public void testXContentWithNoModelBlob() throws IOException { Model model = new Model(getModelMetadata(ModelState.FAILED), null, modelId); GetModelResponse getModelResponse = new GetModelResponse(model); String expectedResponseString = - "{\"model_id\":\"test-model\",\"model_blob\":\"\",\"state\":\"failed\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\",\"training_node_assignment\":\"\",\"model_definition\":{\"name\":\"\",\"parameters\":{}}}"; + "{\"model_id\":\"test-model\",\"model_blob\":\"\",\"state\":\"failed\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\",\"training_node_assignment\":\"\",\"model_definition\":{\"name\":\"\",\"parameters\":{}},\"data_type\":\"float\"}"; XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); getModelResponse.toXContent(xContentBuilder, null); assertEquals(expectedResponseString, xContentBuilder.toString()); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java index a2da83dad..381831fc7 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java @@ -19,6 +19,7 @@ import org.opensearch.knn.KNNSingleNodeTestCase; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelCache; @@ -78,7 +79,8 @@ public void testNodeOperation_modelInCache() throws ExecutionException, Interrup "description", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), new byte[128], modelId diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java index 56c50aca1..63c770a26 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java @@ -24,6 +24,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.transport.TransportService; @@ -307,7 +308,8 @@ public void testTrainingIndexSize() { trainingIndexName, "training-field", null, - "description" + "description", + VectorDataType.DEFAULT ); // Mock client to return the right number of docs @@ -339,6 +341,102 @@ public void testTrainingIndexSize() { transportAction.getTrainingIndexSizeInKB(trainingModelRequest, listener); } + public void testTrainIndexSize_whenDataTypeIsBinary() { + String trainingIndexName = "training-index"; + int dimension = 8; + int vectorCount = 1000000; + int expectedSize = Byte.BYTES * (dimension / 8) * vectorCount / BYTES_PER_KILOBYTES + 1; // 977 KB + + // Setup the request + TrainingModelRequest trainingModelRequest = new TrainingModelRequest( + null, + KNNMethodContext.getDefault(), + dimension, + trainingIndexName, + "training-field", + null, + "description", + VectorDataType.BINARY + ); + + // Mock client to return the right number of docs + TotalHits totalHits = new TotalHits(vectorCount, TotalHits.Relation.EQUAL_TO); + SearchHits searchHits = new SearchHits(new SearchHit[2], totalHits, 1.0f); + SearchResponse searchResponse = mock(SearchResponse.class); + when(searchResponse.getHits()).thenReturn(searchHits); + Client client = mock(Client.class); + + doAnswer(invocationOnMock -> { + ((ActionListener) invocationOnMock.getArguments()[1]).onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + // Setup the action + ClusterService clusterService = mock(ClusterService.class); + TransportService transportService = mock(TransportService.class); + TrainingJobRouterTransportAction transportAction = new TrainingJobRouterTransportAction( + transportService, + new ActionFilters(Collections.emptySet()), + clusterService, + client + ); + + ActionListener listener = ActionListener.wrap( + size -> assertEquals(expectedSize, size.intValue()), + e -> fail(e.getMessage()) + ); + + transportAction.getTrainingIndexSizeInKB(trainingModelRequest, listener); + } + + public void testTrainIndexSize_whenDataTypeIsByte() { + String trainingIndexName = "training-index"; + int dimension = 8; + int vectorCount = 1000000; + int expectedSize = Byte.BYTES * dimension * vectorCount / BYTES_PER_KILOBYTES + 1; // 7813 KB + + // Setup the request + TrainingModelRequest trainingModelRequest = new TrainingModelRequest( + null, + KNNMethodContext.getDefault(), + dimension, + trainingIndexName, + "training-field", + null, + "description", + VectorDataType.BYTE + ); + + // Mock client to return the right number of docs + TotalHits totalHits = new TotalHits(vectorCount, TotalHits.Relation.EQUAL_TO); + SearchHits searchHits = new SearchHits(new SearchHit[2], totalHits, 1.0f); + SearchResponse searchResponse = mock(SearchResponse.class); + when(searchResponse.getHits()).thenReturn(searchHits); + Client client = mock(Client.class); + + doAnswer(invocationOnMock -> { + ((ActionListener) invocationOnMock.getArguments()[1]).onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + // Setup the action + ClusterService clusterService = mock(ClusterService.class); + TransportService transportService = mock(TransportService.class); + TrainingJobRouterTransportAction transportAction = new TrainingJobRouterTransportAction( + transportService, + new ActionFilters(Collections.emptySet()), + clusterService, + client + ); + + ActionListener listener = ActionListener.wrap( + size -> assertEquals(expectedSize, size.intValue()), + e -> fail(e.getMessage()) + ); + + transportAction.getTrainingIndexSizeInKB(trainingModelRequest, listener); + } + private Map generateDiscoveryNodes(List dataNodeIds) { Map nodes = new HashMap<>(); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java index b39c48635..9434a6e41 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java @@ -25,6 +25,7 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.MethodComponentContext; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.util.KNNEngine; @@ -61,7 +62,8 @@ public void testStreams() throws IOException { trainingIndex, trainingField, preferredNode, - description + description, + VectorDataType.DEFAULT ); BytesStreamOutput streamOutput = new BytesStreamOutput(); @@ -74,6 +76,7 @@ public void testStreams() throws IOException { assertEquals(original1.getTrainingIndex(), copy1.getTrainingIndex()); assertEquals(original1.getTrainingField(), copy1.getTrainingField()); assertEquals(original1.getPreferredNodeId(), copy1.getPreferredNodeId()); + assertEquals(original1.getVectorDataType(), copy1.getVectorDataType()); // Also, check when preferred node and model id and description are null TrainingModelRequest original2 = new TrainingModelRequest( @@ -83,7 +86,8 @@ public void testStreams() throws IOException { trainingIndex, trainingField, null, - null + null, + VectorDataType.DEFAULT ); streamOutput = new BytesStreamOutput(); @@ -96,6 +100,7 @@ public void testStreams() throws IOException { assertEquals(original2.getTrainingIndex(), copy2.getTrainingIndex()); assertEquals(original2.getTrainingField(), copy2.getTrainingField()); assertEquals(original2.getPreferredNodeId(), copy2.getPreferredNodeId()); + assertEquals(original2.getVectorDataType(), copy2.getVectorDataType()); } public void testGetters() { @@ -117,7 +122,8 @@ public void testGetters() { trainingIndex, trainingField, preferredNode, - description + description, + VectorDataType.DEFAULT ); trainingModelRequest.setMaximumVectorCount(maxVectorCount); @@ -156,7 +162,8 @@ public void testValidation_invalid_modelIdAlreadyExists() { trainingIndex, trainingField, null, - null + null, + VectorDataType.DEFAULT ); // Mock the model dao to return metadata for modelId to recognize it is a duplicate @@ -170,7 +177,8 @@ public void testValidation_invalid_modelIdAlreadyExists() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); @@ -211,7 +219,8 @@ public void testValidation_blocked_modelId() { trainingIndex, trainingField, null, - null + null, + VectorDataType.DEFAULT ); // Mock the model dao to return true to recognize that the modelId is in graveyard @@ -257,7 +266,8 @@ public void testValidation_invalid_invalidMethodContext() { trainingIndex, trainingField, null, - null + null, + VectorDataType.DEFAULT ); // Mock the model dao to return null so that no exception is produced @@ -300,7 +310,8 @@ public void testValidation_invalid_trainingIndexDoesNotExist() { trainingIndex, trainingField, null, - null + null, + VectorDataType.DEFAULT ); // Mock the model dao to return null so that no exception is produced @@ -346,7 +357,8 @@ public void testValidation_invalid_trainingFieldDoesNotExist() { trainingIndex, trainingField, null, - null + null, + VectorDataType.DEFAULT ); // Mock the model dao to return null so that no exception is produced @@ -397,7 +409,8 @@ public void testValidation_invalid_trainingFieldNotKnnVector() { trainingIndex, trainingField, null, - null + null, + VectorDataType.DEFAULT ); // Mock the model dao to return null so that no exception is produced @@ -452,7 +465,8 @@ public void testValidation_invalid_dimensionDoesNotMatch() { trainingIndex, trainingField, null, - null + null, + VectorDataType.DEFAULT ); // Mock the model dao to return null so that no exception is produced @@ -509,7 +523,8 @@ public void testValidation_invalid_preferredNodeDoesNotExist() { trainingIndex, trainingField, preferredNode, - null + null, + VectorDataType.DEFAULT ); // Mock the model dao to return metadata for modelId to recognize it is a duplicate @@ -574,7 +589,8 @@ public void testValidation_invalid_descriptionToLong() { trainingIndex, trainingField, null, - description + description, + VectorDataType.DEFAULT ); // Mock the model dao to return metadata for modelId to recognize it is a duplicate @@ -618,7 +634,8 @@ public void testValidation_valid_trainingIndexBuiltFromMethod() { trainingIndex, trainingField, null, - null + null, + VectorDataType.DEFAULT ); // Mock the model dao to return metadata for modelId to recognize it is a duplicate @@ -655,7 +672,8 @@ public void testValidation_valid_trainingIndexBuiltFromModel() { trainingIndex, trainingField, null, - null + null, + VectorDataType.DEFAULT ); // Mock the model dao to return metadata for modelId to recognize it is a duplicate diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelTransportActionTests.java index 950ce1fd0..9ca790350 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelTransportActionTests.java @@ -17,6 +17,7 @@ import org.opensearch.knn.KNNSingleNodeTestCase; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; @@ -72,9 +73,10 @@ public void testDoExecute() throws InterruptedException, ExecutionException, IOE trainingIndexName, trainingFieldName, null, - "test-detector" + "test-detector", + VectorDataType.DEFAULT ); - trainingModelRequest.setTrainingDataSizeInKB(estimateVectorSetSizeInKB(trainingDataCount, dimension)); + trainingModelRequest.setTrainingDataSizeInKB(estimateVectorSetSizeInKB(trainingDataCount, dimension, VectorDataType.DEFAULT)); // Create listener that ensures that the initial model put succeeds ActionListener listener = ActionListener.wrap( diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java index 5be907ebd..bad8d368b 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java @@ -17,6 +17,7 @@ import org.opensearch.knn.common.exception.DeleteModelException; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelGraveyard; @@ -210,7 +211,8 @@ public void testClusterManagerOperation_GetIndicesUsingModel() throws IOExceptio "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), modelBlob, modelId diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java index 3719d124a..d2291c4ea 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java @@ -15,6 +15,7 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; @@ -42,7 +43,8 @@ public void testStreams() throws IOException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); UpdateModelMetadataRequest updateModelMetadataRequest = new UpdateModelMetadataRequest(modelId, isRemoveRequest, modelMetadata); @@ -67,7 +69,8 @@ public void testValidate() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); UpdateModelMetadataRequest updateModelMetadataRequest1 = new UpdateModelMetadataRequest("test", true, null); @@ -107,7 +110,8 @@ public void testGetModelMetadata() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); UpdateModelMetadataRequest updateModelMetadataRequest = new UpdateModelMetadataRequest("test", true, modelMetadata); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java index ab0e4f506..c35c7effb 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java @@ -19,6 +19,7 @@ import org.opensearch.knn.KNNSingleNodeTestCase; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; @@ -68,7 +69,8 @@ public void testClusterManagerOperation() throws InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ); // Get update transport action diff --git a/src/test/java/org/opensearch/knn/training/TrainingDataConsumerTests.java b/src/test/java/org/opensearch/knn/training/FloatTrainingDataConsumerTests.java similarity index 88% rename from src/test/java/org/opensearch/knn/training/TrainingDataConsumerTests.java rename to src/test/java/org/opensearch/knn/training/FloatTrainingDataConsumerTests.java index d5a66c5b6..27e02b46b 100644 --- a/src/test/java/org/opensearch/knn/training/TrainingDataConsumerTests.java +++ b/src/test/java/org/opensearch/knn/training/FloatTrainingDataConsumerTests.java @@ -23,7 +23,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -public class TrainingDataConsumerTests extends KNNTestCase { +public class FloatTrainingDataConsumerTests extends KNNTestCase { public void testAccept() { @@ -38,7 +38,7 @@ public void testAccept() { // Capture argument passed to set pointer ArgumentCaptor valueCapture = ArgumentCaptor.forClass(Long.class); - TrainingDataConsumer trainingDataConsumer = new TrainingDataConsumer(trainingDataAllocation); + FloatTrainingDataConsumer floatTrainingDataConsumer = new FloatTrainingDataConsumer(trainingDataAllocation); List vectorSet1 = new ArrayList<>(3); for (int i = 0; i < 3; i++) { @@ -47,10 +47,8 @@ public void testAccept() { vectorSet1.add(vector); } - when(trainingDataAllocation.getMemoryAddress()).thenReturn(0L); - // Transfer vectors - trainingDataConsumer.accept(vectorSet1); + floatTrainingDataConsumer.accept(vectorSet1); // Ensure that the pointer captured has been updated verify(trainingDataAllocation).setMemoryAddress(valueCapture.capture()); diff --git a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java index 06b96c57c..0852c39de 100644 --- a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java +++ b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java @@ -18,6 +18,7 @@ import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.memory.NativeMemoryAllocation; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryEntryContext; @@ -67,7 +68,8 @@ public void testGetModelId() { mock(NativeMemoryEntryContext.AnonymousEntryContext.class), 10, "", - "test-node" + "test-node", + VectorDataType.DEFAULT ); assertEquals(modelId, trainingJob.getModelId()); @@ -96,7 +98,8 @@ public void testGetModel() { mock(NativeMemoryEntryContext.AnonymousEntryContext.class), dimension, description, - nodeAssignment + nodeAssignment, + VectorDataType.DEFAULT ); Model model = new Model( @@ -109,7 +112,8 @@ public void testGetModel() { description, error, nodeAssignment, - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT ), null, modelID @@ -183,8 +187,8 @@ public void testRun_success() throws IOException, ExecutionException { modelContext, dimension, "", - "test-node" - + "test-node", + VectorDataType.DEFAULT ); trainingJob.run(); @@ -262,8 +266,8 @@ public void testRun_failure_onGetTrainingDataAllocation() throws ExecutionExcept modelContext, dimension, "", - - "test-node" + "test-node", + VectorDataType.DEFAULT ); trainingJob.run(); @@ -330,8 +334,8 @@ public void testRun_failure_onGetModelAnonymousAllocation() throws ExecutionExce modelContext, dimension, "", - - "test-node" + "test-node", + VectorDataType.DEFAULT ); trainingJob.run(); @@ -397,7 +401,8 @@ public void testRun_failure_closedTrainingDataAllocation() throws ExecutionExcep mock(NativeMemoryEntryContext.AnonymousEntryContext.class), dimension, "", - "test-node" + "test-node", + VectorDataType.DEFAULT ); trainingJob.run(); @@ -470,7 +475,8 @@ public void testRun_failure_notEnoughTrainingData() throws ExecutionException { modelContext, dimension, "", - "test-node" + "test-node", + VectorDataType.DEFAULT ); trainingJob.run(); diff --git a/src/test/java/org/opensearch/knn/training/VectorReaderTests.java b/src/test/java/org/opensearch/knn/training/VectorReaderTests.java index 209c9cc73..b69b43a39 100644 --- a/src/test/java/org/opensearch/knn/training/VectorReaderTests.java +++ b/src/test/java/org/opensearch/knn/training/VectorReaderTests.java @@ -8,25 +8,23 @@ * Modifications Copyright OpenSearch Contributors. See * GitHub history for details. */ - package org.opensearch.knn.training; +import lombok.Getter; import org.opensearch.core.action.ActionListener; import org.opensearch.action.search.SearchResponse; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.ValidationException; import org.opensearch.knn.KNNSingleNodeTestCase; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.memory.NativeMemoryAllocation; +import org.opensearch.common.ValidationException; +import org.opensearch.search.SearchHit; import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashSet; -import java.util.List; -import java.util.Random; +import java.util.*; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; -import java.util.function.Consumer; import java.util.stream.Collectors; public class VectorReaderTests extends KNNSingleNodeTestCase { @@ -36,9 +34,9 @@ public class VectorReaderTests extends KNNSingleNodeTestCase { private final static String DEFAULT_FIELD_NAME = "test-field"; private final static String DEFAULT_NESTED_FIELD_PATH = "a.b.test-field"; private final static int DEFAULT_DIMENSION = 16; - private final static int DEFAULT_NUM_VECTORS = 100; + private final static int DEFAULT_NUM_VECTORS = 50; private final static int DEFAULT_MAX_VECTOR_COUNT = 10000; - private final static int DEFAULT_SEARCH_SIZE = 10; + private final static int DEFAULT_SEARCH_SIZE = 120; public void testRead_valid_completeIndex() throws InterruptedException, ExecutionException, IOException { createIndex(DEFAULT_INDEX_NAME); @@ -56,9 +54,9 @@ public void testRead_valid_completeIndex() throws InterruptedException, Executio // Configure VectorReader ClusterService clusterService = node().injector().getInstance(ClusterService.class); VectorReader vectorReader = new VectorReader(client()); + TestFloatTrainingDataConsumer trainingDataConsumer = new TestFloatTrainingDataConsumer(createMockTrainingDataAllocation()); // Read all vectors and confirm they match vectors - TestVectorConsumer testVectorConsumer = new TestVectorConsumer(); final CountDownLatch inProgressLatch = new CountDownLatch(1); vectorReader.read( clusterService, @@ -66,14 +64,14 @@ public void testRead_valid_completeIndex() throws InterruptedException, Executio DEFAULT_FIELD_NAME, DEFAULT_MAX_VECTOR_COUNT, DEFAULT_SEARCH_SIZE, - testVectorConsumer, + trainingDataConsumer, createOnSearchResponseCountDownListener(inProgressLatch) ); assertLatchDecremented(inProgressLatch); - List consumedVectors = testVectorConsumer.getVectorsConsumed(); - assertEquals(DEFAULT_NUM_VECTORS, consumedVectors.size()); + List consumedVectors = trainingDataConsumer.getTotalAddedVectors(); + assertEquals(DEFAULT_NUM_VECTORS, trainingDataConsumer.getTotalVectorsCountAdded()); List flatVectors = vectors.stream().flatMap(Arrays::stream).collect(Collectors.toList()); List flatConsumedVectors = consumedVectors.stream().flatMap(Arrays::stream).collect(Collectors.toList()); @@ -98,21 +96,21 @@ public void testRead_valid_trainVectorsIngestedAsIntegers() throws IOException, VectorReader vectorReader = new VectorReader(client()); // Read all vectors and confirm they match vectors - TestVectorConsumer testVectorConsumer = new TestVectorConsumer(); final CountDownLatch inProgressLatch = new CountDownLatch(1); + TestFloatTrainingDataConsumer trainingDataConsumer = new TestFloatTrainingDataConsumer(createMockTrainingDataAllocation()); vectorReader.read( clusterService, DEFAULT_INDEX_NAME, DEFAULT_FIELD_NAME, DEFAULT_MAX_VECTOR_COUNT, DEFAULT_SEARCH_SIZE, - testVectorConsumer, + trainingDataConsumer, createOnSearchResponseCountDownListener(inProgressLatch) ); assertLatchDecremented(inProgressLatch); - List consumedVectors = testVectorConsumer.getVectorsConsumed(); + List consumedVectors = trainingDataConsumer.getTotalAddedVectors(); assertEquals(DEFAULT_NUM_VECTORS, consumedVectors.size()); List flatVectors = vectors.stream().flatMap(Arrays::stream).map(Integer::floatValue).collect(Collectors.toList()); @@ -149,21 +147,22 @@ public void testRead_valid_incompleteIndex() throws InterruptedException, Execut VectorReader vectorReader = new VectorReader(client()); // Read all vectors and confirm they match vectors - TestVectorConsumer testVectorConsumer = new TestVectorConsumer(); final CountDownLatch inProgressLatch = new CountDownLatch(1); + TestFloatTrainingDataConsumer trainingDataConsumer = new TestFloatTrainingDataConsumer(createMockTrainingDataAllocation()); + vectorReader.read( clusterService, DEFAULT_INDEX_NAME, DEFAULT_FIELD_NAME, DEFAULT_MAX_VECTOR_COUNT, DEFAULT_SEARCH_SIZE, - testVectorConsumer, + trainingDataConsumer, createOnSearchResponseCountDownListener(inProgressLatch) ); assertLatchDecremented(inProgressLatch); - List consumedVectors = testVectorConsumer.getVectorsConsumed(); + List consumedVectors = trainingDataConsumer.getTotalAddedVectors(); assertEquals(DEFAULT_NUM_VECTORS, consumedVectors.size()); List flatVectors = vectors.stream().flatMap(Arrays::stream).collect(Collectors.toList()); @@ -192,21 +191,21 @@ public void testRead_valid_OnlyGetMaxVectors() throws InterruptedException, Exec VectorReader vectorReader = new VectorReader(client()); // Read maxNumVectorsRead vectors - TestVectorConsumer testVectorConsumer = new TestVectorConsumer(); final CountDownLatch inProgressLatch = new CountDownLatch(1); + TestFloatTrainingDataConsumer trainingDataConsumer = new TestFloatTrainingDataConsumer(createMockTrainingDataAllocation()); vectorReader.read( clusterService, DEFAULT_INDEX_NAME, DEFAULT_FIELD_NAME, maxNumVectorsRead, DEFAULT_SEARCH_SIZE, - testVectorConsumer, + trainingDataConsumer, createOnSearchResponseCountDownListener(inProgressLatch) ); assertLatchDecremented(inProgressLatch); - List consumedVectors = testVectorConsumer.getVectorsConsumed(); + List consumedVectors = trainingDataConsumer.getTotalAddedVectors(); assertEquals(maxNumVectorsRead, consumedVectors.size()); } @@ -364,21 +363,21 @@ public void testRead_valid_NestedField() throws InterruptedException, ExecutionE VectorReader vectorReader = new VectorReader(client()); // Read all vectors and confirm they match vectors - TestVectorConsumer testVectorConsumer = new TestVectorConsumer(); final CountDownLatch inProgressLatch = new CountDownLatch(1); + TestFloatTrainingDataConsumer trainingDataConsumer = new TestFloatTrainingDataConsumer(createMockTrainingDataAllocation()); vectorReader.read( clusterService, DEFAULT_INDEX_NAME, DEFAULT_NESTED_FIELD_PATH, DEFAULT_MAX_VECTOR_COUNT, DEFAULT_SEARCH_SIZE, - testVectorConsumer, + trainingDataConsumer, createOnSearchResponseCountDownListener(inProgressLatch) ); assertLatchDecremented(inProgressLatch); - List consumedVectors = testVectorConsumer.getVectorsConsumed(); + List consumedVectors = trainingDataConsumer.getTotalAddedVectors(); assertEquals(DEFAULT_NUM_VECTORS, consumedVectors.size()); List flatVectors = vectors.stream().flatMap(Arrays::stream).collect(Collectors.toList()); @@ -386,29 +385,47 @@ public void testRead_valid_NestedField() throws InterruptedException, ExecutionE assertEquals(new HashSet<>(flatVectors), new HashSet<>(flatConsumedVectors)); } - private static class TestVectorConsumer implements Consumer> { + private void assertLatchDecremented(CountDownLatch countDownLatch) throws InterruptedException { + assertTrue(countDownLatch.await(DEFAULT_LATCH_TIMEOUT, TimeUnit.SECONDS)); + } + + private ActionListener createOnSearchResponseCountDownListener(CountDownLatch countDownLatch) { + return ActionListener.wrap(response -> countDownLatch.countDown(), Throwable::printStackTrace); + } + + private NativeMemoryAllocation.TrainingDataAllocation createMockTrainingDataAllocation() { + return new NativeMemoryAllocation.TrainingDataAllocation(null, 0, 0, VectorDataType.FLOAT); + } - List vectorsConsumed; + // create test float training data consumer class extending FloatTrainingDataConsumer + private static class TestFloatTrainingDataConsumer extends FloatTrainingDataConsumer { + @Getter + private List totalAddedVectors = new ArrayList<>(); - TestVectorConsumer() { - vectorsConsumed = new ArrayList<>(); + public TestFloatTrainingDataConsumer(NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation) { + super(trainingDataAllocation); } @Override - public void accept(List vectors) { - vectorsConsumed.addAll(vectors); - } + public void processTrainingVectors(SearchResponse searchResponse, int vectorsToAdd, String fieldName) { + SearchHit[] hits = searchResponse.getHits().getHits(); + List vectors = new ArrayList<>(); - public List getVectorsConsumed() { - return vectorsConsumed; - } - } + String[] fieldPath = fieldName.split("\\."); - private void assertLatchDecremented(CountDownLatch countDownLatch) throws InterruptedException { - assertTrue(countDownLatch.await(DEFAULT_LATCH_TIMEOUT, TimeUnit.SECONDS)); - } + for (int vector = 0; vector < vectorsToAdd; vector++) { + Object fieldValue = extractFieldValue(hits[vector], fieldPath); + if (!(fieldValue instanceof List)) { + continue; + } - private ActionListener createOnSearchResponseCountDownListener(CountDownLatch countDownLatch) { - return ActionListener.wrap(response -> countDownLatch.countDown(), Throwable::printStackTrace); + List fieldList = (List) fieldValue; + vectors.add(fieldList.stream().map(Number::floatValue).toArray(Float[]::new)); + } + + totalAddedVectors.addAll(vectors); + setTotalVectorsCountAdded(getTotalVectorsCountAdded() + vectors.size()); + accept(vectors); + } } } diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 860cd2efa..1a17559a8 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -6,6 +6,7 @@ package org.opensearch.knn; import com.google.common.primitives.Floats; +import com.google.common.primitives.Ints; import lombok.SneakyThrows; import lombok.extern.log4j.Log4j2; import org.apache.commons.lang.StringUtils; @@ -1000,6 +1001,28 @@ public void bulkIngestRandomVectors(String indexName, String fieldName, int numV } + /** + * Bulk ingest random binary vectors + * @param indexName index name + * @param fieldName field name + * @param numVectors number of vectors + * @param dimension vector dimension + */ + public void bulkIngestRandomBinaryVectors(String indexName, String fieldName, int numVectors, int dimension) throws IOException { + if (dimension % 8 != 0) { + throw new IllegalArgumentException("Dimension must be a multiple of 8"); + } + for (int i = 0; i < numVectors; i++) { + int binaryDimension = dimension / 8; + int[] vector = new int[binaryDimension]; + for (int j = 0; j < binaryDimension; j++) { + vector[j] = randomIntBetween(-128, 127); + } + + addKnnDoc(indexName, String.valueOf(i + 1), fieldName, Ints.asList(vector).toArray()); + } + } + /** * Bulk ingest random vectors with nested field * @@ -1337,6 +1360,18 @@ public Response trainModel( return client().performRequest(request); } + public Response trainModel(String modelId, XContentBuilder builder) throws IOException { + if (modelId == null) { + modelId = ""; + } else { + modelId = "/" + modelId; + } + + Request request = new Request("POST", "/_plugins/_knn/models" + modelId + "/_train"); + request.setJsonEntity(builder.toString()); + return client().performRequest(request); + } + /** * Retrieve the model *