Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide a client for Azure OpenAI GPT model service #467

Merged
merged 11 commits into from
Jun 29, 2023
Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package com.microsoft.hydralab.center.openai;

import com.alibaba.fastjson.JSON;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Objects;

// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
public class AzureOpenAIServiceClient {
hydraxman marked this conversation as resolved.
Show resolved Hide resolved
private final Logger logger = LoggerFactory.getLogger(AzureOpenAIServiceClient.class);
private final String apiKey;
private final String endpoint;
private final String deployment;
private final String apiVersion;
private final OkHttpClient client = new OkHttpClient();

public AzureOpenAIServiceClient(String apiKey, String deployment, String endpoint, String apiVersion) {
this.apiKey = apiKey;
this.endpoint = endpoint.endsWith("/") ? endpoint.substring(0, endpoint.length() - 1) : endpoint;
this.deployment = deployment;
this.apiVersion = apiVersion;
}

public String chatCompletion(ChatRequest request) {
MediaType mediaType = MediaType.parse("application/json");
String url = String.format("%s/openai/deployments/%s/chat/completions?api-version=%s", endpoint, deployment, apiVersion);

String requestBodyString = JSON.toJSONString(request);
logger.info("Request body: {}", requestBodyString);

RequestBody body = RequestBody.create(requestBodyString, mediaType);
Request httpRequest = new Request.Builder().url(url).post(body)
// .addHeader("Content-Type", "application/json")
hydraxman marked this conversation as resolved.
Show resolved Hide resolved
.addHeader("api-key", apiKey).build();

try (Response response = client.newCall(httpRequest).execute()) {
if (!response.isSuccessful()) {
throw new RuntimeException("Unexpected response code: " + response);
}
return Objects.requireNonNull(response.body()).string();
} catch (Exception e) {
throw new RuntimeException("Error occurred while invoking Azure OpenAI API: " + e.getMessage(), e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.microsoft.hydralab.center.openai;

import lombok.Data;

// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
@Data
hydraxman marked this conversation as resolved.
Show resolved Hide resolved
public class ChatMessage {
private String role;
private String content;

public ChatMessage(String role, String content) {
this.role = role;
this.content = content;
}

@SuppressWarnings("InterfaceIsType")
public interface Role {
String USER = "user";
String SYSTEM = "system";
String ASSISTANT = "assistant";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.microsoft.hydralab.center.openai;

import com.alibaba.fastjson.annotation.JSONField;
import lombok.Data;

import java.util.List;

// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
@Data
hydraxman marked this conversation as resolved.
Show resolved Hide resolved
public class ChatRequest {
private List<ChatMessage> messages;
@JSONField(name = "max_tokens")
private int maxTokens = 800;
private double temperature = 0.75;
@JSONField(name = "frequency_penalty")
private double frequencyPenalty = 0;
@JSONField(name = "presence_penalty")
private double presencePenalty = 0;
@JSONField(name = "top_p")
private double topP = 0.95;
private String stop;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package com.microsoft.hydralab.center.openai;

import com.microsoft.hydralab.center.test.BaseTest;
import org.junit.jupiter.api.Test;
import org.junit.platform.commons.util.StringUtils;
import org.springframework.beans.factory.annotation.Value;

import java.util.Arrays;

// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
public class AzureOpenAIServiceClientTest extends BaseTest {
@Value("${app.openai.chat-completion.api-key:}")
private String apiKey;
@Value("${app.openai.chat-completion.endpoint-url:}")
private String endpoint;
@Value("${app.openai.chat-completion.deployment:}")
private String deployment;
@Value("${app.openai.chat-completion.api-version:}")
private String apiVersion;

@Test
public void createAzureOpenAIServiceClientAndAsk() {
if (StringUtils.isBlank(apiKey)) {
return;
}
AzureOpenAIServiceClient azureOpenAIServiceClient = new AzureOpenAIServiceClient(apiKey, deployment, endpoint, apiVersion);
ChatRequest request = new ChatRequest();
request.setMessages(Arrays.asList(
new ChatMessage(ChatMessage.Role.SYSTEM, "You are an AI assistant that helps people find information."),
new ChatMessage(ChatMessage.Role.USER, "Could you tell me a joke?")
));
baseLogger.info(azureOpenAIServiceClient.chatCompletion(request));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
*/
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
@ExtendWith(SpringExtension.class)
@ActiveProfiles("test")
@ActiveProfiles({"test", "local"})
@EnableCaching
@Transactional
@Rollback
Expand Down