Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

405 abstract ModelType - inline ModelType #477

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import io.ktor.utils.io.core.*
*/
interface AutoClose : AutoCloseable {
fun <A : AutoCloseable> autoClose(autoCloseable: A): A

fun <A : AutoCloseable> A.autoCloseBind(): A = autoClose(this)
}

/** DSL method to use AutoClose */
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.xebia.functional.xef.llm

import com.xebia.functional.xef.llm.models.MaxIoContextLength
import com.xebia.functional.xef.llm.models.chat.Message

interface BaseChat : LLM {

val contextLength: MaxIoContextLength

@Deprecated(
"will be removed from LLM in favor of abstracting former ModelType, use contextLength instead"
)
val maxContextLength
get() =
(contextLength as? MaxIoContextLength.Combined)?.total
?: error(
"accessing maxContextLength requires model's context length to be of type MaxIoContextLength.Combined"
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a side note, again - this is supposed to be an intermediary solution. Usages of this field that use an OAI model will still work as before. If this field is called on an instance of a (Google) model that doesn't use the combined context length an exception imminent.
I found this to be the best way to handle it, in favor to not change too much code in one PR.


fun countTokens(text: String): Int

fun truncateText(text: String, maxTokens: Int): String

fun tokensFromMessages(messages: List<Message>): Int
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.prompt.templates.assistant
import kotlinx.coroutines.flow.*

interface Chat : LLM {
interface Chat : BaseChat {

suspend fun createChatCompletion(request: ChatCompletionRequest): ChatCompletionResponse

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import kotlinx.serialization.KSerializer
import kotlinx.serialization.descriptors.*
import kotlinx.serialization.json.*

interface ChatWithFunctions : LLM {
interface ChatWithFunctions : BaseChat {

suspend fun createChatCompletionWithFunctions(
request: FunChatCompletionRequest
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
package com.xebia.functional.xef.llm

import com.xebia.functional.xef.llm.models.MaxIoContextLength
import com.xebia.functional.xef.llm.models.text.CompletionRequest
import com.xebia.functional.xef.llm.models.text.CompletionResult

interface Completion : LLM {
val contextLength: MaxIoContextLength

suspend fun createCompletion(request: CompletionRequest): CompletionResult

fun countTokens(text: String): Int

fun truncateText(text: String, maxTokens: Int): String
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ interface Embeddings : LLM {
texts
.chunked(chunkSize ?: 400)
.parMap {
createEmbeddings(EmbeddingRequest(modelType.name, it, requestConfig.user.id)).data
createEmbeddings(EmbeddingRequest(modelID.value, it, requestConfig.user.id)).data
}
.flatten()

Expand Down
36 changes: 10 additions & 26 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/llm/LLM.kt
Original file line number Diff line number Diff line change
@@ -1,38 +1,22 @@
package com.xebia.functional.xef.llm

import com.xebia.functional.tokenizer.Encoding
import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.xef.llm.models.chat.Message
import com.xebia.functional.xef.llm.models.ModelID

sealed interface LLM : AutoCloseable {
// sealed modifier temporarily removed as OAI's implementation of tokensFromMessages has to extend
// and override LLM
/*sealed */ interface LLM : AutoCloseable {

val modelType: ModelType
val modelID: ModelID

@Deprecated("use modelType.name instead", replaceWith = ReplaceWith("modelType.name"))
@Deprecated("use modelID.value instead", replaceWith = ReplaceWith("modelID.value"))
val name
get() = modelType.name
get() = modelID.value

/**
* Copies this instance and uses [modelType] for [LLM.modelType]. Has to return the most specific
* type of this instance!
* Copies this instance and uses [modelID] for the new instances' [LLM.modelID]. Has to return the
* most specific type of this instance!
*/
fun copy(modelType: ModelType): LLM

fun tokensFromMessages(
messages: List<Message>
): Int { // TODO: naive implementation with magic numbers
fun Encoding.countTokensFromMessages(tokensPerMessage: Int, tokensPerName: Int): Int =
messages.sumOf { message ->
countTokens(message.role.name) +
countTokens(message.content) +
tokensPerMessage +
tokensPerName
} + 3
return modelType.encoding.countTokensFromMessages(
tokensPerMessage = modelType.tokensPerMessage,
tokensPerName = modelType.tokensPerName
) + modelType.tokenPadding
}
fun copy(modelID: ModelID): LLM

override fun close() = Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ import com.xebia.functional.xef.llm.models.chat.Message
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.prompt.templates.assistant
import com.xebia.functional.xef.store.Memory
import kotlin.math.floor
import kotlin.math.roundToInt

internal object PromptCalculator {

suspend fun adaptPromptToConversationAndModel(
prompt: Prompt,
scope: Conversation,
llm: LLM
llm: BaseChat
): Prompt =
when (prompt.configuration.messagePolicy.addMessagesFromConversation) {
MessagesFromHistory.ALL -> adaptPromptFromConversation(prompt, scope, llm)
Expand All @@ -24,7 +26,7 @@ internal object PromptCalculator {
private suspend fun adaptPromptFromConversation(
prompt: Prompt,
scope: Conversation,
llm: LLM
llm: BaseChat
): Prompt {

// calculate tokens for history and context
Expand Down Expand Up @@ -55,7 +57,7 @@ internal object PromptCalculator {
if (ctxInfo.isNotEmpty()) {
val ctx: String = ctxInfo.joinToString("\n")

val ctxTruncated: String = llm.modelType.encoding.truncateText(ctx, maxContextTokens)
val ctxTruncated: String = llm.truncateText(ctx, maxContextTokens)

Prompt { +assistant(ctxTruncated) }.messages
} else {
Expand All @@ -69,7 +71,7 @@ internal object PromptCalculator {
memories.map { it.content }

private fun calculateMessagesFromHistory(
llm: LLM,
llm: BaseChat,
memories: List<Memory>,
maxHistoryTokens: Int
) =
Expand Down Expand Up @@ -97,18 +99,18 @@ internal object PromptCalculator {

private fun calculateMaxContextTokens(prompt: Prompt, remainingTokensForContexts: Int): Int {
val contextPercent = prompt.configuration.messagePolicy.contextPercent
val maxContextTokens = (remainingTokensForContexts * contextPercent) / 100
val maxContextTokens = floor(remainingTokensForContexts * (contextPercent / 100f)).roundToInt()
return maxContextTokens
}

private fun calculateMaxHistoryTokens(prompt: Prompt, remainingTokensForContexts: Int): Int {
val historyPercent = prompt.configuration.messagePolicy.historyPercent
val maxHistoryTokens = (remainingTokensForContexts * historyPercent) / 100
val maxHistoryTokens = floor(remainingTokensForContexts * (historyPercent / 100f)).roundToInt()
return maxHistoryTokens
}

private fun calculateRemainingTokensForContext(llm: LLM, prompt: Prompt): Int {
val maxContextLength: Int = llm.modelType.maxContextLength
private fun calculateRemainingTokensForContext(llm: BaseChat, prompt: Prompt): Int {
val maxContextLength: Int = llm.maxContextLength
val remainingTokens: Int = maxContextLength - prompt.configuration.minResponseTokens

val messagesTokens = llm.tokensFromMessages(prompt.messages)
Expand All @@ -121,6 +123,6 @@ internal object PromptCalculator {
return remainingTokensForContexts
}

private suspend fun Conversation.memories(llm: LLM, limitTokens: Int): List<Memory> =
private suspend fun Conversation.memories(llm: BaseChat, limitTokens: Int): List<Memory> =
conversationId?.let { store.memories(llm, it, limitTokens) } ?: emptyList()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.xebia.functional.xef.llm.models

/**
* Describing the maximum context length a model with text input and output might have.
*
* Some models from VertexAI (in 2023/10) have both types of max context length.
*/
sealed interface MaxIoContextLength {
/** one total length of input and output combined */
data class Combined(val total: Int) : MaxIoContextLength

/** two separate max lengths for input and output respectively */
data class Fix(val input: Int, val output: Int) : MaxIoContextLength
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.xebia.functional.xef.llm.models

import kotlin.jvm.JvmInline

@JvmInline value class ModelID(val value: String)
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.xebia.functional.xef.store

import com.xebia.functional.xef.llm.LLM
import com.xebia.functional.xef.llm.BaseChat
import com.xebia.functional.xef.llm.models.embeddings.Embedding

/**
Expand All @@ -13,7 +13,7 @@ class CombinedVectorStore(private val top: VectorStore, private val bottom: Vect
VectorStore by top {

override suspend fun memories(
llm: LLM,
llm: BaseChat,
conversationId: ConversationId,
limitTokens: Int
): List<Memory> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import arrow.atomic.Atomic
import arrow.atomic.AtomicInt
import arrow.atomic.getAndUpdate
import arrow.atomic.update
import com.xebia.functional.xef.llm.BaseChat
import com.xebia.functional.xef.llm.Embeddings
import com.xebia.functional.xef.llm.LLM
import com.xebia.functional.xef.llm.models.embeddings.Embedding
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
import kotlin.math.sqrt
Expand Down Expand Up @@ -55,7 +55,7 @@ private constructor(private val embeddings: Embeddings, private val state: Atomi
}

override suspend fun memories(
llm: LLM,
llm: BaseChat,
conversationId: ConversationId,
limitTokens: Int
): List<Memory> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package com.xebia.functional.xef.store

import com.xebia.functional.xef.llm.LLM
import com.xebia.functional.xef.llm.BaseChat

fun List<Memory>.reduceByLimitToken(llm: LLM, limitTokens: Int): List<Memory> {
fun List<Memory>.reduceByLimitToken(llm: BaseChat, limitTokens: Int): List<Memory> {
val tokensFromMessages = llm.tokensFromMessages(map { it.content })
return if (tokensFromMessages <= limitTokens) this
else
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.xebia.functional.xef.store

import arrow.atomic.AtomicInt
import com.xebia.functional.xef.llm.LLM
import com.xebia.functional.xef.llm.BaseChat
import com.xebia.functional.xef.llm.models.embeddings.Embedding
import kotlin.jvm.JvmStatic

Expand All @@ -15,7 +15,11 @@ interface VectorStore {

suspend fun addMemories(memories: List<Memory>)

suspend fun memories(llm: LLM, conversationId: ConversationId, limitTokens: Int): List<Memory>
suspend fun memories(
llm: BaseChat,
conversationId: ConversationId,
limitTokens: Int
): List<Memory>

/**
* Add texts to the vector store after running them through the embeddings
Expand Down Expand Up @@ -56,7 +60,7 @@ interface VectorStore {
override suspend fun addMemories(memories: List<Memory>) {}

override suspend fun memories(
llm: LLM,
llm: BaseChat,
conversationId: ConversationId,
limitTokens: Int
): List<Memory> = emptyList()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package com.xebia.functional.xef.textsplitters

import com.xebia.functional.tokenizer.Encoding
import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.tokenizer.EncodingType

fun TokenTextSplitter(modelType: ModelType, chunkSize: Int, chunkOverlap: Int): TextSplitter =
TokenTextSplitterImpl(modelType.encoding, chunkSize, chunkOverlap)
fun TokenTextSplitter(encodingType: EncodingType, chunkSize: Int, chunkOverlap: Int): TextSplitter =
TokenTextSplitterImpl(encodingType.encoding, chunkSize, chunkOverlap)

private class TokenTextSplitterImpl(
private val tokenizer: Encoding,
Expand Down
Loading