Skip to content

Commit

Permalink
Fix MLRAGSearchProcessorIT not to extend RestMLRemoteInferenceIT.
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Lee <[email protected]>
  • Loading branch information
austintlee committed Sep 5, 2024
1 parent e6b21da commit 06b74cb
Showing 1 changed file with 25 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
*/
package org.opensearch.ml.rest;

import static org.opensearch.ml.rest.RestMLRemoteInferenceIT.createConnector;
import static org.opensearch.ml.rest.RestMLRemoteInferenceIT.deployRemoteModel;
import static org.opensearch.ml.utils.TestHelper.makeRequest;
import static org.opensearch.ml.utils.TestHelper.toHttpEntity;

Expand All @@ -41,7 +43,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT {
public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase {

private static final String OPENAI_KEY = System.getenv("OPENAI_KEY");
private static final String OPENAI_CONNECTOR_BLUEPRINT = "{\n"
Expand Down Expand Up @@ -526,11 +528,11 @@ public void testBM25WithOpenAI() throws Exception {
Response response = createConnector(OPENAI_CONNECTOR_BLUEPRINT);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId);
response = RestMLRemoteInferenceIT.registerRemoteModel("openAI-GPT-3.5 completions", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
response = RestMLRemoteInferenceIT.getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");
response = deployRemoteModel(modelId);
Expand Down Expand Up @@ -580,11 +582,11 @@ public void testBM25WithOpenAIWithImage() throws Exception {
Response response = createConnector(OPENAI_4o_CONNECTOR_BLUEPRINT);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = registerRemoteModel("openAI-GPT-4o-mini completions", connectorId);
response = RestMLRemoteInferenceIT.registerRemoteModel("openAI-GPT-4o-mini completions", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
response = RestMLRemoteInferenceIT.getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");
response = deployRemoteModel(modelId);
Expand Down Expand Up @@ -667,11 +669,11 @@ public void testBM25WithBedrock() throws Exception {
Response response = createConnector(BEDROCK_CONNECTOR_BLUEPRINT);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = registerRemoteModel("Bedrock Anthropic Claude", connectorId);
response = RestMLRemoteInferenceIT.registerRemoteModel("Bedrock Anthropic Claude", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
response = RestMLRemoteInferenceIT.getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");
response = deployRemoteModel(modelId);
Expand Down Expand Up @@ -719,11 +721,11 @@ public void testBM25WithBedrockConverse() throws Exception {
Response response = createConnector(BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = registerRemoteModel("Bedrock Anthropic Claude", connectorId);
response = RestMLRemoteInferenceIT.registerRemoteModel("Bedrock Anthropic Claude", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
response = RestMLRemoteInferenceIT.getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");
response = deployRemoteModel(modelId);
Expand Down Expand Up @@ -771,11 +773,11 @@ public void testBM25WithBedrockConverseUsingLlmMessages() throws Exception {
Response response = createConnector(BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT2);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = registerRemoteModel("Bedrock Anthropic Claude", connectorId);
response = RestMLRemoteInferenceIT.registerRemoteModel("Bedrock Anthropic Claude", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
response = RestMLRemoteInferenceIT.getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");
response = deployRemoteModel(modelId);
Expand Down Expand Up @@ -831,11 +833,11 @@ public void testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat() throws
Response response = createConnector(BEDROCK_DOCUMENT_CONVERSE_CONNECTOR_BLUEPRINT2);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = registerRemoteModel("Bedrock Anthropic Claude", connectorId);
response = RestMLRemoteInferenceIT.registerRemoteModel("Bedrock Anthropic Claude", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
response = RestMLRemoteInferenceIT.getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");
response = deployRemoteModel(modelId);
Expand Down Expand Up @@ -890,11 +892,11 @@ public void testBM25WithOpenAIWithConversation() throws Exception {
Response response = createConnector(OPENAI_CONNECTOR_BLUEPRINT);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId);
response = RestMLRemoteInferenceIT.registerRemoteModel("openAI-GPT-3.5 completions", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
response = RestMLRemoteInferenceIT.getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");
response = deployRemoteModel(modelId);
Expand Down Expand Up @@ -947,11 +949,11 @@ public void testBM25WithOpenAIWithConversationAndImage() throws Exception {
Response response = createConnector(OPENAI_4o_CONNECTOR_BLUEPRINT);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = registerRemoteModel("openAI-GPT-4 completions", connectorId);
response = RestMLRemoteInferenceIT.registerRemoteModel("openAI-GPT-4 completions", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
response = RestMLRemoteInferenceIT.getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");
response = deployRemoteModel(modelId);
Expand Down Expand Up @@ -1008,11 +1010,11 @@ public void testBM25WithBedrockWithConversation() throws Exception {
Response response = createConnector(BEDROCK_CONNECTOR_BLUEPRINT);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = registerRemoteModel("Bedrock", connectorId);
response = RestMLRemoteInferenceIT.registerRemoteModel("Bedrock", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
response = RestMLRemoteInferenceIT.getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");
response = deployRemoteModel(modelId);
Expand Down Expand Up @@ -1065,11 +1067,11 @@ public void testBM25WithCohere() throws Exception {
Response response = createConnector(COHERE_CONNECTOR_BLUEPRINT);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = registerRemoteModel("Cohere Chat Completion v1", connectorId);
response = RestMLRemoteInferenceIT.registerRemoteModel("Cohere Chat Completion v1", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
response = RestMLRemoteInferenceIT.getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");
response = deployRemoteModel(modelId);
Expand Down Expand Up @@ -1117,11 +1119,11 @@ public void testBM25WithCohereUsingLlmResponseField() throws Exception {
Response response = createConnector(COHERE_CONNECTOR_BLUEPRINT);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = registerRemoteModel("Cohere Chat Completion v1", connectorId);
response = RestMLRemoteInferenceIT.registerRemoteModel("Cohere Chat Completion v1", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
response = RestMLRemoteInferenceIT.getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");
response = deployRemoteModel(modelId);
Expand Down

0 comments on commit 06b74cb

Please sign in to comment.