Skip to content

Commit

Permalink
review feedback
Browse files Browse the repository at this point in the history
Signed-off-by: Samuel Herman <[email protected]>
  • Loading branch information
samuel-oci committed Oct 21, 2023
1 parent 266a34b commit 0effd07
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 182 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ public void execute(
final ScoreNormalizationTechnique normalizationTechnique,
final ScoreCombinationTechnique combinationTechnique
) {
log.info("Entering normalization processor workflow");
// save original state
List<Integer> unprocessedDocIds = unprocessedDocIds(querySearchResults);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@

package org.opensearch.neuralsearch.processor.normalization;

import com.google.common.primitives.Floats;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

import lombok.ToString;

import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;

import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import com.google.common.primitives.Floats;

/**
* Implementation of z-score normalization technique for hybrid query
Expand All @@ -24,24 +27,26 @@
TODO: Some todo items that apply here but also on the original normalization techniques on which it is modeled {@link L2ScoreNormalizationTechnique} and {@link MinMaxScoreNormalizationTechnique}
1. Random access to abstract list object is a bad practice both stylistically and from performance perspective and should be removed
2. Identical sub queries and their distribution between shards is currently completely implicit based on ordering and should be explicit based on identifier
3. Weird calculation of numOfSubQueries instead of having a more explicit indicator
3. Implicit calculation of numOfSubQueries instead of having a more explicit upstream indicator/metadata regarding it
*/
@ToString(onlyExplicitlyIncluded = true)
public class ZScoreNormalizationTechnique implements ScoreNormalizationTechnique {
@ToString.Include
public static final String TECHNIQUE_NAME = "z_score";
private static final float SINGLE_RESULT_SCORE = 1.0f;

@Override
public void normalize(List<CompoundTopDocs> queryTopDocs) {
// why are we doing that? is List<CompoundTopDocs> the list of subqueries for a single shard? or a global list of all subqueries across shards?
// If a subquery comes from each shard then when is it combined? that seems weird that combination will do combination of normalized results that each is normalized just based on shard level result
int numOfSubQueries = queryTopDocs.stream()
.filter(Objects::nonNull)
.filter(topDocs -> topDocs.getTopDocs().size() > 0)
.findAny()
.get()
.getTopDocs()
.size();
public void normalize(final List<CompoundTopDocs> queryTopDocs) {
/*
TODO: There is an implicit assumption in this calculation that probably need to be made clearer by passing some metadata with the results.
Currently assuming that finding a single non empty shard result will contain all sub query results with 0 hits.
*/
final Optional<CompoundTopDocs> maybeCompoundTopDocs = queryTopDocs.stream()
.filter(Objects::nonNull)
.filter(topDocs -> topDocs.getTopDocs().size() > 0)
.findAny();

final int numOfSubQueries = maybeCompoundTopDocs.map(compoundTopDocs -> compoundTopDocs.getTopDocs().size()).orElse(0);

// to be done for each subquery
float[] sumPerSubquery = findScoreSumPerSubQuery(queryTopDocs, numOfSubQueries);
Expand All @@ -67,9 +72,7 @@ public void normalize(List<CompoundTopDocs> queryTopDocs) {
static private float[] findScoreSumPerSubQuery(final List<CompoundTopDocs> queryTopDocs, final int numOfScores) {
final float[] sumOfScorePerSubQuery = new float[numOfScores];
Arrays.fill(sumOfScorePerSubQuery, 0);
//TODO: make this better, currently
// this is a horrible implementation in particular when it comes to the topDocsPerSubQuery.get(j)
// which does a random search on an abstract list type.
// TODO: make this syntactically clearer regarding performance by avoiding List.get(j) with an abstract List type
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
if (Objects.isNull(compoundQueryTopDocs)) {
continue;
Expand All @@ -86,9 +89,7 @@ static private float[] findScoreSumPerSubQuery(final List<CompoundTopDocs> query
static private long[] findNumberOfElementsPerSubQuery(final List<CompoundTopDocs> queryTopDocs, final int numOfScores) {
final long[] numberOfElementsPerSubQuery = new long[numOfScores];
Arrays.fill(numberOfElementsPerSubQuery, 0);
//TODO: make this better, currently
// this is a horrible implementation in particular when it comes to the topDocsPerSubQuery.get(j)
// which does a random search on an abstract list type.
// TODO: make this syntactically clearer regarding performance by avoiding List.get(j) with an abstract List type
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
if (Objects.isNull(compoundQueryTopDocs)) {
continue;
Expand All @@ -108,21 +109,22 @@ static private float[] findMeanPerSubquery(final float[] sumPerSubquery, final l
if (elementsPerSubquery[i] == 0) {
meanPerSubQuery[i] = 0;
} else {
meanPerSubQuery[i] = sumPerSubquery[i]/elementsPerSubquery[i];
meanPerSubQuery[i] = sumPerSubquery[i] / elementsPerSubquery[i];
}
}

return meanPerSubQuery;
}

static private float[] findStdPerSubquery(final List<CompoundTopDocs> queryTopDocs, final float[] meanPerSubQuery, final long[] elementsPerSubquery, final int numOfScores) {
static private float[] findStdPerSubquery(
final List<CompoundTopDocs> queryTopDocs,
final float[] meanPerSubQuery,
final long[] elementsPerSubquery,
final int numOfScores
) {
final double[] deltaSumPerSubquery = new double[numOfScores];
Arrays.fill(deltaSumPerSubquery, 0);


//TODO: make this better, currently
// this is a horrible implementation in particular when it comes to the topDocsPerSubQuery.get(j)
// which does a random search on an abstract list type.
// TODO: make this syntactically clearer regarding performance by avoiding List.get(j) with an abstract List type
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
if (Objects.isNull(compoundQueryTopDocs)) {
continue;
Expand All @@ -147,7 +149,7 @@ static private float[] findStdPerSubquery(final List<CompoundTopDocs> queryTopDo
return stdPerSubQuery;
}

static private float sumScoreDocsArray(ScoreDoc[] scoreDocs) {
static private float sumScoreDocsArray(final ScoreDoc[] scoreDocs) {
float sum = 0;
for (ScoreDoc scoreDoc : scoreDocs) {
sum += scoreDoc.score;
Expand All @@ -161,6 +163,6 @@ private static float normalizeSingleScore(final float score, final float standar
if (Floats.compare(mean, score) == 0) {
return SINGLE_RESULT_SCORE;
}
return (score - mean) / standardDeviation;
return (score - mean) / standardDeviation;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -760,4 +760,19 @@ private String registerModelGroup() {
assertNotNull(modelGroupId);
return modelGroupId;
}

protected List<Map<String, Object>> getNestedHits(Map<String, Object> searchResponseAsMap) {
Map<String, Object> hitsMap = (Map<String, Object>) searchResponseAsMap.get("hits");
return (List<Map<String, Object>>) hitsMap.get("hits");
}

protected Map<String, Object> getTotalHits(Map<String, Object> searchResponseAsMap) {
Map<String, Object> hitsMap = (Map<String, Object>) searchResponseAsMap.get("hits");
return (Map<String, Object>) hitsMap.get("total");
}

protected Optional<Float> getMaxScore(Map<String, Object> searchResponseAsMap) {
Map<String, Object> hitsMap = (Map<String, Object>) searchResponseAsMap.get("hits");
return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,17 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.query;
package org.opensearch.neuralsearch.processor;

import static org.opensearch.neuralsearch.TestUtils.DELTA_FOR_SCORE_ASSERTION;
import static org.opensearch.neuralsearch.TestUtils.createRandomVector;

import java.io.IOException;
import java.util.*;
import java.util.stream.IntStream;

import com.google.common.primitives.Floats;
import lombok.SneakyThrows;

import org.junit.After;
import org.junit.Before;
import org.opensearch.index.query.BoolQueryBuilder;
Expand All @@ -15,13 +22,10 @@
import org.opensearch.knn.index.SpaceType;
import org.opensearch.neuralsearch.common.BaseNeuralSearchIT;
import org.opensearch.neuralsearch.processor.normalization.ZScoreNormalizationTechnique;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;

import java.io.IOException;
import java.util.*;
import java.util.stream.IntStream;

import static org.opensearch.neuralsearch.TestUtils.DELTA_FOR_SCORE_ASSERTION;
import static org.opensearch.neuralsearch.TestUtils.createRandomVector;
import com.google.common.primitives.Floats;

public class HybridQueryZScoreIT extends BaseNeuralSearchIT {
private static final String TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME = "test-neural-vector-doc-field-index";
Expand All @@ -46,7 +50,12 @@ public void setUp() throws Exception {
super.setUp();
updateClusterSettings();
prepareModel();
createSearchPipeline(SEARCH_PIPELINE, ZScoreNormalizationTechnique.TECHNIQUE_NAME, DEFAULT_COMBINATION_METHOD, Map.of(PARAM_NAME_WEIGHTS, "[0.5,0.5]"));
createSearchPipeline(
SEARCH_PIPELINE,
ZScoreNormalizationTechnique.TECHNIQUE_NAME,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, "[0.5,0.5]")
);
}

@After
Expand Down Expand Up @@ -114,25 +123,24 @@ public void testComplexQuery_withZScoreNormalization() {

String modelId = getDeployedModelId();
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(
TEST_KNN_VECTOR_FIELD_NAME_1,
TEST_QUERY_TEXT,
modelId,
5,
null,
null
TEST_KNN_VECTOR_FIELD_NAME_1,
TEST_QUERY_TEXT,
modelId,
5,
null,
null
);

HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder();
hybridQueryBuilderNeuralThenTerm.add(neuralQueryBuilder);
hybridQueryBuilderNeuralThenTerm.add(boolQueryBuilder);


final Map<String, Object> searchResponseAsMap = search(
TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME,
hybridQueryBuilderNeuralThenTerm,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE)
TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME,
hybridQueryBuilderNeuralThenTerm,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertEquals(2, getHitCount(searchResponseAsMap));
Expand All @@ -146,10 +154,11 @@ public void testComplexQuery_withZScoreNormalization() {
}

assertEquals(2, scores.size());
// by design when there are only two results with z score since it's z-score normalized we would expect 1 , -1 to be the corresponding score,
// by design when there are only two results with z score since it's z-score normalized we would expect 1 , -1 to be the
// corresponding score,
// furthermore the combination logic with weights should make it doc1Score: (1 * w1 + 0.98 * w2)/(w1 + w2), doc2Score: -1 ~ 0
assertEquals(0.9999, scores.get(0).floatValue(), DELTA_FOR_SCORE_ASSERTION);
assertEquals(0 , scores.get(1).floatValue(), DELTA_FOR_SCORE_ASSERTION);
assertEquals(0, scores.get(1).floatValue(), DELTA_FOR_SCORE_ASSERTION);

// verify that scores are in desc order
assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1)));
Expand Down Expand Up @@ -193,19 +202,4 @@ private void initializeIndexIfNotExist() throws IOException {
assertEquals(2, getDocCount(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME));
}
}

private List<Map<String, Object>> getNestedHits(Map<String, Object> searchResponseAsMap) {
Map<String, Object> hitsMap = (Map<String, Object>) searchResponseAsMap.get("hits");
return (List<Map<String, Object>>) hitsMap.get("hits");
}

private Map<String, Object> getTotalHits(Map<String, Object> searchResponseAsMap) {
Map<String, Object> hitsMap = (Map<String, Object>) searchResponseAsMap.get("hits");
return (Map<String, Object>) hitsMap.get("total");
}

private Optional<Float> getMaxScore(Map<String, Object> searchResponseAsMap) {
Map<String, Object> hitsMap = (Map<String, Object>) searchResponseAsMap.get("hits");
return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.IntStream;

Expand Down Expand Up @@ -341,21 +340,6 @@ private void initializeIndexIfNotExist(String indexName) throws IOException {
}
}

private List<Map<String, Object>> getNestedHits(Map<String, Object> searchResponseAsMap) {
Map<String, Object> hitsMap = (Map<String, Object>) searchResponseAsMap.get("hits");
return (List<Map<String, Object>>) hitsMap.get("hits");
}

private Map<String, Object> getTotalHits(Map<String, Object> searchResponseAsMap) {
Map<String, Object> hitsMap = (Map<String, Object>) searchResponseAsMap.get("hits");
return (Map<String, Object>) hitsMap.get("total");
}

private Optional<Float> getMaxScore(Map<String, Object> searchResponseAsMap) {
Map<String, Object> hitsMap = (Map<String, Object>) searchResponseAsMap.get("hits");
return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue());
}

private void assertQueryResults(Map<String, Object> searchResponseAsMap, int totalExpectedDocQty, boolean assertMinScore) {
assertQueryResults(searchResponseAsMap, totalExpectedDocQty, assertMinScore, Range.between(0.5f, 1.0f));
}
Expand Down
Loading

0 comments on commit 0effd07

Please sign in to comment.