Skip to content

Commit

Permalink
Add test coverage.
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Lee <[email protected]>
  • Loading branch information
austintlee committed Mar 21, 2024
1 parent bb6f5c2 commit 8bea195
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,8 @@ public void testChatCompletionBedrockThrowingError() throws Exception {

public void testIllegalArgument1() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Unknown/unsupported model provider: null. You must provide a valid model provider or llm_response_field.");
exceptionRule
.expectMessage("Unknown/unsupported model provider: null. You must provide a valid model provider or llm_response_field.");
MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class);
ArgumentCaptor<MLInput> captor = ArgumentCaptor.forClass(MLInput.class);
DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client);
Expand Down Expand Up @@ -459,7 +460,8 @@ public void testIllegalArgument1() {

public void testIllegalArgument2() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Unknown/unsupported model provider: null. You must provide a valid model provider or llm_response_field.");
exceptionRule
.expectMessage("Unknown/unsupported model provider: null. You must provide a valid model provider or llm_response_field.");
MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class);
ArgumentCaptor<MLInput> captor = ArgumentCaptor.forClass(MLInput.class);
DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,34 @@ public void testChatCompletionInput() {
ChatCompletionInput input = LlmIOUtil
.createChatCompletionInput("model", "question", Collections.emptyList(), Collections.emptyList(), 0, null);
assertTrue(input instanceof ChatCompletionInput);
assertEquals(Llm.ModelProvider.OPENAI, input.getModelProvider());
}

public void testChatCompletionInputForBedrock() {
ChatCompletionInput input = LlmIOUtil
.createChatCompletionInput("bedrock/model", "question", Collections.emptyList(), Collections.emptyList(), 0, null);
.createChatCompletionInput(
LlmIOUtil.BEDROCK_PROVIDER_PREFIX + "model",
"question",
Collections.emptyList(),
Collections.emptyList(),
0,
null
);
assertTrue(input instanceof ChatCompletionInput);
assertEquals(Llm.ModelProvider.BEDROCK, input.getModelProvider());
}

public void testChatCompletionInputForCohere() {
ChatCompletionInput input = LlmIOUtil
.createChatCompletionInput(
LlmIOUtil.COHERE_PROVIDER_PREFIX + "model",
"question",
Collections.emptyList(),
Collections.emptyList(),
0,
null
);
assertTrue(input instanceof ChatCompletionInput);
assertEquals(Llm.ModelProvider.COHERE, input.getModelProvider());
}
}

0 comments on commit 8bea195

Please sign in to comment.