diff --git a/build.gradle b/build.gradle index d608e18..aad564a 100644 --- a/build.gradle +++ b/build.gradle @@ -86,7 +86,7 @@ dependencies { implementation 'org.apache.httpcomponents:httpclient:4.5.13' implementation 'org.apache.httpcomponents:httpcore:4.4.15' implementation 'com.fasterxml.jackson.core:jackson-databind:2.15.0' - implementation 'com.fasterxml.jackson.core:jackson-core:2.15.1' + implementation 'com.fasterxml.jackson.core:jackson-core:2.15.2' implementation 'com.fasterxml.jackson.core:jackson-annotations:2.15.0' implementation 'commons-logging:commons-logging:1.2' implementation 'com.amazonaws:aws-java-sdk-sts:1.12.300' diff --git a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/impl/AmazonPersonalizedRankerImpl.java b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/impl/AmazonPersonalizedRankerImpl.java index ba7ade5..5d2e5db 100644 --- a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/impl/AmazonPersonalizedRankerImpl.java +++ b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/reranker/impl/AmazonPersonalizedRankerImpl.java @@ -9,8 +9,10 @@ import com.amazonaws.services.personalizeruntime.model.GetPersonalizedRankingRequest; import com.amazonaws.services.personalizeruntime.model.GetPersonalizedRankingResult; +import com.amazonaws.services.personalizeruntime.model.PredictedItem; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.lucene.search.TotalHits; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClient; @@ -18,7 +20,11 @@ import org.opensearch.search.relevance.transformer.personalizeintelligentranking.requestparameter.PersonalizeRequestParameters; import org.opensearch.search.relevance.transformer.personalizeintelligentranking.reranker.PersonalizedRanker; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -82,14 +88,82 @@ public SearchHits rerank(SearchHits hits, PersonalizeRequestParameters requestPa .withUserId(userId); GetPersonalizedRankingResult result = personalizeClient.getPersonalizedRanking(personalizeRequest); - //TODO: Combine Personalize and open search result. Change the result after transform logic is implemented - return hits; + List personalizeRrankingResult = result.getPersonalizedRanking(); + Map idToPersonalizeRankingScoreMap = new HashMap<>(); + Map idToOpenSearchScoreMap = new HashMap<>(); + Map itemIdToSearchHitMap = new HashMap<>(); + // Build a map with key as item id and value as personalize ranking score + for (PredictedItem item : personalizeRrankingResult) { + idToPersonalizeRankingScoreMap.put(item.getItemId(), item.getScore().floatValue()); + } + + // Build a map with key as item id and value as open search scores and another map + // with key as item id and value as corresponding search hit + for (SearchHit hit : originalHits) { + if (!itemIdfield.isEmpty()){ + idToOpenSearchScoreMap.put(hit.getSourceAsMap().get(itemIdfield).toString(), hit.getScore()); + itemIdToSearchHitMap.put(hit.getSourceAsMap().get(itemIdfield).toString(), hit); + } + else{ + idToOpenSearchScoreMap.put(hit.getId(), hit.getScore()); + itemIdToSearchHitMap.put(hit.getId(), hit); + } + } + + + float weight = (float) rankerConfig.getWeight(); + SearchHits newHits = combineScores(idToPersonalizeRankingScoreMap, idToOpenSearchScoreMap, + itemIdToSearchHitMap, hits.getTotalHits(), weight); + return newHits; } catch (Exception ex) { logger.error("Failed to re rank with Personalize. Returning original search results without Personalize re ranking.", ex); return hits; } } + //Combine open search hits and personalize campaign response + public SearchHits combineScores(Map idToPersonalizeRankingScoreMap, + Map idToOpenSearchScoreMap, + Map itemIdToSearchHitMap, + TotalHits totalHits, float weight) { + //Update open search score based on the personalize campaign response for each item id + List openSearchItemId = new ArrayList(idToOpenSearchScoreMap.keySet()); + for (String itemId : openSearchItemId) { + if(idToPersonalizeRankingScoreMap.containsKey(itemId)){ + float personalizedScore = idToPersonalizeRankingScoreMap.get(itemId); + float openSearchScore = idToOpenSearchScoreMap.get(itemId); + float combinedScore = (float) (weight / Math.log(openSearchScore + 1) + + (1 - weight) / Math.log(personalizedScore + 1)); + idToOpenSearchScoreMap.put(itemId, combinedScore); + } + } + + //Create a new list of search hits in the decreasing order of the combined scores + Map sortedScores = sortByValue(idToOpenSearchScoreMap); + + List rerankedHits = sortedScores.keySet().stream() + .map(itemId -> { + SearchHit hit = itemIdToSearchHitMap.get(itemId); + hit.score(sortedScores.get(itemId)); + return hit; + }) + .collect(Collectors.toList()); + float maxScore = sortedScores.values().stream().max(Float::compare).orElse(0f); + return new SearchHits(rerankedHits.toArray(new SearchHit[0]), totalHits, maxScore); + } + + + //Sort map by reverse order of the values + public Map sortByValue(Map map) { + return map.entrySet().stream() + .sorted(Map.Entry.comparingByValue().reversed()) + .collect(Collectors.toMap( + Map.Entry::getKey, + Map.Entry::getValue, + (oldValue, newValue) -> oldValue, LinkedHashMap::new)); + } + + /** * Validate Personalize configuration for calling Personalize service * @param requestParameters Request parameters for Personalize present in search request diff --git a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/ranker/impl/AmazonPersonalizeRankerImplTests.java b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/ranker/impl/AmazonPersonalizeRankerImplTests.java index b5e5ac9..5e42712 100644 --- a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/ranker/impl/AmazonPersonalizeRankerImplTests.java +++ b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/ranker/impl/AmazonPersonalizeRankerImplTests.java @@ -9,6 +9,7 @@ package org.opensearch.search.relevance.transformer.personalizeintelligentranking.ranker.impl; import org.mockito.Mockito; +import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.relevance.transformer.personalizeintelligentranking.client.PersonalizeClient; import org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.PersonalizeIntelligentRankerConfiguration; @@ -18,9 +19,14 @@ import org.opensearch.search.relevance.transformer.personalizeintelligentranking.utils.SearchTestUtil; import org.opensearch.test.OpenSearchTestCase; + import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import static org.mockito.ArgumentMatchers.any; @@ -32,6 +38,7 @@ public class AmazonPersonalizeRankerImplTests extends OpenSearchTestCase { private String itemIdField = "ITEM_ID"; private String region = "us-west-2"; private double weight = 0.25; + private int numOfHits = 10; public void testReRank() throws IOException { PersonalizeIntelligentRankerConfiguration rankerConfig = @@ -42,7 +49,7 @@ public void testReRank() throws IOException { AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); requestParameters.setUserId("28"); - SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(10); + SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(numOfHits); SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); assertEquals(responseHits.getHits().length, transformedHits.getHits().length); } @@ -57,7 +64,7 @@ public void testReRankWithoutItemIdFieldInConfig() throws IOException { AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); requestParameters.setUserId("28"); - SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(10); + SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(numOfHits); SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); assertEquals(responseHits.getHits().length, transformedHits.getHits().length); } @@ -74,7 +81,7 @@ public void testReRankWithRequestParameterContext() throws IOException { PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); requestParameters.setUserId("28"); requestParameters.setContext(context); - SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(10); + SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(numOfHits); SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); assertEquals(responseHits.getHits().length, transformedHits.getHits().length); } @@ -91,7 +98,7 @@ public void testReRankWithInvalidRequestParameterContext() throws IOException { PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); requestParameters.setUserId("28"); requestParameters.setContext(context); - SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(10); + SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(numOfHits); SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); assertEquals(responseHits.getHits().length, transformedHits.getHits().length); } @@ -107,7 +114,7 @@ public void testReRankWithNoUserId() throws IOException { context.put("contextKey", "contextValue"); PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); requestParameters.setContext(context); - SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(10); + SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(numOfHits); SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); assertEquals(responseHits.getHits().length, transformedHits.getHits().length); } @@ -122,7 +129,7 @@ public void testReRankWithEmptyItemIdField() throws IOException { AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); requestParameters.setUserId("28"); - SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(10); + SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(numOfHits); SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); assertEquals(responseHits.getHits().length, transformedHits.getHits().length); } @@ -136,8 +143,275 @@ public void testReRankWithNullItemIdField() throws IOException { AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); requestParameters.setUserId("28"); - SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(10); + SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(numOfHits); SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); assertEquals(responseHits.getHits().length, transformedHits.getHits().length); } + + + public void testReRankWithWeightAsZero() throws IOException { + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, itemIdField, region, 0); + PersonalizeClient client = Mockito.mock(PersonalizeClient.class); + Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResult(numOfHits)); + + AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); + PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); + requestParameters.setUserId("28"); + SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(numOfHits); + SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); + + List originalHits = Arrays.asList(transformedHits.getHits()); + String itemIdfield = rankerConfig.getItemIdField(); + List rerankedDocumentIds; + + rerankedDocumentIds = originalHits.stream() + .filter(h -> h.getSourceAsMap().get(itemIdfield) != null) + .map(h -> h.getSourceAsMap().get(itemIdfield).toString()) + .collect(Collectors.toList()); + + ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsWhenWeightIsZero(numOfHits); + + assertEquals(expectedRankedDocumentIds, rerankedDocumentIds); + } + + + public void testReRankWithWeightAsOne() throws IOException { + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, itemIdField, region, 1); + PersonalizeClient client = Mockito.mock(PersonalizeClient.class); + Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResult(numOfHits)); + + + AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); + PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); + requestParameters.setUserId("28"); + SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(numOfHits); + SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); + + List originalHits = Arrays.asList(transformedHits.getHits()); + String itemIdfield = rankerConfig.getItemIdField(); + List rerankedDocumentIds; + + rerankedDocumentIds = originalHits.stream() + .filter(h -> h.getSourceAsMap().get(itemIdfield) != null) + .map(h -> h.getSourceAsMap().get(itemIdfield).toString()) + .collect(Collectors.toList()); + + ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsWhenWeightIsOne(numOfHits); + + assertEquals(expectedRankedDocumentIds, rerankedDocumentIds); + } + + public void testReRankWithWeightAsNietherZeroOrOne() throws IOException { + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, itemIdField, region, weight); + PersonalizeClient client = Mockito.mock(PersonalizeClient.class); + Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResult(numOfHits)); + + AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); + PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); + requestParameters.setUserId("28"); + SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(numOfHits); + SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); + + List originalHits = Arrays.asList(transformedHits.getHits()); + String itemIdfield = rankerConfig.getItemIdField(); + List rerankedDocumentIds; + + rerankedDocumentIds = originalHits.stream() + .filter(h -> h.getSourceAsMap().get(itemIdfield) != null) + .map(h -> h.getSourceAsMap().get(itemIdfield).toString()) + .collect(Collectors.toList()); + + ArrayList rerankedDocumentIdsWhenWeightIsOne = PersonalizeRuntimeTestUtil.expectedRankedItemIdsWhenWeightIsOne(numOfHits); + ArrayList rerankedDocumentIdsWhenWeightIsZero = PersonalizeRuntimeTestUtil.expectedRankedItemIdsWhenWeightIsZero(numOfHits); + + assertNotEquals(rerankedDocumentIdsWhenWeightIsOne, rerankedDocumentIds); + assertNotEquals(rerankedDocumentIdsWhenWeightIsZero, rerankedDocumentIds); + } + + public void testReRankWithWeightAsGreaterThanOne() throws IOException { + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, itemIdField, region, 2); + PersonalizeClient client = Mockito.mock(PersonalizeClient.class); + Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResult(numOfHits)); + + AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); + PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); + requestParameters.setUserId("28"); + SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(numOfHits); + SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); + + List originalHits = Arrays.asList(transformedHits.getHits()); + String itemIdfield = rankerConfig.getItemIdField(); + List rerankedDocumentIds; + + rerankedDocumentIds = originalHits.stream() + .filter(h -> h.getSourceAsMap().get(itemIdfield) != null) + .map(h -> h.getSourceAsMap().get(itemIdfield).toString()) + .collect(Collectors.toList()); + + ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsWhenWeightIsZero(numOfHits); + assertEquals(expectedRankedDocumentIds, rerankedDocumentIds); + } + + public void testReRankWithWeightAsLessThanZero() throws IOException { + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, itemIdField, region, -1); + PersonalizeClient client = Mockito.mock(PersonalizeClient.class); + Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResult(numOfHits)); + + AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); + PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); + requestParameters.setUserId("28"); + SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(numOfHits); + SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); + + List originalHits = Arrays.asList(transformedHits.getHits()); + String itemIdfield = rankerConfig.getItemIdField(); + List rerankedDocumentIds; + + rerankedDocumentIds = originalHits.stream() + .filter(h -> h.getSourceAsMap().get(itemIdfield) != null) + .map(h -> h.getSourceAsMap().get(itemIdfield).toString()) + .collect(Collectors.toList()); + + ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsWhenWeightIsZero(numOfHits); + assertEquals(expectedRankedDocumentIds, rerankedDocumentIds); + } + + + public void testReRankWithWeightAsZeroWithNullItemIdField() throws IOException { + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, "", region, 0); + PersonalizeClient client = Mockito.mock(PersonalizeClient.class); + Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResultWhenItemIdConfigIsEmpty(numOfHits)); + + AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); + PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); + requestParameters.setUserId("28"); + SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(numOfHits); + SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); + + List originalHits = Arrays.asList(transformedHits.getHits()); + String itemIdfield = rankerConfig.getItemIdField(); + List rerankedDocumentIds; + + rerankedDocumentIds = originalHits.stream() + .filter(h -> h.getId() != null) + .map(h -> h.getId()) + .collect(Collectors.toList()); + + ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsWhenWeightIsZeroWithNullItemIdField(numOfHits); + + assertEquals(expectedRankedDocumentIds, rerankedDocumentIds); + } + + + public void testReRankWithWeightAsOneWithNullItemIdField() throws IOException { + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, "", region, 1); + PersonalizeClient client = Mockito.mock(PersonalizeClient.class); + Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResultWhenItemIdConfigIsEmpty(numOfHits)); + + + AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); + PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); + requestParameters.setUserId("28"); + SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(numOfHits); + SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); + + List originalHits = Arrays.asList(transformedHits.getHits()); + String itemIdfield = rankerConfig.getItemIdField(); + List rerankedDocumentIds; + + rerankedDocumentIds = originalHits.stream() + .filter(h -> h.getId() != null) + .map(h -> h.getId()) + .collect(Collectors.toList()); + + ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsWhenWeightIsOneWithNullItemIdField(numOfHits); + + assertEquals(expectedRankedDocumentIds, rerankedDocumentIds); + } + + public void testReRankWithWeightAsNietherZeroOrOneWithNullItemIdField() throws IOException { + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, "", region, weight); + PersonalizeClient client = Mockito.mock(PersonalizeClient.class); + Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResultWhenItemIdConfigIsEmpty(numOfHits)); + + AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); + PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); + requestParameters.setUserId("28"); + SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(numOfHits); + SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); + + List originalHits = Arrays.asList(transformedHits.getHits()); + String itemIdfield = rankerConfig.getItemIdField(); + List rerankedDocumentIds; + + rerankedDocumentIds = originalHits.stream() + .filter(h -> h.getId() != null) + .map(h -> h.getId()) + .collect(Collectors.toList()); + + ArrayList rerankedDocumentIdsWhenWeightIsOne = PersonalizeRuntimeTestUtil.expectedRankedItemIdsWhenWeightIsOneWithNullItemIdField(numOfHits); + ArrayList rerankedDocumentIdsWhenWeightIsZero = PersonalizeRuntimeTestUtil.expectedRankedItemIdsWhenWeightIsZeroWithNullItemIdField(numOfHits); + + assertNotEquals(rerankedDocumentIdsWhenWeightIsOne, rerankedDocumentIds); + assertNotEquals(rerankedDocumentIdsWhenWeightIsZero, rerankedDocumentIds); + } + + public void testReRankWithWeightAsGreaterThanOneWithNullItemIdField() throws IOException { + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, "", region, 2); + PersonalizeClient client = Mockito.mock(PersonalizeClient.class); + Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResultWhenItemIdConfigIsEmpty(numOfHits)); + + AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); + PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); + requestParameters.setUserId("28"); + SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(numOfHits); + SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); + + List originalHits = Arrays.asList(transformedHits.getHits()); + String itemIdfield = rankerConfig.getItemIdField(); + List rerankedDocumentIds; + + rerankedDocumentIds = originalHits.stream() + .filter(h -> h.getId() != null) + .map(h -> h.getId()) + .collect(Collectors.toList()); + + ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsWhenWeightIsZeroWithNullItemIdField(numOfHits); + assertEquals(expectedRankedDocumentIds, rerankedDocumentIds); + } + + public void testReRankWithWeightAsLessThanZeroWithNullItemIdField() throws IOException { + PersonalizeIntelligentRankerConfiguration rankerConfig = + new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, recipe, "", region, -1); + PersonalizeClient client = Mockito.mock(PersonalizeClient.class); + Mockito.when(client.getPersonalizedRanking(any())).thenReturn(PersonalizeRuntimeTestUtil.buildGetPersonalizedRankingResultWhenItemIdConfigIsEmpty(numOfHits)); + + AmazonPersonalizedRankerImpl ranker = new AmazonPersonalizedRankerImpl(rankerConfig, client); + PersonalizeRequestParameters requestParameters = new PersonalizeRequestParameters(); + requestParameters.setUserId("28"); + SearchHits responseHits = SearchTestUtil.getSampleSearchHitsForPersonalize(numOfHits); + SearchHits transformedHits = ranker.rerank(responseHits, requestParameters); + + List originalHits = Arrays.asList(transformedHits.getHits()); + String itemIdfield = rankerConfig.getItemIdField(); + List rerankedDocumentIds; + + rerankedDocumentIds = originalHits.stream() + .filter(h -> h.getId() != null) + .map(h -> h.getId()) + .collect(Collectors.toList()); + + ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsWhenWeightIsZeroWithNullItemIdField(numOfHits); + assertEquals(expectedRankedDocumentIds, rerankedDocumentIds); + } + } diff --git a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/PersonalizeRuntimeTestUtil.java b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/PersonalizeRuntimeTestUtil.java index db68d2d..4ca4b34 100644 --- a/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/PersonalizeRuntimeTestUtil.java +++ b/src/test/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/utils/PersonalizeRuntimeTestUtil.java @@ -31,4 +31,66 @@ public static GetPersonalizedRankingResult buildGetPersonalizedRankingResult() { .withRecommendationId("sampleRecommendationId"); return result; } + + public static GetPersonalizedRankingResult buildGetPersonalizedRankingResult(int numOfHits) { + List predictedItems = new ArrayList<>(); + for(int i = numOfHits; i >= 1; i--){ + PredictedItem predictedItem = new PredictedItem(). + withScore((double) i/10). + withItemId(String.valueOf(i-1)); + predictedItems.add(predictedItem); + } + GetPersonalizedRankingResult result = new GetPersonalizedRankingResult() + .withPersonalizedRanking(predictedItems) + .withRecommendationId("sampleRecommendationId"); + return result; + } + + + public static GetPersonalizedRankingResult buildGetPersonalizedRankingResultWhenItemIdConfigIsEmpty(int numOfHits) { + List predictedItems = new ArrayList<>(); + for(int i = numOfHits; i >= 1; i--){ + PredictedItem predictedItem = new PredictedItem(). + withScore((double) i/10). + withItemId("doc"+ (i - 1)); + predictedItems.add(predictedItem); + } + + GetPersonalizedRankingResult result = new GetPersonalizedRankingResult() + .withPersonalizedRanking(predictedItems) + .withRecommendationId("sampleRecommendationId"); + return result; + } + + public static ArrayList expectedRankedItemIdsWhenWeightIsOne(int numOfHits){ + ArrayList expectedRankedItemIds = new ArrayList<>(); + for(int i = numOfHits; i >= 1; i--){ + expectedRankedItemIds.add(String.valueOf(i-1)); + } + return expectedRankedItemIds; + } + + public static ArrayList expectedRankedItemIdsWhenWeightIsZero(int numOfHits){ + ArrayList expectedRankedItemIds = new ArrayList<>(); + for(int i = 0; i expectedRankedItemIdsWhenWeightIsOneWithNullItemIdField(int numOfHits){ + ArrayList expectedRankedItemIds = new ArrayList<>(); + for(int i = numOfHits; i >= 1; i--){ + expectedRankedItemIds.add("doc" + (i - 1)); + } + return expectedRankedItemIds; + } + + public static ArrayList expectedRankedItemIdsWhenWeightIsZeroWithNullItemIdField(int numOfHits){ + ArrayList expectedRankedItemIds = new ArrayList<>(); + for(int i = 0; i