Skip to content

Commit

Permalink
Fix update document with knnn_vector size not matching issue
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Jul 5, 2023
1 parent 7ee1eb3 commit 4af436a
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 39 deletions.
1 change: 0 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ dependencies {
zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}"
compileOnly fileTree(dir: knnJarDirectory, include: '*.jar')
api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}"
implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.10'
// ml-common excluded reflection for runtime so we need to add it by ourselves.
// https://github.com/opensearch-project/ml-commons/commit/464bfe34c66d7a729a00dd457f03587ea4e504d9
// TODO: Remove following three lines of dependencies if ml-common include them in their jar
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,9 @@

package org.opensearch.neuralsearch.processor;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import java.util.stream.IntStream;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import lombok.extern.log4j.Log4j2;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.action.ActionListener;
import org.opensearch.env.Environment;
Expand All @@ -24,8 +16,14 @@
import org.opensearch.ingest.IngestDocument;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import java.util.stream.IntStream;

/**
* This processor is used for user input data text embedding processing, model_id can be used to indicate which model user use,
Expand Down Expand Up @@ -119,7 +117,7 @@ void appendVectorFieldsToDocument(IngestDocument ingestDocument, Map<String, Obj
Objects.requireNonNull(vectors, "embedding failed, inference returns null result!");
log.debug("Text embedding result fetched, starting build vector output!");
Map<String, Object> textEmbeddingResult = buildTextEmbeddingResult(knnMap, vectors, ingestDocument.getSourceAndMetadata());
textEmbeddingResult.forEach(ingestDocument::appendFieldValue);
textEmbeddingResult.forEach(ingestDocument::setFieldValue);
}

@SuppressWarnings({ "unchecked" })
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,7 @@

package org.opensearch.neuralsearch.ml;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import com.google.common.collect.ImmutableMap;
import org.junit.Before;
import org.mockito.InjectMocks;
import org.mockito.Mock;
Expand All @@ -32,6 +24,15 @@
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.transport.NodeNotConnectedException;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;

public class MLCommonsClientAccessorTests extends OpenSearchTestCase {

@Mock
Expand Down Expand Up @@ -168,7 +169,9 @@ private ModelTensorOutput createModelTensorOutput(final Float[] output) {
output,
new long[] { 1, 2 },
MLResultDataType.FLOAT64,
ByteBuffer.wrap(new byte[12])
ByteBuffer.wrap(new byte[12]),
"mockResult",
ImmutableMap.of("mockKey", "mockValue")
);
mlModelTensorList.add(tensor);
final ModelTensors modelTensors = new ModelTensors(mlModelTensorList);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,9 @@

package org.opensearch.neuralsearch.processor;

import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.*;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Supplier;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import lombok.SneakyThrows;

import org.junit.Before;
import org.mockito.InjectMocks;
import org.mockito.Mock;
Expand All @@ -33,8 +22,24 @@
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.test.OpenSearchTestCase;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Supplier;

import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.isA;
import static org.mockito.Mockito.isNull;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class TextEmbeddingProcessorTests extends OpenSearchTestCase {

Expand Down Expand Up @@ -398,6 +403,20 @@ public void testBuildVectorOutput_withNestedMap_successful() {
assertNotNull(actionGamesKnn);
}

public void test_updateDocument_appendVectorFieldsToDocument_successful() {
Map<String, Object> config = createPlainStringConfiguration();
IngestDocument ingestDocument = createPlainIngestDocument();
TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config);
Map<String, Object> knnMap = processor.buildMapWithKnnKeyAndOriginalValue(ingestDocument);
List<List<Float>> modelTensorList = createMockVectorResult();
processor.appendVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList);

List<List<Float>> modelTensorList1 = createMockVectorResult();
processor.appendVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList1);
assertEquals(12, ingestDocument.getSourceAndMetadata().size());
assertEquals(2, ((List<?>)ingestDocument.getSourceAndMetadata().get("oriKey6_knn")).size());
}

private List<List<Float>> createMockVectorResult() {
List<List<Float>> modelTensorList = new ArrayList<>();
List<Float> number1 = ImmutableList.of(1.234f, 2.354f);
Expand Down

0 comments on commit 4af436a

Please sign in to comment.