Skip to content

Commit

Permalink
refactor code
Browse files Browse the repository at this point in the history
Signed-off-by: Kaushal Kumar <[email protected]>
  • Loading branch information
kaushalmahi12 committed Aug 27, 2024
1 parent d8e41e1 commit a3df783
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 35 deletions.
13 changes: 8 additions & 5 deletions server/src/main/java/org/opensearch/wlm/QueryGroupService.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

package org.opensearch.wlm;

import java.util.Optional;
import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException;

/**
* This is stub at this point in time and will be replace by an acutal one in couple of days
Expand All @@ -25,13 +25,16 @@ public void requestFailedFor(final String queryGroupId) {
/**
*
* @param queryGroupId query group identifier
* @return whether the queryGroup is contended and should reject new incoming requests
*/
public Optional<String> shouldRejectFor(String queryGroupId) {
if (queryGroupId == null) return Optional.empty();
public void rejectIfNeeded(String queryGroupId) {
if (queryGroupId == null) return;
boolean reject = false;
final StringBuilder reason = new StringBuilder();
// TODO: At this point this is dummy and we need to decide whether to cancel the request based on last
// reported resource usage for the queryGroup. We also need to increment the rejection count here for the
// query group
return Optional.of("Possible reason. ");
if (reject) {
throw new OpenSearchRejectedExecutionException("QueryGroup " + queryGroupId + " is already contended." + reason.toString());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,13 @@

package org.opensearch.wlm;

import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportChannel;
import org.opensearch.transport.TransportInterceptor;
import org.opensearch.transport.TransportRequest;
import org.opensearch.transport.TransportRequestHandler;

import java.util.Optional;

/**
* This class is used to intercept search traffic requests and populate the queryGroupId header in task headers
*/
Expand Down Expand Up @@ -61,11 +58,7 @@ public void messageReceived(T request, TransportChannel channel, Task task) thro
if (isSearchWorkloadRequest(task)) {
((QueryGroupTask) task).setQueryGroupId(threadPool.getThreadContext());
final String queryGroupId = ((QueryGroupTask) (task)).getQueryGroupId();
Optional<String> reason = queryGroupService.shouldRejectFor(queryGroupId);

if (reason.isPresent()) {
throw new OpenSearchRejectedExecutionException("QueryGroup " + queryGroupId + " is already contended." + reason.get());
}
queryGroupService.rejectIfNeeded(queryGroupId);
}
actualHandler.messageReceived(request, channel, task);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,10 @@

import org.opensearch.action.search.SearchRequestContext;
import org.opensearch.action.search.SearchRequestOperationsListener;
import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.wlm.QueryGroupService;
import org.opensearch.wlm.QueryGroupTask;

import java.util.Optional;

/**
* This listener is used to perform the rejections for incoming requests into a queryGroup
*/
Expand All @@ -37,9 +34,6 @@ public QueryGroupRequestRejectionOperationListener(QueryGroupService queryGroupS
@Override
protected void onRequestStart(SearchRequestContext searchRequestContext) {
final String queryGroupId = threadPool.getThreadContext().getHeader(QueryGroupTask.QUERY_GROUP_ID_HEADER);
Optional<String> reason = queryGroupService.shouldRejectFor(queryGroupId);
if (reason.isPresent()) {
throw new OpenSearchRejectedExecutionException("QueryGroup " + queryGroupId + " is already contended." + reason.get());
}
queryGroupService.rejectIfNeeded(queryGroupId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
import org.opensearch.transport.TransportRequestHandler;

import java.util.Collections;
import java.util.Optional;

import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;

public class WorkloadManagementTransportRequestHandlerTests extends OpenSearchTestCase {
private WorkloadManagementTransportInterceptor.RequestHandler<TransportRequest> sut;
Expand All @@ -51,16 +51,15 @@ public void tearDown() throws Exception {
public void testMessageReceivedForSearchWorkload_nonRejectionCase() throws Exception {
ShardSearchRequest request = mock(ShardSearchRequest.class);
QueryGroupTask spyTask = getSpyTask();
when(queryGroupService.shouldRejectFor(anyString())).thenReturn(Optional.empty());

doNothing().when(queryGroupService).rejectIfNeeded(anyString());
sut.messageReceived(request, mock(TransportChannel.class), spyTask);
assertTrue(sut.isSearchWorkloadRequest(spyTask));
}

public void testMessageReceivedForSearchWorkload_RejectionCase() throws Exception {
ShardSearchRequest request = mock(ShardSearchRequest.class);
QueryGroupTask spyTask = getSpyTask();
when(queryGroupService.shouldRejectFor(anyString())).thenReturn(Optional.of("QueryGroup is contended."));
doThrow(OpenSearchRejectedExecutionException.class).when(queryGroupService).rejectIfNeeded(anyString());

assertThrows(OpenSearchRejectedExecutionException.class, () -> sut.messageReceived(request, mock(TransportChannel.class), spyTask));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
import org.opensearch.wlm.QueryGroupService;
import org.opensearch.wlm.QueryGroupTask;

import java.util.Optional;

import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class QueryGroupRequestRejectionOperationListenerTests extends OpenSearchTestCase {
ThreadPool testThreadPool;
Expand All @@ -28,6 +27,8 @@ public class QueryGroupRequestRejectionOperationListenerTests extends OpenSearch
public void setUp() throws Exception {
super.setUp();
testThreadPool = new TestThreadPool("RejectionTestThreadPool");
queryGroupService = mock(QueryGroupService.class);
sut = new QueryGroupRequestRejectionOperationListener(queryGroupService, testThreadPool);
}

public void tearDown() throws Exception {
Expand All @@ -36,21 +37,16 @@ public void tearDown() throws Exception {
}

public void testRejectionCase() {
queryGroupService = mock(QueryGroupService.class);
sut = new QueryGroupRequestRejectionOperationListener(queryGroupService, testThreadPool);
final String testQueryGroupId = "asdgasgkajgkw3141_3rt4t";
testThreadPool.getThreadContext().putHeader(QueryGroupTask.QUERY_GROUP_ID_HEADER, testQueryGroupId);
when(queryGroupService.shouldRejectFor(testQueryGroupId)).thenReturn(Optional.of("Test query group is contended"));

doThrow(OpenSearchRejectedExecutionException.class).when(queryGroupService).rejectIfNeeded(testQueryGroupId);
assertThrows(OpenSearchRejectedExecutionException.class, () -> sut.onRequestStart(null));
}

public void testNonRejectionCase() {
queryGroupService = mock(QueryGroupService.class);
sut = new QueryGroupRequestRejectionOperationListener(queryGroupService, testThreadPool);
final String testQueryGroupId = "asdgasgkajgkw3141_3rt4t";
testThreadPool.getThreadContext().putHeader(QueryGroupTask.QUERY_GROUP_ID_HEADER, testQueryGroupId);
when(queryGroupService.shouldRejectFor(testQueryGroupId)).thenReturn(Optional.empty());
doNothing().when(queryGroupService).rejectIfNeeded(testQueryGroupId);

sut.onRequestStart(null);
}
Expand Down

0 comments on commit a3df783

Please sign in to comment.