Skip to content

Commit

Permalink
Pulled search pipeline in MultiSearchRequest and updated test
Browse files Browse the repository at this point in the history
Signed-off-by: Owais <[email protected]>
  • Loading branch information
owaiskazi19 committed Sep 18, 2024
1 parent e7fc952 commit 141883d
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,10 @@ public static void readMultiLineFormat(
) {
consumer.accept(searchRequest, parser);
}

if (searchRequest.source() != null && searchRequest.source().pipeline() != null) {
searchRequest.pipeline(searchRequest.source().pipeline());
}
// move pointers
from = nextMarker + 1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ public static void parseSearchRequest(
searchRequest.routing(request.param("routing"));
searchRequest.preference(request.param("preference"));
searchRequest.indicesOptions(IndicesOptions.fromRequest(request, searchRequest.indicesOptions()));
searchRequest.pipeline(request.param("search_pipeline"));

checkRestTotalHits(request, searchRequest);
request.paramAsBoolean(INCLUDE_NAMED_QUERIES_SCORE_PARAM, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,9 +395,6 @@ public PipelinedRequest resolvePipeline(SearchRequest searchRequest, IndexNameEx
if (searchRequest.pipeline() != null) {
// Named pipeline specified for the request
pipelineId = searchRequest.pipeline();
} else if (searchRequest.source() != null && searchRequest.source().pipeline() != null) {
// Inline pipeline specified for the request
pipelineId = searchRequest.source().pipeline();
} else if (state != null && searchRequest.indices() != null && searchRequest.indices().length != 0) {
try {
// Check for index default pipeline
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.Version;
import org.opensearch.action.search.DeleteSearchPipelineRequest;
import org.opensearch.action.search.MockSearchPhaseContext;
import org.opensearch.action.search.MultiSearchRequest;
import org.opensearch.action.search.PutSearchPipelineRequest;
import org.opensearch.action.search.QueryPhaseResultConsumer;
import org.opensearch.action.search.SearchPhaseContext;
Expand Down Expand Up @@ -75,6 +76,8 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;

import static org.opensearch.search.RandomSearchRequestGenerator.randomSearchRequest;
import static org.hamcrest.Matchers.containsString;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -972,7 +975,8 @@ public void testInlinePipeline() throws Exception {
/**
* Tests a pipeline name defined in the search request source.
*/
public void testInlineDefinedPipeline() throws Exception {
public void testInlineDefinedPipelineForMultiSearch() throws Exception {
int numberOfSearchRequests = randomIntBetween(0, 32);
SearchPipelineService searchPipelineService = createWithProcessors();

SearchPipelineMetadata metadata = new SearchPipelineMetadata(
Expand All @@ -988,7 +992,6 @@ public void testInlineDefinedPipeline() throws Exception {
),
MediaTypeRegistry.JSON
)

)

);
Expand All @@ -999,34 +1002,49 @@ public void testInlineDefinedPipeline() throws Exception {
.build();
searchPipelineService.applyClusterState(new ClusterChangedEvent("", clusterState, previousState));

SearchSourceBuilder sourceBuilder = SearchSourceBuilder.searchSource().size(100).pipeline("p1");
SearchRequest searchRequest = new SearchRequest().source(sourceBuilder);

// Verify pipeline
PipelinedRequest pipelinedRequest = syncTransformRequest(
searchPipelineService.resolvePipeline(searchRequest, indexNameExpressionResolver)
);
Pipeline pipeline = pipelinedRequest.getPipeline();
assertEquals("p1", pipeline.getId());
assertEquals(1, pipeline.getSearchRequestProcessors().size());
assertEquals(1, pipeline.getSearchResponseProcessors().size());

// Verify that pipeline transforms request
assertEquals(200, pipelinedRequest.source().size());
MultiSearchRequest multiSearchRequest = new MultiSearchRequest();
for (int i = 0; i < numberOfSearchRequests; i++) {
SearchRequest searchRequest = randomSearchRequest(() -> {
// No need to return a very complex SearchSourceBuilder here, that is tested
// elsewhere
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.from(randomInt(10));
searchSourceBuilder.size(randomIntBetween(20, 100));
searchSourceBuilder.pipeline("p1");
return searchSourceBuilder;
});
multiSearchRequest.add(searchRequest);

// Verify pipeline
PipelinedRequest pipelinedRequest = syncTransformRequest(
searchPipelineService.resolvePipeline(searchRequest, indexNameExpressionResolver)
);
Pipeline pipeline = pipelinedRequest.getPipeline();
assertEquals("p1", pipeline.getId());
assertEquals(1, pipeline.getSearchRequestProcessors().size());
assertEquals(1, pipeline.getSearchResponseProcessors().size());

// Verify that pipeline transforms request
assertEquals(200, pipelinedRequest.source().size());

int size = 10;
SearchHit[] hits = new SearchHit[size];
for (int j = 0; j < size; j++) {
hits[j] = new SearchHit(j, "doc" + j, Collections.emptyMap(), Collections.emptyMap());
hits[j].score(j);
}
SearchHits searchHits = new SearchHits(hits, new TotalHits(size * 2, TotalHits.Relation.EQUAL_TO), size);
SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0);
SearchResponse searchResponse = new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null);

int size = 10;
SearchHit[] hits = new SearchHit[size];
for (int i = 0; i < size; i++) {
hits[i] = new SearchHit(i, "doc" + i, Collections.emptyMap(), Collections.emptyMap());
hits[i].score(i);
SearchResponse transformedResponse = syncTransformResponse(pipelinedRequest, searchResponse);
for (int j = 0; j < size; j++) {
assertEquals(2.0, transformedResponse.getHits().getHits()[j].getScore(), 0.0001);
}
}
SearchHits searchHits = new SearchHits(hits, new TotalHits(size * 2, TotalHits.Relation.EQUAL_TO), size);
SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0);
SearchResponse searchResponse = new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null);

SearchResponse transformedResponse = syncTransformResponse(pipelinedRequest, searchResponse);
for (int i = 0; i < size; i++) {
assertEquals(2.0, transformedResponse.getHits().getHits()[i].getScore(), 0.0001);
for (SearchRequest subReq : multiSearchRequest.requests()) {
assertThat(multiSearchRequest.toString(), containsString(subReq.toString()));
}
}

Expand Down

0 comments on commit 141883d

Please sign in to comment.