Skip to content

Commit

Permalink
Fix unit tests failures
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Jul 9, 2024
1 parent 70cb23f commit 404056a
Show file tree
Hide file tree
Showing 23 changed files with 342 additions and 144 deletions.
4 changes: 2 additions & 2 deletions src/main/java/org/opensearch/knn/indices/ModelMetadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ public static ModelMetadata fromString(String modelMetadataString) {
String timestamp = modelMetadataArray[4];
String description = modelMetadataArray[5];
String error = modelMetadataArray[6];
VectorDataType vectorDataType = VectorDataType.get(modelMetadataArray[9]);
VectorDataType vectorDataType = VectorDataType.get(modelMetadataArray[7]);
return new ModelMetadata(
knnEngine,
spaceType,
Expand All @@ -339,7 +339,7 @@ public static ModelMetadata fromString(String modelMetadataString) {
String description = modelMetadataArray[5];
String error = modelMetadataArray[6];
String trainingNodeAssignment = modelMetadataArray[7];
VectorDataType vectorDataType = VectorDataType.get(modelMetadataArray[10]);
VectorDataType vectorDataType = VectorDataType.get(modelMetadataArray[8]);
return new ModelMetadata(
knnEngine,
spaceType,
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/org/opensearch/knn/jni/JNIService.java
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public static void createIndexFromTemplate(
KNNEngine knnEngine
) {
if (KNNEngine.FAISS == knnEngine) {
if (faissUtil.isBinaryIndex(parameters)) {
if (IndexUtil.isBinaryIndex(knnEngine, parameters)) {
FaissService.createBinaryIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters);
return;
} else {
Expand Down Expand Up @@ -313,7 +313,7 @@ public static void freeSharedIndexState(long shareIndexStateAddr, KNNEngine knnE
*/
public static byte[] trainIndex(Map<String, Object> indexParameters, int dimension, long trainVectorsPointer, KNNEngine knnEngine) {
if (KNNEngine.FAISS == knnEngine) {
if (faissUtil.isBinaryIndex(indexParameters)) {
if (IndexUtil.isBinaryIndex(knnEngine, indexParameters)) {
return FaissService.trainBinaryIndex(indexParameters, dimension, trainVectorsPointer);
} else {
return FaissService.trainIndex(indexParameters, dimension, trainVectorsPointer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,5 +345,6 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeInt(this.maximumVectorCount);
out.writeInt(this.searchSize);
out.writeInt(this.trainingDataSizeInKB);
out.writeOptionalString(this.vectorDataType.getValue());
}
}
3 changes: 2 additions & 1 deletion src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ protected void writeModelToModelSystemIndex(Model model) throws IOException, Exe
.field(MODEL_STATE, modelMetadata.getState().getName())
.field(MODEL_TIMESTAMP, modelMetadata.getTimestamp().toString())
.field(MODEL_DESCRIPTION, modelMetadata.getDescription())
.field(MODEL_ERROR, modelMetadata.getError());
.field(MODEL_ERROR, modelMetadata.getError())
.field(VECTOR_DATA_TYPE_FIELD, modelMetadata.getVectorDataType().getValue());

if (model.getModelBlob() != null) {
builder.field(MODEL_BLOB_PARAMETER, Base64.getEncoder().encodeToString(model.getModelBlob()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ public void testCreateIndexFromModel() throws IOException, InterruptedException
"",
"",
"test-node",
MethodComponentContext.EMPTY
MethodComponentContext.EMPTY,
VectorDataType.FLOAT
);

Model model = new Model(modelMetadata, modelBlob, modelId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,8 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio
"Empty description",
"",
"",
MethodComponentContext.EMPTY
MethodComponentContext.EMPTY,
VectorDataType.FLOAT
),
modelBytes,
modelId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,12 @@
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.*;
import org.opensearch.knn.index.query.KNNQueryFactory;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.index.query.KNNQuery;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.query.KNNWeight;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorField;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.FieldType;
Expand Down Expand Up @@ -213,7 +209,8 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio
"",
"",
"",
MethodComponentContext.EMPTY
MethodComponentContext.EMPTY,
VectorDataType.FLOAT
);

Model mockModel = new Model(modelMetadata1, modelBlob, modelId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ public void testBuilder_build_fromModel() {
"",
"",
"",
MethodComponentContext.EMPTY
MethodComponentContext.EMPTY,
VectorDataType.FLOAT
);
builder.modelId.setValue(modelId);
Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath());
Expand Down Expand Up @@ -674,7 +675,8 @@ public void testKNNVectorFieldMapper_merge_fromModel() throws IOException {
"",
"",
"",
MethodComponentContext.EMPTY
MethodComponentContext.EMPTY,
VectorDataType.FLOAT
);
when(mockModelDao.getMetadata(modelId)).thenReturn(mockModelMetadata);

Expand Down Expand Up @@ -745,7 +747,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() {
parseContext,
TEST_DIMENSION,
luceneFieldMapper.fieldType().spaceType,
luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext()
luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(),
VectorDataType.FLOAT
);

// Document should have 2 fields: one for VectorField (binary doc values) and one for KnnVectorField
Expand Down Expand Up @@ -789,7 +792,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() {
parseContext,
TEST_DIMENSION,
luceneFieldMapper.fieldType().spaceType,
luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext()
luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(),
VectorDataType.FLOAT
);

// Document should have 1 field: one for KnnVectorField
Expand Down Expand Up @@ -824,7 +828,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() {
parseContext,
TEST_DIMENSION,
luceneFieldMapper.fieldType().spaceType,
luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext()
luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(),
VectorDataType.BYTE
);

// Document should have 2 fields: one for VectorField (binary doc values) and one for KnnByteVectorField
Expand Down Expand Up @@ -867,7 +872,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() {
parseContext,
TEST_DIMENSION,
luceneFieldMapper.fieldType().spaceType,
luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext()
luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(),
VectorDataType.BYTE
);

// Document should have 1 field: one for KnnByteVectorField
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,7 @@ public void testDoToQuery_FromModel() {
when(modelMetadata.getSpaceType()).thenReturn(SpaceType.COSINESIMIL);
when(modelMetadata.getState()).thenReturn(ModelState.CREATED);
when(modelMetadata.getMethodComponentContext()).thenReturn(new MethodComponentContext("ivf", emptyMap()));
when(modelMetadata.getVectorDataType()).thenReturn(VectorDataType.FLOAT);
ModelDao modelDao = mock(ModelDao.class);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);
KNNQueryBuilder.initialize(modelDao);
Expand Down Expand Up @@ -940,6 +941,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold
when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2);
when(modelMetadata.getState()).thenReturn(ModelState.CREATED);
when(modelMetadata.getMethodComponentContext()).thenReturn(new MethodComponentContext("ivf", emptyMap()));
when(modelMetadata.getVectorDataType()).thenReturn(VectorDataType.FLOAT);
ModelDao modelDao = mock(ModelDao.class);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);
KNNQueryBuilder.initialize(modelDao);
Expand Down Expand Up @@ -975,6 +977,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_th
when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS);
when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2);
when(modelMetadata.getState()).thenReturn(ModelState.CREATED);
when(modelMetadata.getVectorDataType()).thenReturn(VectorDataType.FLOAT);
when(modelMetadata.getMethodComponentContext()).thenReturn(new MethodComponentContext("ivf", emptyMap()));
ModelDao modelDao = mock(ModelDao.class);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);
Expand Down
11 changes: 5 additions & 6 deletions src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.opensearch.common.unit.TimeValue;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.KNNCodecVersion;
Expand All @@ -62,6 +63,7 @@
import java.util.function.Function;
import java.util.stream.Collectors;

import static java.util.Collections.emptyMap;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
Expand All @@ -75,12 +77,7 @@
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.when;
import static org.opensearch.knn.KNNRestTestCase.INDEX_NAME;
import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.*;

public class KNNWeightTests extends KNNTestCase {
private static final String FIELD_NAME = "target_field";
Expand Down Expand Up @@ -199,6 +196,8 @@ public void testQueryScoreForFaissWithModel() {
when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS);
when(modelMetadata.getSpaceType()).thenReturn(spaceType);
when(modelMetadata.getState()).thenReturn(ModelState.CREATED);
when(modelMetadata.getVectorDataType()).thenReturn(VectorDataType.FLOAT);
when(modelMetadata.getMethodComponentContext()).thenReturn(new MethodComponentContext("ivf", emptyMap()));
when(modelDao.getMetadata(eq("modelId"))).thenReturn(modelMetadata);

KNNWeight.initialize(modelDao);
Expand Down
38 changes: 25 additions & 13 deletions src/test/java/org/opensearch/knn/indices/ModelCacheTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.util.KNNEngine;

import java.time.ZoneOffset;
Expand All @@ -45,7 +46,8 @@ public void testGet_normal() throws ExecutionException, InterruptedException {
"",
"",
"",
MethodComponentContext.EMPTY
MethodComponentContext.EMPTY,
VectorDataType.FLOAT
),
"hello".getBytes(),
modelId
Expand Down Expand Up @@ -82,7 +84,8 @@ public void testGet_modelDoesNotFitInCache() throws ExecutionException, Interrup
"",
"",
"",
MethodComponentContext.EMPTY
MethodComponentContext.EMPTY,
VectorDataType.FLOAT
),
new byte[BYTES_PER_KILOBYTES + 1],
modelId
Expand Down Expand Up @@ -140,7 +143,8 @@ public void testGetTotalWeight() throws ExecutionException, InterruptedException
"",
"",
"",
MethodComponentContext.EMPTY
MethodComponentContext.EMPTY,
VectorDataType.FLOAT
),
new byte[size1],
modelId1
Expand All @@ -156,7 +160,8 @@ public void testGetTotalWeight() throws ExecutionException, InterruptedException
"",
"",
"",
MethodComponentContext.EMPTY
MethodComponentContext.EMPTY,
VectorDataType.FLOAT
),
new byte[size2],
modelId2
Expand Down Expand Up @@ -200,7 +205,8 @@ public void testRemove_normal() throws ExecutionException, InterruptedException
"",
"",
"",
MethodComponentContext.EMPTY
MethodComponentContext.EMPTY,
VectorDataType.FLOAT
),
new byte[size1],
modelId1
Expand All @@ -216,8 +222,8 @@ public void testRemove_normal() throws ExecutionException, InterruptedException
"",
"",
"",
MethodComponentContext.EMPTY

MethodComponentContext.EMPTY,
VectorDataType.FLOAT
),
new byte[size2],
modelId2
Expand Down Expand Up @@ -266,7 +272,8 @@ public void testRebuild_normal() throws ExecutionException, InterruptedException
"",
"",
"",
MethodComponentContext.EMPTY
MethodComponentContext.EMPTY,
VectorDataType.FLOAT
),
"hello".getBytes(),
modelId
Expand Down Expand Up @@ -312,7 +319,8 @@ public void testRebuild_afterSettingUpdate() throws ExecutionException, Interrup
"",
"",
"",
MethodComponentContext.EMPTY
MethodComponentContext.EMPTY,
VectorDataType.FLOAT
),
new byte[modelSize],
modelId
Expand Down Expand Up @@ -381,7 +389,8 @@ public void testContains() throws ExecutionException, InterruptedException {
"",
"",
"",
MethodComponentContext.EMPTY
MethodComponentContext.EMPTY,
VectorDataType.FLOAT
),
new byte[modelSize1],
modelId1
Expand Down Expand Up @@ -423,7 +432,8 @@ public void testRemoveAll() throws ExecutionException, InterruptedException {
"",
"",
"",
MethodComponentContext.EMPTY
MethodComponentContext.EMPTY,
VectorDataType.FLOAT
),
new byte[modelSize1],
modelId1
Expand All @@ -441,7 +451,8 @@ public void testRemoveAll() throws ExecutionException, InterruptedException {
"",
"",
"",
MethodComponentContext.EMPTY
MethodComponentContext.EMPTY,
VectorDataType.FLOAT
),
new byte[modelSize2],
modelId2
Expand Down Expand Up @@ -487,7 +498,8 @@ public void testModelCacheEvictionDueToSize() throws ExecutionException, Interru
"",
"",
"",
MethodComponentContext.EMPTY
MethodComponentContext.EMPTY,
VectorDataType.FLOAT
),
new byte[BYTES_PER_KILOBYTES * 2],
modelId
Expand Down
Loading

0 comments on commit 404056a

Please sign in to comment.