diff --git a/plugin/build.gradle b/plugin/build.gradle index 47ce5599cb..e835ba4e0f 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -572,8 +572,3 @@ task bwcTestSuite(type: StandaloneRestIntegTestTask) { dependsOn tasks.named("${baseName}#rollingUpgradeClusterTask") dependsOn tasks.named("${baseName}#fullRestartClusterTask") } - -forbiddenPatterns { - exclude '**/*.pdf' - exclude '**/*.jpg' -} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java index b97d73ae8f..63c2a7634f 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java @@ -17,24 +17,19 @@ */ package org.opensearch.ml.rest; -import static org.opensearch.ml.rest.RestMLRemoteInferenceIT.createConnector; -import static org.opensearch.ml.rest.RestMLRemoteInferenceIT.deployRemoteModel; import static org.opensearch.ml.utils.TestHelper.makeRequest; import static org.opensearch.ml.utils.TestHelper.toHttpEntity; import java.nio.file.Files; import java.nio.file.Path; -import java.util.Base64; import java.util.Locale; import java.util.Map; import java.util.Set; -import org.apache.commons.io.FileUtils; import org.apache.http.HttpHeaders; import org.apache.http.message.BasicHeader; import org.apache.http.util.EntityUtils; import org.junit.Before; -import org.junit.Ignore; import org.opensearch.client.Response; import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.MLTaskState; @@ -44,7 +39,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase { +public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT { private static final String OPENAI_KEY = System.getenv("OPENAI_KEY"); private static final String OPENAI_CONNECTOR_BLUEPRINT = "{\n" @@ -75,42 +70,11 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase { + " ]\n" + "}"; - private static final String OPENAI_4o_CONNECTOR_BLUEPRINT = "{\n" - + " \"name\": \"OpenAI Chat Connector\",\n" - + " \"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n" - + " \"version\": 2,\n" - + " \"protocol\": \"http\",\n" - + " \"parameters\": {\n" - + " \"endpoint\": \"api.openai.com\",\n" - + " \"model\": \"gpt-4o-mini\",\n" - + " \"temperature\": 0\n" - + " },\n" - + " \"credential\": {\n" - + " \"openAI_key\": \"" - + OPENAI_KEY - + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://${parameters.endpoint}/v1/chat/completions\",\n" - + " \"headers\": {\n" - + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages}, \\\"temperature\\\": ${parameters.temperature} , \\\"max_tokens\\\": 300 }\"\n" - + " }\n" - + " ]\n" - + "}"; - private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN"); private static final String GITHUB_CI_AWS_REGION = "us-west-2"; - private static final String BEDROCK_ANTHROPIC_CLAUDE_3_5_SONNET = "anthropic.claude-3-5-sonnet-20240620-v1:0"; - private static final String BEDROCK_ANTHROPIC_CLAUDE_3_SONNET = "anthropic.claude-3-sonnet-20240229-v1:0"; - private static final String BEDROCK_CONNECTOR_BLUEPRINT1 = "{\n" + " \"name\": \"Bedrock Connector: claude2\",\n" + " \"description\": \"The connector to bedrock claude2 model\",\n" @@ -181,100 +145,10 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase { + " ]\n" + "}"; - private static final String BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT2 = "{\n" - + " \"name\": \"Bedrock Connector: claude 3.5\",\n" - + " \"description\": \"The connector to bedrock claude 3.5 model\",\n" - + " \"version\": 1,\n" - + " \"protocol\": \"aws_sigv4\",\n" - + " \"parameters\": {\n" - + " \"region\": \"" - + GITHUB_CI_AWS_REGION - + "\",\n" - + " \"service_name\": \"bedrock\",\n" - + " \"model\": \"" - + BEDROCK_ANTHROPIC_CLAUDE_3_5_SONNET - + "\",\n" - + " \"system_prompt\": \"You are a helpful assistant.\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"access_key\": \"" - + AWS_ACCESS_KEY_ID - + "\",\n" - + " \"secret_key\": \"" - + AWS_SECRET_ACCESS_KEY - + "\",\n" - + " \"session_token\": \"" - + AWS_SESSION_TOKEN - + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"headers\": {\n" - + " \"content-type\": \"application/json\"\n" - + " },\n" - + " \"url\": \"https://bedrock-runtime." - + GITHUB_CI_AWS_REGION - + ".amazonaws.com/model/" - + BEDROCK_ANTHROPIC_CLAUDE_3_5_SONNET - + "/converse\",\n" - + " \"request_body\": \"{ \\\"system\\\": [{\\\"text\\\": \\\"you are a helpful assistant.\\\"}], \\\"messages\\\": ${parameters.messages} , \\\"inferenceConfig\\\": {\\\"temperature\\\": 0.0, \\\"topP\\\": 0.9, \\\"maxTokens\\\": 1000} }\"\n" - + " }\n" - + " ]\n" - + "}"; - - private static final String BEDROCK_DOCUMENT_CONVERSE_CONNECTOR_BLUEPRINT2 = "{\n" - + " \"name\": \"Bedrock Connector: claude 3\",\n" - + " \"description\": \"The connector to bedrock claude 3 model\",\n" - + " \"version\": 1,\n" - + " \"protocol\": \"aws_sigv4\",\n" - + " \"parameters\": {\n" - + " \"region\": \"" - + GITHUB_CI_AWS_REGION - + "\",\n" - + " \"service_name\": \"bedrock\",\n" - + " \"model\": \"" - + BEDROCK_ANTHROPIC_CLAUDE_3_SONNET - + "\",\n" - + " \"system_prompt\": \"You are a helpful assistant.\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"access_key\": \"" - + AWS_ACCESS_KEY_ID - + "\",\n" - + " \"secret_key\": \"" - + AWS_SECRET_ACCESS_KEY - + "\",\n" - + " \"session_token\": \"" - + AWS_SESSION_TOKEN - + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"headers\": {\n" - + " \"content-type\": \"application/json\"\n" - + " },\n" - + " \"url\": \"https://bedrock-runtime." - + GITHUB_CI_AWS_REGION - + ".amazonaws.com/model/" - + BEDROCK_ANTHROPIC_CLAUDE_3_SONNET - + "/converse\",\n" - + " \"request_body\": \"{ \\\"messages\\\": ${parameters.messages} , \\\"inferenceConfig\\\": {\\\"temperature\\\": 0.0, \\\"topP\\\": 0.9, \\\"maxTokens\\\": 1000} }\"\n" - + " }\n" - + " ]\n" - + "}"; - private static final String BEDROCK_CONNECTOR_BLUEPRINT = AWS_SESSION_TOKEN == null ? BEDROCK_CONNECTOR_BLUEPRINT2 : BEDROCK_CONNECTOR_BLUEPRINT1; - private static final String BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT = AWS_SESSION_TOKEN == null - ? BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT2 - : BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT2; - private static final String COHERE_KEY = System.getenv("COHERE_KEY"); private static final String COHERE_CONNECTOR_BLUEPRINT = "{\n" + " \"name\": \"Cohere Chat Model\",\n" @@ -318,22 +192,6 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase { + " ]\n" + "}"; - // In some cases, we do not want a system prompt to be sent to an LLM. - private static final String PIPELINE_TEMPLATE2 = "{\n" - + " \"response_processors\": [\n" - + " {\n" - + " \"retrieval_augmented_generation\": {\n" - + " \"tag\": \"%s\",\n" - + " \"description\": \"%s\",\n" - + " \"model_id\": \"%s\",\n" - // + " \"system_prompt\": \"%s\",\n" - + " \"user_instructions\": \"%s\",\n" - + " \"context_field_list\": [\"%s\"]\n" - + " }\n" - + " }\n" - + " ]\n" - + "}"; - private static final String BM25_SEARCH_REQUEST_TEMPLATE = "{\n" + " \"_source\": [\"%s\"],\n" + " \"query\" : {\n" @@ -352,63 +210,6 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase { + " }\n" + "}"; - private static final String BM25_SEARCH_REQUEST_WITH_IMAGE_TEMPLATE = "{\n" - + " \"_source\": [\"%s\"],\n" - + " \"query\" : {\n" - + " \"match\": {\"%s\": \"%s\"}\n" - + " },\n" - + " \"ext\": {\n" - + " \"generative_qa_parameters\": {\n" - + " \"llm_model\": \"%s\",\n" - + " \"llm_question\": \"%s\",\n" - + " \"system_prompt\": \"%s\",\n" - + " \"user_instructions\": \"%s\",\n" - + " \"context_size\": %d,\n" - + " \"message_size\": %d,\n" - + " \"timeout\": %d,\n" - + " \"llm_messages\": [{ \"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"%s\"}, {\"image\": {\"format\": \"%s\", \"%s\": \"%s\"}}] }]\n" - + " }\n" - + " }\n" - + "}"; - - private static final String BM25_SEARCH_REQUEST_WITH_DOCUMENT_TEMPLATE = "{\n" - + " \"_source\": [\"%s\"],\n" - + " \"query\" : {\n" - + " \"match\": {\"%s\": \"%s\"}\n" - + " },\n" - + " \"ext\": {\n" - + " \"generative_qa_parameters\": {\n" - + " \"llm_model\": \"%s\",\n" - + " \"llm_question\": \"%s\",\n" - // + " \"system_prompt\": \"%s\",\n" - + " \"user_instructions\": \"%s\",\n" - + " \"context_size\": %d,\n" - + " \"message_size\": %d,\n" - + " \"timeout\": %d,\n" - + " \"llm_messages\": [{ \"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"%s\"}, {\"document\": {\"format\": \"%s\", \"name\": \"%s\", \"data\": \"%s\"}}] }]\n" - + " }\n" - + " }\n" - + "}"; - - private static final String BM25_SEARCH_REQUEST_WITH_IMAGE_AND_DOCUMENT_TEMPLATE = "{\n" - + " \"_source\": [\"%s\"],\n" - + " \"query\" : {\n" - + " \"match\": {\"%s\": \"%s\"}\n" - + " },\n" - + " \"ext\": {\n" - + " \"generative_qa_parameters\": {\n" - + " \"llm_model\": \"%s\",\n" - + " \"llm_question\": \"%s\",\n" - + " \"system_prompt\": \"%s\",\n" - + " \"user_instructions\": \"%s\",\n" - + " \"context_size\": %d,\n" - + " \"message_size\": %d,\n" - + " \"timeout\": %d,\n" - + " \"llm_messages\": [{ \"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"%s\"}, {\"image\": {\"format\": \"%s\", \"%s\": \"%s\"}} , {\"document\": {\"format\": \"%s\", \"name\": \"%s\", \"data\": \"%s\"}}] }]\n" - + " }\n" - + " }\n" - + "}"; - private static final String BM25_SEARCH_REQUEST_WITH_CONVO_TEMPLATE = "{\n" + " \"_source\": [\"%s\"],\n" + " \"query\" : {\n" @@ -428,26 +229,6 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase { + " }\n" + "}"; - private static final String BM25_SEARCH_REQUEST_WITH_CONVO_AND_IMAGE_TEMPLATE = "{\n" - + " \"_source\": [\"%s\"],\n" - + " \"query\" : {\n" - + " \"match\": {\"%s\": \"%s\"}\n" - + " },\n" - + " \"ext\": {\n" - + " \"generative_qa_parameters\": {\n" - + " \"llm_model\": \"%s\",\n" - + " \"llm_question\": \"%s\",\n" - + " \"memory_id\": \"%s\",\n" - + " \"system_prompt\": \"%s\",\n" - + " \"user_instructions\": \"%s\",\n" - + " \"context_size\": %d,\n" - + " \"message_size\": %d,\n" - + " \"timeout\": %d,\n" - + " \"llm_messages\": [{ \"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"%s\"}, {\"image\": {\"format\": \"%s\", \"%s\": \"%s\"}}] }]\n" - + " }\n" - + " }\n" - + "}"; - private static final String BM25_SEARCH_REQUEST_WITH_LLM_RESPONSE_FIELD_TEMPLATE = "{\n" + " \"_source\": [\"%s\"],\n" + " \"query\" : {\n" @@ -466,28 +247,18 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase { + "}"; private static final String OPENAI_MODEL = "gpt-3.5-turbo"; - private static final String OPENAI_40_MODEL = "gpt-4o-mini"; private static final String BEDROCK_ANTHROPIC_CLAUDE = "bedrock/anthropic-claude"; - private static final String BEDROCK_CONVERSE_ANTHROPIC_CLAUDE = "bedrock-converse/" + BEDROCK_ANTHROPIC_CLAUDE_3_5_SONNET; - private static final String BEDROCK_CONVERSE_ANTHROPIC_CLAUDE_3 = "bedrock-converse/" + BEDROCK_ANTHROPIC_CLAUDE_3_SONNET; private static final String TEST_DOC_PATH = "org/opensearch/ml/rest/test_data/"; private static Set testDocs = Set.of("qa_doc1.json", "qa_doc2.json", "qa_doc3.json"); private static final String DEFAULT_USER_AGENT = "Kibana"; protected ClassLoader classLoader = RestMLRAGSearchProcessorIT.class.getClassLoader(); private static final String INDEX_NAME = "test"; - private static final String ML_RAG_REMOTE_MODEL_GROUP = "rag_remote_model_group"; - // "client" gets initialized by the test framework at the instance level // so we perform this per test case, not via @BeforeClass. @Before public void init() throws Exception { - RestMLRemoteInferenceIT.disableClusterConnectorAccessControl(); - // TODO Do we really need to wait this long? This adds 20s to every test case run. - // Can we instead check the cluster state and move on? - Thread.sleep(20000); - Response response = TestHelper .makeRequest( client(), @@ -536,11 +307,11 @@ public void testBM25WithOpenAI() throws Exception { Response response = createConnector(OPENAI_CONNECTOR_BLUEPRINT); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); - response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "openAI-GPT-3.5 completions", connectorId); + response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId); responseMap = parseResponseToMap(response); String taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); - response = RestMLRemoteInferenceIT.getTask(taskId); + response = getTask(taskId); responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); response = deployRemoteModel(modelId); @@ -582,94 +353,6 @@ public void testBM25WithOpenAI() throws Exception { assertNotNull(answer); } - @Ignore - public void testBM25WithOpenAIWithImage() throws Exception { - // Skip test if key is null - if (OPENAI_KEY == null) { - return; - } - Response response = createConnector(OPENAI_4o_CONNECTOR_BLUEPRINT); - Map responseMap = parseResponseToMap(response); - String connectorId = (String) responseMap.get("connector_id"); - response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "openAI-GPT-4o-mini completions", connectorId); - responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = RestMLRemoteInferenceIT.getTask(taskId); - responseMap = parseResponseToMap(response); - String modelId = (String) responseMap.get("model_id"); - response = deployRemoteModel(modelId); - responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - - PipelineParameters pipelineParameters = new PipelineParameters(); - pipelineParameters.tag = "testBM25WithOpenAIWithImage"; - pipelineParameters.description = "desc"; - pipelineParameters.modelId = modelId; - pipelineParameters.systemPrompt = "You are a helpful assistant"; - pipelineParameters.userInstructions = "none"; - pipelineParameters.context_field = "text"; - Response response1 = createSearchPipeline("pipeline_test", pipelineParameters); - assertEquals(200, response1.getStatusLine().getStatusCode()); - - byte[] rawImage = FileUtils - .readFileToByteArray(Path.of(classLoader.getResource(TEST_DOC_PATH + "openai_boardwalk.jpg").toURI()).toFile()); - String imageContent = Base64.getEncoder().encodeToString(rawImage); - - SearchRequestParameters requestParameters = new SearchRequestParameters(); - requestParameters.source = "text"; - requestParameters.match = "president"; - requestParameters.llmModel = OPENAI_40_MODEL; - requestParameters.llmQuestion = "what is this image"; - requestParameters.systemPrompt = "You are great at answering questions"; - requestParameters.userInstructions = "Follow my instructions as best you can"; - requestParameters.contextSize = 5; - requestParameters.interactionSize = 5; - requestParameters.timeout = 60; - requestParameters.imageFormat = "jpeg"; - requestParameters.imageType = "data"; - requestParameters.imageData = imageContent; - Response response2 = performSearch(INDEX_NAME, "pipeline_test", 5, requestParameters); - assertEquals(200, response2.getStatusLine().getStatusCode()); - - Map responseMap2 = parseResponseToMap(response2); - Map ext = (Map) responseMap2.get("ext"); - assertNotNull(ext); - Map rag = (Map) ext.get("retrieval_augmented_generation"); - assertNotNull(rag); - - // TODO handle errors such as throttling - String answer = (String) rag.get("answer"); - assertNotNull(answer); - - requestParameters = new SearchRequestParameters(); - requestParameters.source = "text"; - requestParameters.match = "president"; - requestParameters.llmModel = OPENAI_40_MODEL; - requestParameters.llmQuestion = "what is this image"; - requestParameters.systemPrompt = "You are great at answering questions"; - requestParameters.userInstructions = "Follow my instructions as best you can"; - requestParameters.contextSize = 5; - requestParameters.interactionSize = 5; - requestParameters.timeout = 60; - requestParameters.imageFormat = "jpeg"; - requestParameters.imageType = "url"; - requestParameters.imageData = - "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"; // imageContent; - Response response3 = performSearch(INDEX_NAME, "pipeline_test", 5, requestParameters); - assertEquals(200, response2.getStatusLine().getStatusCode()); - - Map responseMap3 = parseResponseToMap(response3); - ext = (Map) responseMap2.get("ext"); - assertNotNull(ext); - rag = (Map) ext.get("retrieval_augmented_generation"); - assertNotNull(rag); - - answer = (String) rag.get("answer"); - assertNotNull(answer); - } - public void testBM25WithBedrock() throws Exception { // Skip test if key is null if (AWS_ACCESS_KEY_ID == null) { @@ -678,11 +361,11 @@ public void testBM25WithBedrock() throws Exception { Response response = createConnector(BEDROCK_CONNECTOR_BLUEPRINT); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); - response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "Bedrock Anthropic Claude", connectorId); + response = registerRemoteModel("Bedrock Anthropic Claude", connectorId); responseMap = parseResponseToMap(response); String taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); - response = RestMLRemoteInferenceIT.getTask(taskId); + response = getTask(taskId); responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); response = deployRemoteModel(modelId); @@ -691,7 +374,7 @@ public void testBM25WithBedrock() throws Exception { waitForTask(taskId, MLTaskState.COMPLETED); PipelineParameters pipelineParameters = new PipelineParameters(); - pipelineParameters.tag = "testBM25WithBedrock"; + pipelineParameters.tag = "testBM25WithOpenAI"; pipelineParameters.description = "desc"; pipelineParameters.modelId = modelId; pipelineParameters.systemPrompt = "You are a helpful assistant"; @@ -722,180 +405,6 @@ public void testBM25WithBedrock() throws Exception { assertNotNull(answer); } - @Ignore - public void testBM25WithBedrockConverse() throws Exception { - // Skip test if key is null - if (AWS_ACCESS_KEY_ID == null) { - return; - } - Response response = createConnector(BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT); - Map responseMap = parseResponseToMap(response); - String connectorId = (String) responseMap.get("connector_id"); - response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "Bedrock Anthropic Claude", connectorId); - responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = RestMLRemoteInferenceIT.getTask(taskId); - responseMap = parseResponseToMap(response); - String modelId = (String) responseMap.get("model_id"); - response = deployRemoteModel(modelId); - responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - - PipelineParameters pipelineParameters = new PipelineParameters(); - pipelineParameters.tag = "testBM25WithBedrockConverse"; - pipelineParameters.description = "desc"; - pipelineParameters.modelId = modelId; - pipelineParameters.systemPrompt = "You are a helpful assistant"; - pipelineParameters.userInstructions = "none"; - pipelineParameters.context_field = "text"; - Response response1 = createSearchPipeline("pipeline_test", pipelineParameters); - assertEquals(200, response1.getStatusLine().getStatusCode()); - - SearchRequestParameters requestParameters = new SearchRequestParameters(); - requestParameters.source = "text"; - requestParameters.match = "president"; - requestParameters.llmModel = BEDROCK_CONVERSE_ANTHROPIC_CLAUDE; - requestParameters.llmQuestion = "who is lincoln"; - requestParameters.contextSize = 5; - requestParameters.interactionSize = 5; - requestParameters.timeout = 60; - Response response2 = performSearch(INDEX_NAME, "pipeline_test", 5, requestParameters); - assertEquals(200, response2.getStatusLine().getStatusCode()); - - Map responseMap2 = parseResponseToMap(response2); - Map ext = (Map) responseMap2.get("ext"); - assertNotNull(ext); - Map rag = (Map) ext.get("retrieval_augmented_generation"); - assertNotNull(rag); - - // TODO handle errors such as throttling - String answer = (String) rag.get("answer"); - assertNotNull(answer); - } - - @Ignore - public void testBM25WithBedrockConverseUsingLlmMessages() throws Exception { - // Skip test if key is null - if (AWS_ACCESS_KEY_ID == null) { - return; - } - Response response = createConnector(BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT2); - Map responseMap = parseResponseToMap(response); - String connectorId = (String) responseMap.get("connector_id"); - response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "Bedrock Anthropic Claude", connectorId); - responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = RestMLRemoteInferenceIT.getTask(taskId); - responseMap = parseResponseToMap(response); - String modelId = (String) responseMap.get("model_id"); - response = deployRemoteModel(modelId); - responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - - PipelineParameters pipelineParameters = new PipelineParameters(); - pipelineParameters.tag = "testBM25WithBedrockConverseUsingLlmMessages"; - pipelineParameters.description = "desc"; - pipelineParameters.modelId = modelId; - pipelineParameters.systemPrompt = "You are a helpful assistant"; - pipelineParameters.userInstructions = "none"; - pipelineParameters.context_field = "text"; - Response response1 = createSearchPipeline("pipeline_test", pipelineParameters); - assertEquals(200, response1.getStatusLine().getStatusCode()); - - byte[] rawImage = FileUtils - .readFileToByteArray(Path.of(classLoader.getResource(TEST_DOC_PATH + "openai_boardwalk.jpg").toURI()).toFile()); - String imageContent = Base64.getEncoder().encodeToString(rawImage); - - SearchRequestParameters requestParameters = new SearchRequestParameters(); - - requestParameters.source = "text"; - requestParameters.match = "president"; - requestParameters.llmModel = BEDROCK_CONVERSE_ANTHROPIC_CLAUDE; - requestParameters.llmQuestion = "describe the image and answer the question: would lincoln have liked this place"; - requestParameters.contextSize = 5; - requestParameters.interactionSize = 5; - requestParameters.timeout = 60; - requestParameters.imageFormat = "jpeg"; - requestParameters.imageType = "data"; // Bedrock does not support URLs - requestParameters.imageData = imageContent; - Response response2 = performSearch(INDEX_NAME, "pipeline_test", 5, requestParameters); - assertEquals(200, response2.getStatusLine().getStatusCode()); - - Map responseMap2 = parseResponseToMap(response2); - Map ext = (Map) responseMap2.get("ext"); - assertNotNull(ext); - Map rag = (Map) ext.get("retrieval_augmented_generation"); - assertNotNull(rag); - - // TODO handle errors such as throttling - String answer = (String) rag.get("answer"); - assertNotNull(answer); - } - - @Ignore - public void testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat() throws Exception { - // Skip test if key is null - if (AWS_ACCESS_KEY_ID == null) { - return; - } - Response response = createConnector(BEDROCK_DOCUMENT_CONVERSE_CONNECTOR_BLUEPRINT2); - Map responseMap = parseResponseToMap(response); - String connectorId = (String) responseMap.get("connector_id"); - response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "Bedrock Anthropic Claude", connectorId); - responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = RestMLRemoteInferenceIT.getTask(taskId); - responseMap = parseResponseToMap(response); - String modelId = (String) responseMap.get("model_id"); - response = deployRemoteModel(modelId); - responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - - PipelineParameters pipelineParameters = new PipelineParameters(); - pipelineParameters.tag = "testBM25WithBedrockConverseUsingLlmMessagesForDocumentChat"; - pipelineParameters.description = "desc"; - pipelineParameters.modelId = modelId; - // pipelineParameters.systemPrompt = "You are a helpful assistant"; - pipelineParameters.userInstructions = "none"; - pipelineParameters.context_field = "text"; - Response response1 = createSearchPipeline2("pipeline_test", pipelineParameters); - assertEquals(200, response1.getStatusLine().getStatusCode()); - - byte[] docBytes = FileUtils.readFileToByteArray(Path.of(classLoader.getResource(TEST_DOC_PATH + "lincoln.pdf").toURI()).toFile()); - String docContent = Base64.getEncoder().encodeToString(docBytes); - - SearchRequestParameters requestParameters; - requestParameters = new SearchRequestParameters(); - requestParameters.source = "text"; - requestParameters.match = "president"; - requestParameters.llmModel = BEDROCK_CONVERSE_ANTHROPIC_CLAUDE_3; - requestParameters.llmQuestion = "use the information from the attached document to tell me something interesting about lincoln"; - requestParameters.contextSize = 5; - requestParameters.interactionSize = 5; - requestParameters.timeout = 60; - requestParameters.documentFormat = "pdf"; - requestParameters.documentName = "lincoln"; - requestParameters.documentData = docContent; - Response response3 = performSearch(INDEX_NAME, "pipeline_test", 5, requestParameters); - assertEquals(200, response3.getStatusLine().getStatusCode()); - - Map responseMap3 = parseResponseToMap(response3); - Map ext = (Map) responseMap3.get("ext"); - assertNotNull(ext); - Map rag = (Map) ext.get("retrieval_augmented_generation"); - assertNotNull(rag); - - // TODO handle errors such as throttling - String answer = (String) rag.get("answer"); - assertNotNull(answer); - } - public void testBM25WithOpenAIWithConversation() throws Exception { // Skip test if key is null if (OPENAI_KEY == null) { @@ -904,11 +413,11 @@ public void testBM25WithOpenAIWithConversation() throws Exception { Response response = createConnector(OPENAI_CONNECTOR_BLUEPRINT); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); - response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "openAI-GPT-3.5 completions", connectorId); + response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId); responseMap = parseResponseToMap(response); String taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); - response = RestMLRemoteInferenceIT.getTask(taskId); + response = getTask(taskId); responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); response = deployRemoteModel(modelId); @@ -917,7 +426,7 @@ public void testBM25WithOpenAIWithConversation() throws Exception { waitForTask(taskId, MLTaskState.COMPLETED); PipelineParameters pipelineParameters = new PipelineParameters(); - pipelineParameters.tag = "testBM25WithOpenAIWithConversation"; + pipelineParameters.tag = "testBM25WithOpenAI"; pipelineParameters.description = "desc"; pipelineParameters.modelId = modelId; pipelineParameters.systemPrompt = "You are a helpful assistant"; @@ -953,68 +462,6 @@ public void testBM25WithOpenAIWithConversation() throws Exception { assertNotNull(interactionId); } - @Ignore - public void testBM25WithOpenAIWithConversationAndImage() throws Exception { - // Skip test if key is null - if (OPENAI_KEY == null) { - return; - } - Response response = createConnector(OPENAI_4o_CONNECTOR_BLUEPRINT); - Map responseMap = parseResponseToMap(response); - String connectorId = (String) responseMap.get("connector_id"); - response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "openAI-GPT-4 completions", connectorId); - responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = RestMLRemoteInferenceIT.getTask(taskId); - responseMap = parseResponseToMap(response); - String modelId = (String) responseMap.get("model_id"); - response = deployRemoteModel(modelId); - responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - - PipelineParameters pipelineParameters = new PipelineParameters(); - pipelineParameters.tag = "testBM25WithOpenAIWithConversationAndImage"; - pipelineParameters.description = "desc"; - pipelineParameters.modelId = modelId; - pipelineParameters.systemPrompt = "You are a helpful assistant"; - pipelineParameters.userInstructions = "none"; - pipelineParameters.context_field = "text"; - Response response1 = createSearchPipeline("pipeline_test", pipelineParameters); - assertEquals(200, response1.getStatusLine().getStatusCode()); - - String conversationId = createConversation("test_convo_1"); - SearchRequestParameters requestParameters = new SearchRequestParameters(); - requestParameters.source = "text"; - requestParameters.match = "president"; - requestParameters.llmModel = OPENAI_40_MODEL; - requestParameters.llmQuestion = "describe the image and answer the question: can you picture lincoln enjoying himself there"; - requestParameters.contextSize = 5; - requestParameters.interactionSize = 5; - requestParameters.timeout = 60; - requestParameters.conversationId = conversationId; - requestParameters.imageFormat = "jpeg"; - requestParameters.imageType = "url"; - requestParameters.imageData = - "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"; - Response response2 = performSearch(INDEX_NAME, "pipeline_test", 5, requestParameters); - assertEquals(200, response2.getStatusLine().getStatusCode()); - - Map responseMap2 = parseResponseToMap(response2); - Map ext = (Map) responseMap2.get("ext"); - assertNotNull(ext); - Map rag = (Map) ext.get("retrieval_augmented_generation"); - assertNotNull(rag); - - // TODO handle errors such as throttling - String answer = (String) rag.get("answer"); - assertNotNull(answer); - - String interactionId = (String) rag.get("message_id"); - assertNotNull(interactionId); - } - public void testBM25WithBedrockWithConversation() throws Exception { // Skip test if key is null if (AWS_ACCESS_KEY_ID == null) { @@ -1023,11 +470,11 @@ public void testBM25WithBedrockWithConversation() throws Exception { Response response = createConnector(BEDROCK_CONNECTOR_BLUEPRINT); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); - response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "Bedrock", connectorId); + response = registerRemoteModel("Bedrock", connectorId); responseMap = parseResponseToMap(response); String taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); - response = RestMLRemoteInferenceIT.getTask(taskId); + response = getTask(taskId); responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); response = deployRemoteModel(modelId); @@ -1036,7 +483,7 @@ public void testBM25WithBedrockWithConversation() throws Exception { waitForTask(taskId, MLTaskState.COMPLETED); PipelineParameters pipelineParameters = new PipelineParameters(); - pipelineParameters.tag = "testBM25WithBedrockWithConversation"; + pipelineParameters.tag = "testBM25WithBedrock"; pipelineParameters.description = "desc"; pipelineParameters.modelId = modelId; pipelineParameters.systemPrompt = "You are a helpful assistant"; @@ -1080,11 +527,11 @@ public void testBM25WithCohere() throws Exception { Response response = createConnector(COHERE_CONNECTOR_BLUEPRINT); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); - response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "Cohere Chat Completion v1", connectorId); + response = registerRemoteModel("Cohere Chat Completion v1", connectorId); responseMap = parseResponseToMap(response); String taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); - response = RestMLRemoteInferenceIT.getTask(taskId); + response = getTask(taskId); responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); response = deployRemoteModel(modelId); @@ -1132,11 +579,11 @@ public void testBM25WithCohereUsingLlmResponseField() throws Exception { Response response = createConnector(COHERE_CONNECTOR_BLUEPRINT); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); - response = RestMLRemoteInferenceIT.registerRemoteModel(ML_RAG_REMOTE_MODEL_GROUP, "Cohere Chat Completion v1", connectorId); + response = registerRemoteModel("Cohere Chat Completion v1", connectorId); responseMap = parseResponseToMap(response); String taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); - response = RestMLRemoteInferenceIT.getTask(taskId); + response = getTask(taskId); responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); response = deployRemoteModel(modelId); @@ -1145,7 +592,7 @@ public void testBM25WithCohereUsingLlmResponseField() throws Exception { waitForTask(taskId, MLTaskState.COMPLETED); PipelineParameters pipelineParameters = new PipelineParameters(); - pipelineParameters.tag = "testBM25WithCohereUsingLlmResponseField"; + pipelineParameters.tag = "testBM25WithCohereLlmResponseField"; pipelineParameters.description = "desc"; pipelineParameters.modelId = modelId; pipelineParameters.systemPrompt = "You are a helpful assistant"; @@ -1200,33 +647,9 @@ private Response createSearchPipeline(String pipeline, PipelineParameters parame ); } - // No system prompt - private Response createSearchPipeline2(String pipeline, PipelineParameters parameters) throws Exception { - return makeRequest( - client(), - "PUT", - String.format(Locale.ROOT, "/_search/pipeline/%s", pipeline), - null, - toHttpEntity( - String - .format( - Locale.ROOT, - PIPELINE_TEMPLATE2, - parameters.tag, - parameters.description, - parameters.modelId, - parameters.userInstructions, - parameters.context_field - ) - ), - ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) - ); - } - private Response performSearch(String indexName, String pipeline, int size, SearchRequestParameters requestParameters) throws Exception { - // TODO build these templates dynamically String httpEntity = requestParameters.llmResponseField != null ? String .format( @@ -1242,90 +665,6 @@ private Response performSearch(String indexName, String pipeline, int size, Sear requestParameters.timeout, requestParameters.llmResponseField ) - : (requestParameters.documentData != null && requestParameters.imageType != null) - ? String - .format( - Locale.ROOT, - BM25_SEARCH_REQUEST_WITH_IMAGE_AND_DOCUMENT_TEMPLATE, - requestParameters.source, - requestParameters.source, - requestParameters.match, - requestParameters.llmModel, - requestParameters.llmQuestion, - requestParameters.systemPrompt, - requestParameters.userInstructions, - requestParameters.contextSize, - requestParameters.interactionSize, - requestParameters.timeout, - requestParameters.llmQuestion, - requestParameters.imageFormat, - requestParameters.imageType, - requestParameters.imageData, - requestParameters.documentFormat, - requestParameters.documentName, - requestParameters.documentData - ) - : (requestParameters.documentData != null) - ? String - .format( - Locale.ROOT, - BM25_SEARCH_REQUEST_WITH_DOCUMENT_TEMPLATE, - requestParameters.source, - requestParameters.source, - requestParameters.match, - requestParameters.llmModel, - requestParameters.llmQuestion, - // requestParameters.systemPrompt, - requestParameters.userInstructions, - requestParameters.contextSize, - requestParameters.interactionSize, - requestParameters.timeout, - requestParameters.llmQuestion, - requestParameters.documentFormat, - requestParameters.documentName, - requestParameters.documentData - ) - : (requestParameters.conversationId != null && requestParameters.imageType != null) - ? String - .format( - Locale.ROOT, - BM25_SEARCH_REQUEST_WITH_CONVO_AND_IMAGE_TEMPLATE, - requestParameters.source, - requestParameters.source, - requestParameters.match, - requestParameters.llmModel, - requestParameters.llmQuestion, - requestParameters.conversationId, - requestParameters.systemPrompt, - requestParameters.userInstructions, - requestParameters.contextSize, - requestParameters.interactionSize, - requestParameters.timeout, - requestParameters.llmQuestion, - requestParameters.imageFormat, - requestParameters.imageType, - requestParameters.imageData - ) - : (requestParameters.imageType != null) - ? String - .format( - Locale.ROOT, - BM25_SEARCH_REQUEST_WITH_IMAGE_TEMPLATE, - requestParameters.source, - requestParameters.source, - requestParameters.match, - requestParameters.llmModel, - requestParameters.llmQuestion, - requestParameters.systemPrompt, - requestParameters.userInstructions, - requestParameters.contextSize, - requestParameters.interactionSize, - requestParameters.timeout, - requestParameters.llmQuestion, - requestParameters.imageFormat, - requestParameters.imageType, - requestParameters.imageData - ) : (requestParameters.conversationId == null) ? String .format( @@ -1402,11 +741,5 @@ static class SearchRequestParameters { String conversationId; String llmResponseField; - String imageFormat; - String imageType; - String imageData; - String documentFormat; - String documentName; - String documentData; } } diff --git a/plugin/src/test/resources/org/opensearch/ml/rest/test_data/lincoln.pdf b/plugin/src/test/resources/org/opensearch/ml/rest/test_data/lincoln.pdf deleted file mode 100644 index 16eddb91fd..0000000000 Binary files a/plugin/src/test/resources/org/opensearch/ml/rest/test_data/lincoln.pdf and /dev/null differ diff --git a/plugin/src/test/resources/org/opensearch/ml/rest/test_data/openai_boardwalk.jpg b/plugin/src/test/resources/org/opensearch/ml/rest/test_data/openai_boardwalk.jpg deleted file mode 100644 index 19fa158886..0000000000 Binary files a/plugin/src/test/resources/org/opensearch/ml/rest/test_data/openai_boardwalk.jpg and /dev/null differ diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java index 6e8a5544b3..7b1814c2a5 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java @@ -179,8 +179,7 @@ public void processResponseAsync( chatHistory, searchResults, timeout, - params.getLlmResponseField(), - params.getLlmMessages() + params.getLlmResponseField() ), null, llmQuestion, @@ -203,8 +202,7 @@ public void processResponseAsync( chatHistory, searchResults, timeout, - params.getLlmResponseField(), - params.getLlmMessages() + params.getLlmResponseField() ), conversationId, llmQuestion, diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java index ba4f1c9b03..01dc97db75 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java @@ -18,8 +18,6 @@ package org.opensearch.searchpipelines.questionanswering.generative.ext; import java.io.IOException; -import java.util.ArrayList; -import java.util.List; import java.util.Objects; import org.opensearch.core.ParseField; @@ -32,7 +30,6 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants; -import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock; import com.google.common.base.Preconditions; @@ -84,8 +81,6 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject { // that contains the chat completion text, i.e. "answer". private static final ParseField LLM_RESPONSE_FIELD = new ParseField("llm_response_field"); - private static final ParseField LLM_MESSAGES_FIELD = new ParseField("llm_messages"); - public static final int SIZE_NULL_VALUE = -1; static { @@ -99,7 +94,6 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject { PARSER.declareIntOrNull(GenerativeQAParameters::setInteractionSize, SIZE_NULL_VALUE, INTERACTION_SIZE); PARSER.declareIntOrNull(GenerativeQAParameters::setTimeout, SIZE_NULL_VALUE, TIMEOUT); PARSER.declareStringOrNull(GenerativeQAParameters::setLlmResponseField, LLM_RESPONSE_FIELD); - PARSER.declareObjectArray(GenerativeQAParameters::setMessageBlock, (p, c) -> MessageBlock.fromXContent(p), LLM_MESSAGES_FIELD); } @Setter @@ -138,10 +132,6 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject { @Getter private String llmResponseField; - @Setter - @Getter - private List llmMessages = new ArrayList<>(); - public GenerativeQAParameters( String conversationId, String llmModel, @@ -152,32 +142,6 @@ public GenerativeQAParameters( Integer interactionSize, Integer timeout, String llmResponseField - ) { - this( - conversationId, - llmModel, - llmQuestion, - systemPrompt, - userInstructions, - contextSize, - interactionSize, - timeout, - llmResponseField, - null - ); - } - - public GenerativeQAParameters( - String conversationId, - String llmModel, - String llmQuestion, - String systemPrompt, - String userInstructions, - Integer contextSize, - Integer interactionSize, - Integer timeout, - String llmResponseField, - List llmMessages ) { this.conversationId = conversationId; this.llmModel = llmModel; @@ -192,9 +156,6 @@ public GenerativeQAParameters( this.interactionSize = (interactionSize == null) ? SIZE_NULL_VALUE : interactionSize; this.timeout = (timeout == null) ? SIZE_NULL_VALUE : timeout; this.llmResponseField = llmResponseField; - if (llmMessages != null) { - this.llmMessages.addAll(llmMessages); - } } public GenerativeQAParameters(StreamInput input) throws IOException { @@ -207,7 +168,6 @@ public GenerativeQAParameters(StreamInput input) throws IOException { this.interactionSize = input.readInt(); this.timeout = input.readInt(); this.llmResponseField = input.readOptionalString(); - this.llmMessages.addAll(input.readList(MessageBlock::new)); } @Override @@ -221,8 +181,7 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params .field(CONTEXT_SIZE.getPreferredName(), this.contextSize) .field(INTERACTION_SIZE.getPreferredName(), this.interactionSize) .field(TIMEOUT.getPreferredName(), this.timeout) - .field(LLM_RESPONSE_FIELD.getPreferredName(), this.llmResponseField) - .field(LLM_MESSAGES_FIELD.getPreferredName(), this.llmMessages); + .field(LLM_RESPONSE_FIELD.getPreferredName(), this.llmResponseField); } @Override @@ -238,7 +197,6 @@ public void writeTo(StreamOutput out) throws IOException { out.writeInt(interactionSize); out.writeInt(timeout); out.writeOptionalString(llmResponseField); - out.writeList(llmMessages); } public static GenerativeQAParameters parse(XContentParser parser) throws IOException { @@ -265,8 +223,4 @@ public boolean equals(Object o) { && (this.timeout == other.getTimeout()) && Objects.equals(this.llmResponseField, other.getLlmResponseField()); } - - public void setMessageBlock(List blockList) { - this.llmMessages = blockList; - } } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java index 3202d56455..66c635b211 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java @@ -44,5 +44,4 @@ public class ChatCompletionInput { private String userInstructions; private Llm.ModelProvider modelProvider; private String llmResponseField; - private List llmMessages; } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java index 6793253480..f6cdfec816 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java @@ -75,6 +75,7 @@ protected void setMlClient(MachineLearningInternalClient mlClient) { * @return */ @Override + public void doChatCompletion(ChatCompletionInput chatCompletionInput, ActionListener listener) { MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(getInputParameters(chatCompletionInput)).build(); MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataset).build(); @@ -112,15 +113,14 @@ protected Map getInputParameters(ChatCompletionInput chatComplet inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel()); String messages = PromptUtil .getChatCompletionPrompt( - chatCompletionInput.getModelProvider(), chatCompletionInput.getSystemPrompt(), chatCompletionInput.getUserInstructions(), chatCompletionInput.getQuestion(), chatCompletionInput.getChatHistory(), - chatCompletionInput.getContexts(), - chatCompletionInput.getLlmMessages() + chatCompletionInput.getContexts() ); inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages); + // log.info("Messages to LLM: {}", messages); } else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK || chatCompletionInput.getModelProvider() == ModelProvider.COHERE || chatCompletionInput.getLlmResponseField() != null) { @@ -136,19 +136,6 @@ protected Map getInputParameters(ChatCompletionInput chatComplet chatCompletionInput.getContexts() ) ); - } else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK_CONVERSE) { - // Bedrock Converse API does not include the system prompt as part of the Messages block. - String messages = PromptUtil - .getChatCompletionPrompt( - chatCompletionInput.getModelProvider(), - null, - chatCompletionInput.getUserInstructions(), - chatCompletionInput.getQuestion(), - chatCompletionInput.getChatHistory(), - chatCompletionInput.getContexts(), - chatCompletionInput.getLlmMessages() - ); - inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages); } else { throw new IllegalArgumentException( "Unknown/unsupported model provider: " @@ -157,6 +144,7 @@ protected Map getInputParameters(ChatCompletionInput chatComplet ); } + // log.info("LLM input parameters: {}", inputParameters.toString()); return inputParameters; } @@ -196,20 +184,6 @@ protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider, } else if (provider == ModelProvider.COHERE) { answerField = "text"; fillAnswersOrErrors(dataAsMap, answers, errors, answerField, errorField, defaultErrorMessageField); - } else if (provider == ModelProvider.BEDROCK_CONVERSE) { - Map output = (Map) dataAsMap.get("output"); - Map message = (Map) output.get("message"); - if (message != null) { - List content = (List) message.get("content"); - String answer = (String) ((Map) content.get(0)).get("text"); - answers.add(answer); - } else { - Map error = (Map) output.get("error"); - if (error == null) { - throw new RuntimeException("Unexpected output: " + output); - } - errors.add((String) error.get("message")); - } } else { throw new IllegalArgumentException( "Unknown/unsupported model provider: " + provider + ". You must provide a valid model provider or llm_response_field." diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java index 9318b681d2..1099b1e21f 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java @@ -28,8 +28,7 @@ public interface Llm { enum ModelProvider { OPENAI, BEDROCK, - COHERE, - BEDROCK_CONVERSE + COHERE } void doChatCompletion(ChatCompletionInput input, ActionListener listener); diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java index 24e38ac368..ef9e9948db 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java @@ -29,7 +29,6 @@ public class LlmIOUtil { public static final String BEDROCK_PROVIDER_PREFIX = "bedrock/"; public static final String COHERE_PROVIDER_PREFIX = "cohere/"; - public static final String BEDROCK_CONVERSE__PROVIDER_PREFIX = "bedrock-converse/"; public static ChatCompletionInput createChatCompletionInput( String llmModel, @@ -50,8 +49,7 @@ public static ChatCompletionInput createChatCompletionInput( chatHistory, contexts, timeoutInSeconds, - llmResponseField, - null + llmResponseField ); } @@ -63,8 +61,7 @@ public static ChatCompletionInput createChatCompletionInput( List chatHistory, List contexts, int timeoutInSeconds, - String llmResponseField, - List llmMessages + String llmResponseField ) { Llm.ModelProvider provider = null; if (llmResponseField == null) { @@ -74,8 +71,6 @@ public static ChatCompletionInput createChatCompletionInput( provider = Llm.ModelProvider.BEDROCK; } else if (llmModel.startsWith(COHERE_PROVIDER_PREFIX)) { provider = Llm.ModelProvider.COHERE; - } else if (llmModel.startsWith(BEDROCK_CONVERSE__PROVIDER_PREFIX)) { - provider = Llm.ModelProvider.BEDROCK_CONVERSE; } } } @@ -88,8 +83,7 @@ public static ChatCompletionInput createChatCompletionInput( systemPrompt, userInstructions, provider, - llmResponseField, - llmMessages + llmResponseField ); } } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/MessageBlock.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/MessageBlock.java deleted file mode 100644 index 1dbfd4d13b..0000000000 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/MessageBlock.java +++ /dev/null @@ -1,325 +0,0 @@ -/* - * Copyright 2023 Aryn - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.opensearch.searchpipelines.questionanswering.generative.llm; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; - -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.common.io.stream.Writeable; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParseException; -import org.opensearch.core.xcontent.XContentParser; - -import com.google.common.base.Preconditions; - -import lombok.Getter; -import lombok.Setter; - -public class MessageBlock implements Writeable, ToXContent { - - private static final String TEXT_BLOCK = "text"; - private static final String IMAGE_BLOCK = "image"; - private static final String DOCUMENT_BLOCK = "document"; - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(this.role); - out.writeList(this.blockList); - } - - public MessageBlock(StreamInput in) throws IOException { - this.role = in.readString(); - Writeable.Reader reader = input -> { - String type = input.readString(); - if (type.equals("text")) { - return new TextBlock(input); - } else if (type.equals("image")) { - return new ImageBlock(input); - } else if (type.equals("document")) { - return new DocumentBlock(input); - } else { - throw new RuntimeException("Unexpected type: " + type); - } - }; - this.blockList = in.readList(reader); - } - - public static MessageBlock fromXContent(XContentParser parser) throws IOException { - if (parser.currentToken() == XContentParser.Token.START_OBJECT) { - return new MessageBlock(parser.map()); - } - throw new XContentParseException(parser.getTokenLocation(), "Expected [START_OBJECT], got " + parser.currentToken()); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field("role", this.role); - builder.startArray("content"); - for (AbstractBlock block : this.blockList) { - block.toXContent(builder, params); - } - builder.endArray(); - builder.endObject(); - return builder; - } - - public interface Block { - String getType(); - } - - public static abstract class AbstractBlock implements Block, Writeable, ToXContent { - - @Override - abstract public String getType(); - - @Override - public void writeTo(StreamOutput out) throws IOException { - throw new UnsupportedOperationException("Not implemented."); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - throw new UnsupportedOperationException("Not implemented."); - } - } - - public static class TextBlock extends AbstractBlock { - - @Getter - String type = "text"; - - @Getter - @Setter - String text; - - public TextBlock(String text) { - Preconditions.checkNotNull(text, "text cannot be null."); - this.text = text; - } - - public TextBlock(StreamInput in) throws IOException { - this.text = in.readString(); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(this.type); - out.writeString(this.text); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - - builder.startObject(); - builder.field("type", "text"); - builder.field("text", this.text); - builder.endObject(); - return builder; - } - } - - public static class ImageBlock extends AbstractBlock { - - @Getter - String type = "image"; - - @Getter - @Setter - String format; - - @Getter - @Setter - String data; - - @Getter - @Setter - String url; - - public ImageBlock(Map imageBlock) { - this.format = (String) imageBlock.get("format"); - Object tmp = imageBlock.get("data"); - if (tmp != null) { - this.data = (String) tmp; - } else { - tmp = imageBlock.get("url"); - if (tmp == null) { - throw new IllegalArgumentException("data or url not found in imageBlock."); - } - this.url = (String) tmp; - } - - } - - public ImageBlock(String format, String data, String url) { - Preconditions.checkNotNull(format, "format cannot be null."); - if (data == null && url == null) { - throw new IllegalArgumentException("data and url cannot both be null."); - } - this.format = format; - this.data = data; - this.url = url; - } - - public ImageBlock(StreamInput in) throws IOException { - format = in.readString(); - data = in.readOptionalString(); - url = in.readOptionalString(); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(this.type); - out.writeString(this.format); - out.writeOptionalString(this.data); - out.writeOptionalString(this.url); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - Map imageMap = new HashMap<>(); - imageMap.put("format", this.format); - if (this.data != null) { - imageMap.put("data", this.data); - } else if (this.url != null) { - imageMap.put("url", this.url); - } - builder.field("image", imageMap); - builder.endObject(); - return builder; - } - } - - public static class DocumentBlock extends AbstractBlock { - - @Getter - String type = "document"; - - @Getter - @Setter - String format; - - @Getter - @Setter - String name; - - @Getter - @Setter - String data; - - public DocumentBlock(Map documentBlock) { - Preconditions.checkState(documentBlock.containsKey("format"), "format not found in the document block."); - Preconditions.checkState(documentBlock.containsKey("name"), "name not found in the document block."); - Preconditions.checkState(documentBlock.containsKey("data"), "data not found in the document block"); - - this.format = (String) documentBlock.get("format"); - this.name = (String) documentBlock.get("name"); - this.data = (String) documentBlock.get("data"); - } - - public DocumentBlock(String format, String name, String data) { - Preconditions.checkNotNull(format, "format cannot be null."); - Preconditions.checkNotNull(name, "name cannot be null."); - Preconditions.checkNotNull(data, "data cannot be null."); - - this.format = format; - this.name = name; - this.data = data; - } - - public DocumentBlock(StreamInput in) throws IOException { - format = in.readString(); - name = in.readString(); - data = in.readString(); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(this.type); - out.writeString(this.format); - out.writeString(this.name); - out.writeString(this.data); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.startObject("document"); - builder.field("format", this.format); - builder.field("name", this.name); - builder.field("data", this.data); - builder.endObject(); - builder.endObject(); - return builder; - } - } - - @Getter - @Setter - private String role; - - @Getter - @Setter - private List blockList = new ArrayList<>(); - - public MessageBlock() {} - - public MessageBlock(Map map) { - setMessageBlock(map); - } - - public void setMessageBlock(Map message) { - Preconditions.checkNotNull(message, "message cannot be null."); - Preconditions.checkState(message.containsKey("role"), "message must have role."); - Preconditions.checkState(message.containsKey("content"), "message must have content."); - - this.role = (String) message.get("role"); - List> contents = (List) message.get("content"); - - for (Map content : contents) { - if (content.containsKey(TEXT_BLOCK)) { - this.blockList.add(new TextBlock((String) content.get(TEXT_BLOCK))); - } else if (content.containsKey(IMAGE_BLOCK)) { - Map imageBlock = (Map) content.get(IMAGE_BLOCK); - this.blockList.add(new ImageBlock(imageBlock)); - } else if (content.containsKey(DOCUMENT_BLOCK)) { - Map documentBlock = (Map) content.get(DOCUMENT_BLOCK); - this.blockList.add(new DocumentBlock(documentBlock)); - } - } - } - - @Override - public boolean equals(Object o) { - // TODO - return true; - } - - @Override - public int hashCode() { - return Objects.hashCode(this.role) + Objects.hashCode(this.blockList); - } -} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java index 9b875c6f7a..3a8a21614e 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java @@ -19,20 +19,15 @@ import java.util.ArrayList; import java.util.Collections; -import java.util.EnumSet; import java.util.List; import java.util.Locale; import org.apache.commons.text.StringEscapeUtils; import org.opensearch.core.common.Strings; import org.opensearch.ml.common.conversation.Interaction; -import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm; -import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; import com.google.gson.JsonArray; -import com.google.gson.JsonElement; import com.google.gson.JsonObject; import com.google.gson.JsonPrimitive; @@ -66,27 +61,20 @@ public static String getQuestionRephrasingPrompt(String originalQuestion, List chatHistory, - List contexts - ) { - return getChatCompletionPrompt(provider, DEFAULT_SYSTEM_PROMPT, null, question, chatHistory, contexts, null); + public static String getChatCompletionPrompt(String question, List chatHistory, List contexts) { + return getChatCompletionPrompt(DEFAULT_SYSTEM_PROMPT, null, question, chatHistory, contexts); } // TODO Currently, this is OpenAI specific. Change this to indicate as such or address it as part of // future prompt template management work. public static String getChatCompletionPrompt( - Llm.ModelProvider provider, String systemPrompt, String userInstructions, String question, List chatHistory, - List contexts, - List llmMessages + List contexts ) { - return buildMessageParameter(provider, systemPrompt, userInstructions, question, chatHistory, contexts, llmMessages); + return buildMessageParameter(systemPrompt, userInstructions, question, chatHistory, contexts); } enum ChatRole { @@ -146,132 +134,37 @@ public static String buildSingleStringPrompt( return bldr.toString(); } - /** - * Message APIs such as OpenAI's Chat Completion API and Anthropic's Messages API - * use an array of messages as input to the LLM and they are better suited for - * multi-modal interactions using text and images. - * - * @param provider - * @param systemPrompt - * @param userInstructions - * @param question - * @param chatHistory - * @param contexts - * @return - */ @VisibleForTesting static String buildMessageParameter( - Llm.ModelProvider provider, String systemPrompt, String userInstructions, String question, List chatHistory, List contexts ) { - return buildMessageParameter(provider, systemPrompt, userInstructions, question, chatHistory, contexts, null); - } - - static String buildMessageParameter( - Llm.ModelProvider provider, - String systemPrompt, - String userInstructions, - String question, - List chatHistory, - List contexts, - List llmMessages - ) { // TODO better prompt template management is needed here. if (Strings.isNullOrEmpty(systemPrompt) && Strings.isNullOrEmpty(userInstructions)) { - // Some model providers such as Anthropic do not allow the system prompt as part of the message body. - userInstructions = DEFAULT_SYSTEM_PROMPT; - } - - MessageArrayBuilder messageArrayBuilder = new MessageArrayBuilder(provider); - - // Build the system prompt (only one per conversation/session) - if (!Strings.isNullOrEmpty(systemPrompt)) { - messageArrayBuilder.startMessage(ChatRole.SYSTEM); - messageArrayBuilder.addTextContent(systemPrompt); - messageArrayBuilder.endMessage(); + systemPrompt = DEFAULT_SYSTEM_PROMPT; } - // Anthropic does not allow two consecutive messages of the same role - // so we combine all user messages and an array of contents. - messageArrayBuilder.startMessage(ChatRole.USER); - boolean lastRoleIsAssistant = false; - if (!Strings.isNullOrEmpty(userInstructions)) { - messageArrayBuilder.addTextContent(userInstructions); - } + JsonArray messageArray = new JsonArray(); + messageArray.addAll(getPromptTemplateAsJsonArray(systemPrompt, userInstructions)); for (int i = 0; i < contexts.size(); i++) { - messageArrayBuilder.addTextContent("SEARCH RESULT " + (i + 1) + ": " + contexts.get(i)); + messageArray.add(new Message(ChatRole.USER, "SEARCH RESULT " + (i + 1) + ": " + contexts.get(i)).toJson()); } - if (!chatHistory.isEmpty()) { // The oldest interaction first - int idx = chatHistory.size() - 1; - Interaction firstInteraction = chatHistory.get(idx); - messageArrayBuilder.addTextContent(firstInteraction.getInput()); - messageArrayBuilder.endMessage(); - messageArrayBuilder.startMessage(ChatRole.ASSISTANT, firstInteraction.getResponse()); - messageArrayBuilder.endMessage(); - - if (chatHistory.size() > 1) { - for (int i = --idx; i >= 0; i--) { - Interaction interaction = chatHistory.get(i); - messageArrayBuilder.startMessage(ChatRole.USER, interaction.getInput()); - messageArrayBuilder.endMessage(); - messageArrayBuilder.startMessage(ChatRole.ASSISTANT, interaction.getResponse()); - messageArrayBuilder.endMessage(); - } - } - - lastRoleIsAssistant = true; - } - - if (llmMessages != null && !llmMessages.isEmpty()) { - // TODO MessageBlock can have assistant roles for few-shot prompting. - if (lastRoleIsAssistant) { - messageArrayBuilder.startMessage(ChatRole.USER); - } - for (MessageBlock message : llmMessages) { - List blockList = message.getBlockList(); - for (MessageBlock.Block block : blockList) { - switch (block.getType()) { - case "text": - messageArrayBuilder.addTextContent(((MessageBlock.TextBlock) block).getText()); - break; - case "image": - MessageBlock.ImageBlock ib = (MessageBlock.ImageBlock) block; - if (ib.getData() != null) { - messageArrayBuilder.addImageData(ib.getFormat(), ib.getData()); - } else if (ib.getUrl() != null) { - messageArrayBuilder.addImageUrl(ib.getFormat(), ib.getUrl()); - } - break; - case "document": - MessageBlock.DocumentBlock db = (MessageBlock.DocumentBlock) block; - messageArrayBuilder.addDocumentContent(db.getFormat(), db.getName(), db.getData()); - break; - default: - break; - } - } - } - } else { - if (lastRoleIsAssistant) { - messageArrayBuilder.startMessage(ChatRole.USER, "QUESTION: " + question + "\n"); - } else { - messageArrayBuilder.addTextContent("QUESTION: " + question + "\n"); - } - messageArrayBuilder.addTextContent("ANSWER:"); + List messages = Messages.fromInteractions(chatHistory).getMessages(); + Collections.reverse(messages); + messages.forEach(m -> messageArray.add(m.toJson())); } + messageArray.add(new Message(ChatRole.USER, "QUESTION: " + question).toJson()); + messageArray.add(new Message(ChatRole.USER, "ANSWER:").toJson()); - messageArrayBuilder.endMessage(); - - return messageArrayBuilder.toJsonArray().toString(); + return messageArray.toString(); } public static String getPromptTemplate(String systemPrompt, String userInstructions) { @@ -290,24 +183,6 @@ static JsonArray getPromptTemplateAsJsonArray(String systemPrompt, String userIn return messageArray; } - /* - static JsonArray getPromptTemplateAsJsonArray(Llm.ModelProvider provider, String systemPrompt, String userInstructions) { - - MessageArrayBuilder bldr = new MessageArrayBuilder(provider); - - if (!Strings.isNullOrEmpty(systemPrompt)) { - bldr.startMessage(ChatRole.SYSTEM); - bldr.addTextContent(systemPrompt); - bldr.endMessage(); - } - if (!Strings.isNullOrEmpty(userInstructions)) { - bldr.startMessage(ChatRole.USER); - bldr.addTextContent(userInstructions); - bldr.endMessage(); - } - return bldr.toJsonArray(); - }*/ - @Getter static class Messages { @@ -334,207 +209,6 @@ public static Messages fromInteractions(final List interactions) { } } - interface Content { - - // All content blocks accept text - void addText(String text); - - JsonElement toJson(); - } - - interface ImageContent extends Content { - - void addImageData(String format, String data); - - void addImageUrl(String format, String url); - } - - interface DocumentContent extends Content { - void addDocument(String format, String name, String data); - } - - interface MultimodalContent extends ImageContent, DocumentContent { - - } - - private final static String CONTENT_FIELD_TEXT = "text"; - private final static String CONTENT_FIELD_TYPE = "type"; - - static class OpenAIContent implements ImageContent { - - private JsonArray json; - - public OpenAIContent() { - this.json = new JsonArray(); - } - - @Override - public void addText(String text) { - JsonObject content = new JsonObject(); - content.add(CONTENT_FIELD_TYPE, new JsonPrimitive(CONTENT_FIELD_TEXT)); - content.add(CONTENT_FIELD_TEXT, new JsonPrimitive(text)); - json.add(content); - } - - @Override - public void addImageData(String format, String data) { - JsonObject content = new JsonObject(); - content.add("type", new JsonPrimitive("image_url")); - JsonObject urlContent = new JsonObject(); - String imageData = String.format(Locale.ROOT, "data:image/%s;base64,%s", format, data); - urlContent.add("url", new JsonPrimitive(imageData)); - content.add("image_url", urlContent); - json.add(content); - } - - @Override - public void addImageUrl(String format, String url) { - JsonObject content = new JsonObject(); - content.add("type", new JsonPrimitive("image_url")); - JsonObject urlContent = new JsonObject(); - urlContent.add("url", new JsonPrimitive(url)); - content.add("image_url", urlContent); - json.add(content); - } - - @Override - public JsonElement toJson() { - return this.json; - } - } - - static class BedrockContent implements MultimodalContent { - - private JsonArray json; - - public BedrockContent() { - this.json = new JsonArray(); - } - - public BedrockContent(String type, String value) { - this.json = new JsonArray(); - if (type.equals("text")) { - addText(value); - } - } - - @Override - public void addText(String text) { - JsonObject content = new JsonObject(); - content.add(CONTENT_FIELD_TEXT, new JsonPrimitive(text)); - json.add(content); - } - - @Override - public JsonElement toJson() { - return this.json; - } - - @Override - public void addImageData(String format, String data) { - JsonObject imageData = new JsonObject(); - imageData.add("bytes", new JsonPrimitive(data)); - JsonObject image = new JsonObject(); - image.add("format", new JsonPrimitive(format)); - image.add("source", imageData); - JsonObject content = new JsonObject(); - content.add("image", image); - json.add(content); - } - - @Override - public void addImageUrl(String format, String url) { - // Bedrock does not support image URLs. - } - - @Override - public void addDocument(String format, String name, String data) { - JsonObject documentData = new JsonObject(); - documentData.add("bytes", new JsonPrimitive(data)); - JsonObject document = new JsonObject(); - document.add("format", new JsonPrimitive(format)); - document.add("name", new JsonPrimitive(name)); - document.add("source", documentData); - JsonObject content = new JsonObject(); - content.add("document", document); - json.add(content); - } - } - - static class MessageArrayBuilder { - - private final Llm.ModelProvider provider; - private List messages = new ArrayList<>(); - private Message message = null; - private Content content = null; - - public MessageArrayBuilder(Llm.ModelProvider provider) { - // OpenAI or Bedrock Converse API - if (!EnumSet.of(Llm.ModelProvider.OPENAI, Llm.ModelProvider.BEDROCK_CONVERSE).contains(provider)) { - throw new IllegalArgumentException("Unsupported provider: " + provider); - } - this.provider = provider; - } - - public void startMessage(ChatRole role) { - this.message = new Message(); - this.message.setChatRole(role); - if (this.provider == Llm.ModelProvider.OPENAI) { - content = new OpenAIContent(); - } else if (this.provider == Llm.ModelProvider.BEDROCK_CONVERSE) { - content = new BedrockContent(); - } - } - - public void startMessage(ChatRole role, String text) { - startMessage(role); - addTextContent(text); - } - - public void endMessage() { - this.message.setContent(this.content); - this.messages.add(this.message); - message = null; - content = null; - } - - public void addTextContent(String content) { - if (this.message == null || this.content == null) { - throw new RuntimeException("You must call startMessage before calling addTextContent !!"); - } - this.content.addText(content); - } - - public void addImageData(String format, String data) { - if (this.content != null && this.content instanceof ImageContent) { - ((ImageContent) this.content).addImageData(format, data); - } - } - - public void addImageUrl(String format, String url) { - if (this.content != null && this.content instanceof ImageContent) { - ((ImageContent) this.content).addImageUrl(format, url); - } - } - - public void addDocumentContent(String format, String name, String data) { - if (this.content != null && this.content instanceof DocumentContent) { - ((DocumentContent) this.content).addDocument(format, name, data); - } - } - - public JsonArray toJsonArray() { - Preconditions - .checkState(this.message == null && this.content == null, "You must call endMessage before calling toJsonArray !!"); - - JsonArray ja = new JsonArray(); - for (Message message : messages) { - ja.add(message.toJson()); - } - return ja; - } - } - // TODO This is OpenAI specific. Either change this to OpenAiMessage or have it handle // vendor specific messages. static class Message { @@ -559,12 +233,6 @@ public Message(ChatRole chatRole, String content) { setContent(content); } - public Message(ChatRole chatRole, Content content) { - this(); - setChatRole(chatRole); - setContent(content); - } - public void setChatRole(ChatRole chatRole) { this.chatRole = chatRole; json.remove(MESSAGE_FIELD_ROLE); @@ -577,11 +245,6 @@ public void setContent(String content) { json.add(MESSAGE_FIELD_CONTENT, new JsonPrimitive(this.content)); } - public void setContent(Content content) { - json.remove(MESSAGE_FIELD_CONTENT); - json.add(MESSAGE_FIELD_CONTENT, content.toJson()); - } - public JsonObject toJson() { return json; } diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java index 23eb6f3d3a..49f164cdb5 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java @@ -25,8 +25,6 @@ import java.io.EOFException; import java.io.IOException; -import java.util.List; -import java.util.Map; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentType; @@ -35,22 +33,10 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentHelper; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock; import org.opensearch.test.OpenSearchTestCase; public class GenerativeQAParamExtBuilderTests extends OpenSearchTestCase { - private List messageList = null; - - public GenerativeQAParamExtBuilderTests() { - Map imageMap = Map.of("image", Map.of("format", "jpg", "url", "https://xyz.com/file.jpg")); - Map textMap = Map.of("text", "what is this"); - Map contentMap = Map.of(); - Map map = Map.of("role", "user", "content", List.of(textMap, imageMap)); - MessageBlock mb = new MessageBlock(map); - messageList = List.of(mb); - } - public void testCtor() throws IOException { GenerativeQAParamExtBuilder builder = new GenerativeQAParamExtBuilder(); GenerativeQAParameters parameters = new GenerativeQAParameters( @@ -129,7 +115,7 @@ public void testParse() throws IOException { } public void testXContentRoundTrip() throws IOException { - GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", null, null, null, null, messageList); + GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", null, null, null, null); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(param1); XContentType xContentType = randomFrom(XContentType.values()); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java index e5caa70ed7..c36dcdb2a5 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java @@ -24,7 +24,6 @@ import java.io.OutputStream; import java.util.ArrayList; import java.util.List; -import java.util.Map; import org.opensearch.action.search.SearchRequest; import org.opensearch.core.common.io.stream.StreamOutput; @@ -32,22 +31,10 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentGenerator; import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock; import org.opensearch.test.OpenSearchTestCase; public class GenerativeQAParametersTests extends OpenSearchTestCase { - private List messageList = null; - - public GenerativeQAParametersTests() { - Map imageMap = Map.of("image", Map.of("format", "jpg", "url", "https://xyz.com/file.jpg")); - Map textMap = Map.of("text", "what is this"); - Map contentMap = Map.of(); - Map map = Map.of("role", "user", "content", List.of(textMap, imageMap)); - MessageBlock mb = new MessageBlock(map); - messageList = List.of(mb); - } - public void testGenerativeQAParameters() { GenerativeQAParameters params = new GenerativeQAParameters( "conversation_id", @@ -68,29 +55,6 @@ public void testGenerativeQAParameters() { assertEquals(params, actual); } - public void testGenerativeQAParametersWithLlmMessages() { - - GenerativeQAParameters params = new GenerativeQAParameters( - "conversation_id", - "llm_model", - "llm_question", - "system_prompt", - "user_instructions", - null, - null, - null, - null, - this.messageList - ); - GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); - extBuilder.setParams(params); - SearchSourceBuilder srcBulder = SearchSourceBuilder.searchSource().ext(List.of(extBuilder)); - SearchRequest request = new SearchRequest("my_index").source(srcBulder); - GenerativeQAParameters actual = GenerativeQAParamUtil.getGenerativeQAParameters(request); - // MessageBlock messageBlock = actual.getMessageBlock(); - assertEquals(params, actual); - } - static class DummyStreamOutput extends StreamOutput { List list = new ArrayList<>(); @@ -98,7 +62,6 @@ static class DummyStreamOutput extends StreamOutput { @Override public void writeString(String str) { - System.out.println("Adding string: " + str); list.add(str); } @@ -160,13 +123,12 @@ public void testWriteTo() throws IOException { contextSize, interactionSize, timeout, - llmResponseField, - messageList + llmResponseField ); StreamOutput output = new DummyStreamOutput(); parameters.writeTo(output); List actual = ((DummyStreamOutput) output).getList(); - assertEquals(12, actual.size()); + assertEquals(6, actual.size()); assertEquals(conversationId, actual.get(0)); assertEquals(llmModel, actual.get(1)); assertEquals(llmQuestion, actual.get(2)); @@ -228,8 +190,7 @@ public void testToXConent() throws IOException { null, null, null, - null, - messageList + null ); XContent xc = mock(XContent.class); OutputStream os = mock(OutputStream.class); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java index d70739b8cd..f3a4bf8284 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java @@ -43,7 +43,6 @@ public void testCtor() { systemPrompt, userInstructions, Llm.ModelProvider.OPENAI, - null, null ); @@ -82,7 +81,6 @@ public void testGettersSetters() { systemPrompt, userInstructions, Llm.ModelProvider.OPENAI, - null, null ); assertEquals(model, input.getModel()); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java index 5e5f72b59a..2dc06366f8 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java @@ -93,7 +93,7 @@ public void testBuildMessageParameter() { ) ) ); - String parameter = PromptUtil.getChatCompletionPrompt(Llm.ModelProvider.BEDROCK_CONVERSE, question, chatHistory, contexts); + String parameter = PromptUtil.getChatCompletionPrompt(question, chatHistory, contexts); Map parameters = Map.of("model", "foo", "messages", parameter); assertTrue(isJson(parameter)); } @@ -120,7 +120,6 @@ public void testChatCompletionApi() throws Exception { "prompt", "instructions", Llm.ModelProvider.OPENAI, - null, null ); doAnswer(invocation -> { @@ -165,56 +164,6 @@ public void testChatCompletionApiForBedrock() throws Exception { "prompt", "instructions", Llm.ModelProvider.BEDROCK, - null, - null - ); - doAnswer(invocation -> { - ((ActionListener) invocation.getArguments()[2]).onResponse(mlOutput); - return null; - }).when(mlClient).predict(any(), any(), any()); - connector.doChatCompletion(input, new ActionListener<>() { - @Override - public void onResponse(ChatCompletionOutput output) { - assertEquals("answer", output.getAnswers().get(0)); - } - - @Override - public void onFailure(Exception e) { - - } - }); - verify(mlClient, times(1)).predict(any(), captor.capture(), any()); - MLInput mlInput = captor.getValue(); - assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet); - } - - public void testMessageApiForBedrockConverse() throws Exception { - MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class); - ArgumentCaptor captor = ArgumentCaptor.forClass(MLInput.class); - DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client); - connector.setMlClient(mlClient); - - Map messageMap = Map.of("role", "agent", "content", "answer"); - Map text = Map.of("text", "answer"); - List list = List.of(text); - Map content = Map.of("content", list); - Map message = Map.of("message", content); - Map dataAsMap = Map.of("output", message); - ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, dataAsMap); - ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor)))); - ActionFuture future = mock(ActionFuture.class); - when(future.actionGet(anyLong())).thenReturn(mlOutput); - when(mlClient.predict(any(), any())).thenReturn(future); - ChatCompletionInput input = new ChatCompletionInput( - "bedrock-converse/model", - "question", - Collections.emptyList(), - Collections.emptyList(), - 0, - "prompt", - "instructions", - Llm.ModelProvider.BEDROCK_CONVERSE, - null, null ); doAnswer(invocation -> { @@ -259,7 +208,6 @@ public void testChatCompletionApiForCohere() throws Exception { "prompt", "instructions", Llm.ModelProvider.COHERE, - null, null ); doAnswer(invocation -> { @@ -305,7 +253,6 @@ public void testChatCompletionApiForCohereWithError() throws Exception { "prompt", "instructions", Llm.ModelProvider.COHERE, - null, null ); doAnswer(invocation -> { @@ -353,8 +300,7 @@ public void testChatCompletionApiForFoo() throws Exception { "prompt", "instructions", null, - llmRespondField, - null + llmRespondField ); doAnswer(invocation -> { ((ActionListener) invocation.getArguments()[2]).onResponse(mlOutput); @@ -401,8 +347,7 @@ public void testChatCompletionApiForFooWithError() throws Exception { "prompt", "instructions", null, - llmRespondField, - null + llmRespondField ); doAnswer(invocation -> { ((ActionListener) invocation.getArguments()[2]).onResponse(mlOutput); @@ -450,8 +395,7 @@ public void testChatCompletionApiForFooWithErrorUnknownMessageField() throws Exc "prompt", "instructions", null, - llmRespondField, - null + llmRespondField ); doAnswer(invocation -> { ((ActionListener) invocation.getArguments()[2]).onResponse(mlOutput); @@ -499,8 +443,7 @@ public void testChatCompletionApiForFooWithErrorUnknownErrorField() throws Excep "prompt", "instructions", null, - llmRespondField, - null + llmRespondField ); doAnswer(invocation -> { ((ActionListener) invocation.getArguments()[2]).onResponse(mlOutput); @@ -546,7 +489,6 @@ public void testChatCompletionThrowingError() throws Exception { "prompt", "instructions", Llm.ModelProvider.OPENAI, - null, null ); @@ -594,7 +536,6 @@ public void testChatCompletionBedrockThrowingError() throws Exception { "prompt", "instructions", Llm.ModelProvider.BEDROCK, - null, null ); doAnswer(invocation -> { @@ -644,7 +585,6 @@ public void testIllegalArgument1() { "prompt", "instructions", null, - null, null ); connector.doChatCompletion(input, ActionListener.wrap(r -> {}, e -> {})); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/MessageBlockTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/MessageBlockTests.java deleted file mode 100644 index 62c6381b55..0000000000 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/MessageBlockTests.java +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Copyright 2023 Aryn - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.opensearch.searchpipelines.questionanswering.generative.llm; - -import java.io.IOException; -import java.util.List; -import java.util.Map; - -import org.junit.Rule; -import org.junit.rules.ExpectedException; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParseException; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.test.OpenSearchTestCase; - -public class MessageBlockTests extends OpenSearchTestCase { - - @Rule - public ExpectedException exceptionRule = ExpectedException.none(); - - public void testStreamRoundTrip() throws Exception { - MessageBlock.TextBlock tb = new MessageBlock.TextBlock("text"); - MessageBlock.ImageBlock ib = new MessageBlock.ImageBlock("jpeg", "data", null); - MessageBlock.ImageBlock ib2 = new MessageBlock.ImageBlock("jpeg", null, "https://xyz/foo.jpg"); - MessageBlock.DocumentBlock db = new MessageBlock.DocumentBlock("pdf", "doc1", "data"); - List blocks = List.of(tb, ib, ib2, db); - MessageBlock mb = new MessageBlock(); - mb.setRole("user"); - mb.setBlockList(blocks); - BytesStreamOutput bso = new BytesStreamOutput(); - mb.writeTo(bso); - MessageBlock read = new MessageBlock(bso.bytes().streamInput()); - assertEquals(mb, read); - } - - public void testFromXContentParseError() throws IOException { - exceptionRule.expect(XContentParseException.class); - - MessageBlock.TextBlock tb = new MessageBlock.TextBlock("text"); - MessageBlock.ImageBlock ib = new MessageBlock.ImageBlock("jpeg", "data", null); - // MessageBlock.ImageBlock ib2 = new MessageBlock.ImageBlock("jpeg", null, "https://xyz/foo.jpg"); - MessageBlock.ImageBlock ib2 = new MessageBlock.ImageBlock(Map.of("format", "png", "data", "xyz")); - MessageBlock.DocumentBlock db = new MessageBlock.DocumentBlock("pdf", "doc1", "data"); - List blocks = List.of(tb, ib, ib2, db); - MessageBlock mb = new MessageBlock(); - mb.setRole("user"); - mb.setBlockList(blocks); - try (XContentBuilder builder = XContentBuilder.builder(randomFrom(XContentType.values()).xContent())) { - mb.toXContent(builder, ToXContent.EMPTY_PARAMS); - try (XContentBuilder shuffled = shuffleXContent(builder); XContentParser parser = createParser(shuffled)) { - // read = TaskResult.PARSER.apply(parser, null); - MessageBlock.fromXContent(parser); - } - } finally { - // throw new IOException("Error processing [" + mb + "]", e); - } - } - - public void testInvalidImageBlock1() { - exceptionRule.expect(IllegalArgumentException.class); - MessageBlock.ImageBlock ib = new MessageBlock.ImageBlock(Map.of("format", "png")); - } - - public void testInvalidImageBlock2() { - exceptionRule.expect(IllegalArgumentException.class); - MessageBlock.ImageBlock ib = new MessageBlock.ImageBlock("jpeg", null, null); - } - - public void testInvalidDocumentBlock1() { - exceptionRule.expect(NullPointerException.class); - MessageBlock.DocumentBlock db = new MessageBlock.DocumentBlock(null, null, null); - } - - public void testInvalidDocumentBlock2() { - exceptionRule.expect(IllegalStateException.class); - MessageBlock.DocumentBlock db = new MessageBlock.DocumentBlock(Map.of()); - } - - public void testDocumentBlockCtor1() { - MessageBlock.DocumentBlock db = new MessageBlock.DocumentBlock(Map.of("format", "pdf", "name", "doc", "data", "xyz")); - assertEquals(db.format, "pdf"); - assertEquals(db.name, "doc"); - assertEquals(db.data, "xyz"); - } -} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java index 0d82a18a15..a3aedf4e5d 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java @@ -26,19 +26,12 @@ import org.json.JSONArray; import org.json.JSONException; import org.json.JSONObject; -import org.junit.Rule; -import org.junit.rules.ExpectedException; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; import org.opensearch.ml.common.conversation.Interaction; -import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm; -import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock; import org.opensearch.test.OpenSearchTestCase; public class PromptUtilTests extends OpenSearchTestCase { - @Rule - public ExpectedException exceptionRule = ExpectedException.none(); - public void testPromptUtilStaticMethods() { assertNull(PromptUtil.getQuestionRephrasingPrompt("question", Collections.emptyList())); } @@ -79,50 +72,7 @@ public void testBuildMessageParameter() { ); contexts.add("context 1"); contexts.add("context 2"); - String parameter = PromptUtil - .buildMessageParameter(Llm.ModelProvider.BEDROCK_CONVERSE, systemPrompt, userInstructions, question, chatHistory, contexts); - Map parameters = Map.of("model", "foo", "messages", parameter); - assertTrue(isJson(parameter)); - } - - public void testBuildMessageParameterForOpenAI() { - String systemPrompt = "You are the best."; - String userInstructions = null; - String question = "Who am I"; - List contexts = new ArrayList<>(); - List chatHistory = List - .of( - Interaction - .fromMap( - "convo1", - Map - .of( - ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, - Instant.now().toString(), - ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, - "message 1", - ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, - "answer1" - ) - ), - Interaction - .fromMap( - "convo1", - Map - .of( - ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, - Instant.now().toString(), - ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, - "message 2", - ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, - "answer2" - ) - ) - ); - contexts.add("context 1"); - contexts.add("context 2"); - String parameter = PromptUtil - .buildMessageParameter(Llm.ModelProvider.OPENAI, systemPrompt, userInstructions, question, chatHistory, contexts); + String parameter = PromptUtil.buildMessageParameter(systemPrompt, userInstructions, question, chatHistory, contexts); Map parameters = Map.of("model", "foo", "messages", parameter); assertTrue(isJson(parameter)); } @@ -167,139 +117,6 @@ public void testBuildBedrockInputParameter() { assertTrue(parameter.contains(systemPrompt)); } - public void testBuildBedrockConverseInputParameter() { - String systemPrompt = "You are the best."; - String userInstructions = null; - String question = "Who am I"; - List contexts = new ArrayList<>(); - List chatHistory = List - .of( - Interaction - .fromMap( - "convo1", - Map - .of( - ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, - Instant.now().toString(), - ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, - "message 1", - ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, - "answer1" - ) - ), - Interaction - .fromMap( - "convo1", - Map - .of( - ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, - Instant.now().toString(), - ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, - "message 2", - ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, - "answer2" - ) - ) - ); - contexts.add("context 1"); - contexts.add("context 2"); - MessageBlock.TextBlock tb = new MessageBlock.TextBlock("text"); - MessageBlock.ImageBlock ib = new MessageBlock.ImageBlock("jpeg", "data", null); - MessageBlock.DocumentBlock db = new MessageBlock.DocumentBlock("pdf", "file1", "data"); - List blocks = List.of(tb, ib, db); - MessageBlock mb = new MessageBlock(); - mb.setBlockList(blocks); - List llmMessages = List.of(mb); - String parameter = PromptUtil - .buildMessageParameter( - Llm.ModelProvider.BEDROCK_CONVERSE, - systemPrompt, - userInstructions, - question, - chatHistory, - contexts, - llmMessages - ); - assertTrue(parameter.contains(systemPrompt)); - } - - public void testBuildOpenAIInputParameter() { - String systemPrompt = "You are the best."; - String userInstructions = null; - String question = "Who am I"; - List contexts = new ArrayList<>(); - List chatHistory = List - .of( - Interaction - .fromMap( - "convo1", - Map - .of( - ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, - Instant.now().toString(), - ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, - "message 1", - ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, - "answer1" - ) - ), - Interaction - .fromMap( - "convo1", - Map - .of( - ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, - Instant.now().toString(), - ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, - "message 2", - ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, - "answer2" - ) - ) - ); - contexts.add("context 1"); - contexts.add("context 2"); - MessageBlock.TextBlock tb = new MessageBlock.TextBlock("text"); - MessageBlock.ImageBlock ib = new MessageBlock.ImageBlock("jpeg", "data", null); - MessageBlock.ImageBlock ib2 = new MessageBlock.ImageBlock("jpeg", null, "https://xyz/foo.jpg"); - List blocks = List.of(tb, ib, ib2); - MessageBlock mb = new MessageBlock(); - mb.setBlockList(blocks); - List llmMessages = List.of(mb); - String parameter = PromptUtil - .buildMessageParameter(Llm.ModelProvider.OPENAI, systemPrompt, userInstructions, question, chatHistory, contexts, llmMessages); - assertTrue(parameter.contains(systemPrompt)); - } - - public void testGetPromptTemplate() { - String systemPrompt = "you are a helpful assistant."; - String userInstructions = "lay out your answer as a sequence of steps."; - String actual = PromptUtil.getPromptTemplate(systemPrompt, userInstructions); - assertTrue(actual.contains(systemPrompt)); - assertTrue(actual.contains(userInstructions)); - } - - public void testMessageCtor() { - PromptUtil.Message message = new PromptUtil.Message(PromptUtil.ChatRole.USER, new PromptUtil.OpenAIContent()); - assertEquals(message.getChatRole(), PromptUtil.ChatRole.USER); - } - - public void testBedrockContentCtor() { - PromptUtil.Content content = new PromptUtil.BedrockContent("text", "foo"); - assertTrue(content.toJson().toString().contains("foo")); - } - - public void testMessageArrayBuilderCtor1() { - exceptionRule.expect(IllegalArgumentException.class); - PromptUtil.MessageArrayBuilder builder = new PromptUtil.MessageArrayBuilder(Llm.ModelProvider.COHERE); - } - - public void testMessageArrayBuilderInvalidUsage1() { - exceptionRule.expect(RuntimeException.class); - PromptUtil.MessageArrayBuilder builder = new PromptUtil.MessageArrayBuilder(Llm.ModelProvider.OPENAI); - builder.addTextContent("boom"); - } - private boolean isJson(String Json) { try { new JSONObject(Json);