Skip to content

Commit

Permalink
Migrate RAG pipeline to async processing.
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Lee <[email protected]>
  • Loading branch information
austintlee committed Apr 21, 2024
1 parent fc555c0 commit 59193e9
Show file tree
Hide file tree
Showing 7 changed files with 546 additions and 190 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,23 @@
import java.util.Map;
import java.util.function.BooleanSupplier;

import org.opensearch.OpenSearchException;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.search.SearchHit;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.searchpipelines.questionanswering.generative.client.ConversationalMemoryClient;
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamUtil;
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters;
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionInput;
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput;
import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm;
import org.opensearch.searchpipelines.questionanswering.generative.llm.LlmIOUtil;
Expand All @@ -65,8 +68,6 @@ public class GenerativeQAResponseProcessor extends AbstractProcessor implements

private static final int DEFAULT_PROCESSOR_TIME_IN_SECONDS = 30;

// TODO Add "interaction_count". This is how far back in chat history we want to go back when calling LLM.

private final String llmModel;
private final List<String> contextFields;

Expand Down Expand Up @@ -106,8 +107,18 @@ protected GenerativeQAResponseProcessor(
}

@Override
public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception {
public SearchResponse processResponse(SearchRequest searchRequest, SearchResponse searchResponse) {
// Synchronous call is no longer supported because this execution can occur on a transport thread.
throw new UnsupportedOperationException();
}

@Override
public void processResponseAsync(
SearchRequest request,
SearchResponse response,
PipelineProcessingContext requestContext,
ActionListener<SearchResponse> responseListener
) {
log.info("Entering processResponse.");

if (!this.featureFlagSupplier.getAsBoolean()) {
Expand All @@ -116,10 +127,12 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp

GenerativeQAParameters params = GenerativeQAParamUtil.getGenerativeQAParameters(request);

Integer timeout = params.getTimeout();
if (timeout == null || timeout == GenerativeQAParameters.SIZE_NULL_VALUE) {
timeout = DEFAULT_PROCESSOR_TIME_IN_SECONDS;
Integer t = params.getTimeout();
if (t == null || t == GenerativeQAParameters.SIZE_NULL_VALUE) {
t = DEFAULT_PROCESSOR_TIME_IN_SECONDS;
}
final int timeout = t;
log.info("Timeout for this request: {} seconds.", timeout);

String llmQuestion = params.getLlmQuestion();
String llmModel = params.getLlmModel() == null ? this.llmModel : params.getLlmModel();
Expand All @@ -128,14 +141,16 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
}
String conversationId = params.getConversationId();

if (conversationId != null && !Strings.hasText(conversationId)) {
throw new IllegalArgumentException("Empty conversation_id is not allowed.");
}
// log.info("LLM question: {}, LLM model {}, conversation id: {}", llmQuestion, llmModel, conversationId);
Instant start = Instant.now();
Integer interactionSize = params.getInteractionSize();
if (interactionSize == null || interactionSize == GenerativeQAParameters.SIZE_NULL_VALUE) {
interactionSize = DEFAULT_CHAT_HISTORY_WINDOW;
}
List<Interaction> chatHistory = (conversationId == null)
? Collections.emptyList()
: memoryClient.getInteractions(conversationId, interactionSize);
log.info("Using interaction size of {}", interactionSize);

Integer topN = params.getContextSize();
if (topN == null) {
Expand All @@ -153,10 +168,35 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
effectiveUserInstructions = params.getUserInstructions();
}

start = Instant.now();
try {
ChatCompletionOutput output = llm
.doChatCompletion(
// log.info("system_prompt: {}", systemPrompt);
// log.info("user_instructions: {}", userInstructions);

final List<Interaction> chatHistory = new ArrayList<>();
if (conversationId == null) {
doChatCompletion(
LlmIOUtil
.createChatCompletionInput(
systemPrompt,
userInstructions,
llmModel,
llmQuestion,
chatHistory,
searchResults,
timeout,
params.getLlmResponseField()
),
null,
llmQuestion,
searchResults,
response,
responseListener
);
} else {
final Instant memoryStart = Instant.now();
memoryClient.getInteractions(conversationId, interactionSize, ActionListener.wrap(r -> {
log.info("getInteractions complete. ({})", getDuration(memoryStart));
chatHistory.addAll(r);
doChatCompletion(
LlmIOUtil
.createChatCompletionInput(
systemPrompt,
Expand All @@ -167,53 +207,82 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
searchResults,
timeout,
params.getLlmResponseField()
)
),
conversationId,
llmQuestion,
searchResults,
response,
responseListener
);
log.info("doChatCompletion complete. ({})", getDuration(start));
}, responseListener::onFailure));
}
}

String answer = null;
String errorMessage = null;
String interactionId = null;
if (output.isErrorOccurred()) {
errorMessage = output.getErrors().get(0);
} else {
answer = (String) output.getAnswers().get(0);
private void doChatCompletion(
ChatCompletionInput input,
String conversationId,
String llmQuestion,
List<String> searchResults,
SearchResponse response,
ActionListener<SearchResponse> responseListener
) {

final Instant chatStart = Instant.now();
llm.doChatCompletion(input, new ActionListener<>() {
@Override
public void onResponse(ChatCompletionOutput output) {
log.info("doChatCompletion complete. ({})", getDuration(chatStart));

final String answer = getAnswer(output);
final String errorMessage = getError(output);

if (conversationId != null) {
start = Instant.now();
interactionId = memoryClient
final Instant memoryStart = Instant.now();
memoryClient
.createInteraction(
conversationId,
llmQuestion,
PromptUtil.getPromptTemplate(systemPrompt, userInstructions),
answer,
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
Collections.singletonMap("metadata", jsonArrayToString(searchResults))
Collections.singletonMap("metadata", jsonArrayToString(searchResults)),
ActionListener.wrap(r -> {
responseListener.onResponse(insertAnswer(response, answer, errorMessage, r));
log.info("Created a new interaction: {} ({})", r, getDuration(memoryStart));
}, responseListener::onFailure)
);
log.info("Created a new interaction: {} ({})", interactionId, getDuration(start));

} else {
responseListener.onResponse(insertAnswer(response, answer, errorMessage, null));
}

}

return insertAnswer(response, answer, errorMessage, interactionId);
} catch (NullPointerException nullPointerException) {
throw new IllegalArgumentException(IllegalArgumentMessage);
} catch (Exception e) {
throw new OpenSearchException("GenerativeQAResponseProcessor failed in precessing response");
}
}
@Override
public void onFailure(Exception e) {
responseListener.onFailure(e);
}

long getDuration(Instant start) {
return Duration.between(start, Instant.now()).toMillis();
private String getError(ChatCompletionOutput output) {
return output.isErrorOccurred() ? output.getErrors().get(0) : null;
}

private String getAnswer(ChatCompletionOutput output) {
return output.isErrorOccurred() ? null : (String) output.getAnswers().get(0);
}
});
}

@Override
public String getType() {
return GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE;
}

private SearchResponse insertAnswer(SearchResponse response, String answer, String errorMessage, String interactionId) {
private long getDuration(Instant start) {
return Duration.between(start, Instant.now()).toMillis();
}

// TODO return the interaction id in the response.
private SearchResponse insertAnswer(SearchResponse response, String answer, String errorMessage, String interactionId) {

return new GenerativeSearchResponse(
answer,
Expand All @@ -240,9 +309,7 @@ private List<String> getSearchResults(SearchResponse response, Integer topN) {
for (String contextField : contextFields) {
Object context = docSourceMap.get(contextField);
if (context == null) {
log.error("Context " + contextField + " not found in search hit " + hits[i]);
// TODO throw a more meaningful error here?
throw new RuntimeException();
throw new RuntimeException("Context " + contextField + " not found in search hit " + hits[i]);
}
searchResults.add(context.toString());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.ml.memory.action.conversation.CreateConversationAction;
Expand Down Expand Up @@ -83,6 +84,33 @@ public String createInteraction(
return res.getId();
}

public void createInteraction(
String conversationId,
String input,
String promptTemplate,
String response,
String origin,
Map<String, String> additionalInfo,
ActionListener<String> listener
) {
client
.execute(
CreateInteractionAction.INSTANCE,
new CreateInteractionRequest(conversationId, input, promptTemplate, response, origin, additionalInfo),
new ActionListener<CreateInteractionResponse>() {
@Override
public void onResponse(CreateInteractionResponse createInteractionResponse) {
listener.onResponse(createInteractionResponse.getId());
}

@Override
public void onFailure(Exception e) {
listener.onFailure(e);
}
}
);
}

public List<Interaction> getInteractions(String conversationId, int lastN) {

Preconditions.checkArgument(lastN > 0, "lastN must be at least 1.");
Expand Down Expand Up @@ -113,4 +141,23 @@ public List<Interaction> getInteractions(String conversationId, int lastN) {

return interactions;
}

public void getInteractions(String conversationId, int lastN, ActionListener<List<Interaction>> listener) {
client
.execute(
GetInteractionsAction.INSTANCE,
new GetInteractionsRequest(conversationId, lastN, 0),
new ActionListener<GetInteractionsResponse>() {
@Override
public void onResponse(GetInteractionsResponse getInteractionsResponse) {
listener.onResponse(getInteractionsResponse.getInteractions());
}

@Override
public void onFailure(Exception e) {
listener.onFailure(e);
}
}
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public ActionFuture<MLOutput> predict(String modelId, MLInput mlInput) {
}

@VisibleForTesting
void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
public void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
validateMLInput(mlInput, true);

MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import java.util.Map;

import org.opensearch.client.Client;
import org.opensearch.common.action.ActionFuture;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
Expand Down Expand Up @@ -75,20 +75,36 @@ protected void setMlClient(MachineLearningInternalClient mlClient) {
* @return
*/
@Override
public ChatCompletionOutput doChatCompletion(ChatCompletionInput chatCompletionInput) {

public void doChatCompletion(ChatCompletionInput chatCompletionInput, ActionListener<ChatCompletionOutput> listener) {
MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(getInputParameters(chatCompletionInput)).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataset).build();
ActionFuture<MLOutput> future = mlClient.predict(this.openSearchModelId, mlInput);
ModelTensorOutput modelOutput = (ModelTensorOutput) future.actionGet(chatCompletionInput.getTimeoutInSeconds() * 1000);

// Response from a remote model
Map<String, ?> dataAsMap = modelOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap();
// log.info("dataAsMap: {}", dataAsMap.toString());

// TODO dataAsMap can be null or can contain information such as throttling. Handle non-happy cases.
mlClient.predict(this.openSearchModelId, mlInput, new ActionListener<>() {
@Override
public void onResponse(MLOutput mlOutput) {
// Response from a remote model
Map<String, ?> dataAsMap = ((ModelTensorOutput) mlOutput)
.getMlModelOutputs()
.get(0)
.getMlModelTensors()
.get(0)
.getDataAsMap();
// log.info("dataAsMap: {}", dataAsMap.toString());
listener
.onResponse(
buildChatCompletionOutput(
chatCompletionInput.getModelProvider(),
dataAsMap,
chatCompletionInput.getLlmResponseField()
)
);
}

return buildChatCompletionOutput(chatCompletionInput.getModelProvider(), dataAsMap, chatCompletionInput.getLlmResponseField());
@Override
public void onFailure(Exception e) {
listener.onFailure(e);
}
});
}

protected Map<String, String> getInputParameters(ChatCompletionInput chatCompletionInput) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
*/
package org.opensearch.searchpipelines.questionanswering.generative.llm;

import org.opensearch.core.action.ActionListener;

/**
* Capabilities of large language models, e.g. completion, embeddings, etc.
*/
Expand All @@ -29,5 +31,5 @@ enum ModelProvider {
COHERE
}

ChatCompletionOutput doChatCompletion(ChatCompletionInput input);
void doChatCompletion(ChatCompletionInput input, ActionListener<ChatCompletionOutput> listener);
}
Loading

0 comments on commit 59193e9

Please sign in to comment.