From 8bea195f0a0dead06be4d101abf7eca44e2ba6f7 Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Wed, 20 Mar 2024 18:35:26 -0700 Subject: [PATCH] Add test coverage. Signed-off-by: Austin Lee --- .../generative/llm/DefaultLlmImplTests.java | 6 +++-- .../generative/llm/LlmIOUtilTests.java | 25 ++++++++++++++++++- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java index 6a0cc968cb..5a3978539c 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java @@ -429,7 +429,8 @@ public void testChatCompletionBedrockThrowingError() throws Exception { public void testIllegalArgument1() { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Unknown/unsupported model provider: null. You must provide a valid model provider or llm_response_field."); + exceptionRule + .expectMessage("Unknown/unsupported model provider: null. You must provide a valid model provider or llm_response_field."); MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class); ArgumentCaptor captor = ArgumentCaptor.forClass(MLInput.class); DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client); @@ -459,7 +460,8 @@ public void testIllegalArgument1() { public void testIllegalArgument2() { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Unknown/unsupported model provider: null. You must provide a valid model provider or llm_response_field."); + exceptionRule + .expectMessage("Unknown/unsupported model provider: null. You must provide a valid model provider or llm_response_field."); MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class); ArgumentCaptor captor = ArgumentCaptor.forClass(MLInput.class); DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtilTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtilTests.java index 18f49da42b..2f8fb0fca2 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtilTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtilTests.java @@ -31,11 +31,34 @@ public void testChatCompletionInput() { ChatCompletionInput input = LlmIOUtil .createChatCompletionInput("model", "question", Collections.emptyList(), Collections.emptyList(), 0, null); assertTrue(input instanceof ChatCompletionInput); + assertEquals(Llm.ModelProvider.OPENAI, input.getModelProvider()); } public void testChatCompletionInputForBedrock() { ChatCompletionInput input = LlmIOUtil - .createChatCompletionInput("bedrock/model", "question", Collections.emptyList(), Collections.emptyList(), 0, null); + .createChatCompletionInput( + LlmIOUtil.BEDROCK_PROVIDER_PREFIX + "model", + "question", + Collections.emptyList(), + Collections.emptyList(), + 0, + null + ); assertTrue(input instanceof ChatCompletionInput); + assertEquals(Llm.ModelProvider.BEDROCK, input.getModelProvider()); + } + + public void testChatCompletionInputForCohere() { + ChatCompletionInput input = LlmIOUtil + .createChatCompletionInput( + LlmIOUtil.COHERE_PROVIDER_PREFIX + "model", + "question", + Collections.emptyList(), + Collections.emptyList(), + 0, + null + ); + assertTrue(input instanceof ChatCompletionInput); + assertEquals(Llm.ModelProvider.COHERE, input.getModelProvider()); } }