Skip to content

Commit

Permalink
API Consistency (octoml#199)
Browse files Browse the repository at this point in the history
This PR makes API to be consistent across iOS and C++ implementation.

- Use ChatModule as the wrapper API.
- Use the same function name across implementations
  modulo style to match language native preference
- Initial round of documentations about these API
  • Loading branch information
tqchen committed May 21, 2023
1 parent ecea7b5 commit b76d5d3
Show file tree
Hide file tree
Showing 6 changed files with 255 additions and 143 deletions.
190 changes: 131 additions & 59 deletions cpp/cli_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,43 @@ struct ModelPaths {
const std::string& local_id);
};

struct LLMChatModule {
/*!
* \brief Helper class to implement chat features.
*
* A common streaming chat flow can be implemented as follows:
*
* \code
*
* void SingleRound(const std::string& input) {
* // prefill and decode first token for given input
* chat->Prefill(input);
* // check if the current round stops
* while (!chat->Stopped()) {
* // get the latest message and display it
* RefreshCurrentReply(chat->GetMessage());
* // decode the next token
* chat->Decode();
* }
* }
*
* \endcode
*
* \note GetMessage function will return the complete latest message.
* This is useful in most UIs that directly replaces the entire
* textbox content.
*
* Implementation detail: this class is a thin wrapper of TVM runtime
* API that can also be exposed in other language runtimes.
* Look for the name ChatModule in other apps(android, iOS) and you will
* find functions with similar names.
*/
class ChatModule {
public:
explicit LLMChatModule(const DLDevice& device) {
/*!
* \brief Constructor
* \param device the device to run the chat on.
*/
explicit ChatModule(const DLDevice& device) {
this->chat_mod_ = mlc::llm::CreateChatModule(device);
this->prefill_ = this->chat_mod_->GetFunction("prefill");
this->decode_ = this->chat_mod_->GetFunction("decode");
Expand All @@ -190,65 +224,53 @@ struct LLMChatModule {
ICHECK(runtime_stats_text_ != nullptr);
ICHECK(reset_chat_ != nullptr);
}

/*!
* \brief Reload the module to a new model path.
* \param model The model path spec.
*/
void Reload(const ModelPaths& model) {
std::string model_path = model.config.parent_path().string();
tvm::runtime::Module executable = tvm::runtime::Module::LoadFromFile(model.lib.string());
reload_(executable, tvm::String(model_path));
}

std::string Role0() { return get_role0_(); }

std::string Role1() { return get_role1_(); }

std::string Stats() { return runtime_stats_text_(); }
/*!
* \brief Reset the current chat session.
* \note The model remains the same, to change model, call Reload.
*/
void ResetChat() { reset_chat_(); }

void Reset() { reset_chat_(); }
/*! \return Role0(user) name in the chat template. */
std::string GetRole0() { return get_role0_(); }

void Converse(const std::string& input, int stream_interval, std::ostream& os) {
this->Prefill(input);
/*! \return Role1(bot) name in the chat template. */
std::string GetRole1() { return get_role1_(); }

std::string cur_msg = "";
std::vector<std::string> cur_utf8_chars = CountUTF8(cur_msg);
/*! \return A text describing the runtime statistics. */
std::string RuntimeStatsText() { return runtime_stats_text_(); }

os << this->Role1() << ": " << std::flush;
for (size_t i = 0; !this->IsStopped(); ++i) {
this->Decode();
if (i % stream_interval == 0 || this->IsStopped()) {
std::string new_msg = GetMessage();
std::vector<std::string> new_utf8_chars = CountUTF8(new_msg);
// Step 1. Find the index of the first UTF8 char that differs
size_t pos = std::mismatch(cur_utf8_chars.begin(), cur_utf8_chars.end(),
new_utf8_chars.begin(), new_utf8_chars.end())
.first -
cur_utf8_chars.begin();
// Step 2. Delete the previous message since `pos`
std::string print = "";
for (size_t j = pos; j < cur_utf8_chars.size(); ++j) {
print += "\b \b";
}
// Step 3. Print the new message since `pos`
for (size_t j = pos; j < new_utf8_chars.size(); ++j) {
print += new_utf8_chars[j];
}
os << print << std::flush;
cur_msg = std::move(new_msg);
cur_utf8_chars = std::move(new_utf8_chars);
}
}
os << std::endl << std::flush;
}

protected:
// Low-level APIs
/*!
* \brief Run prefill stage for a given input and decode the first output token.
* \param input the user input.
*/
void Prefill(const std::string& input) { prefill_(input); }

/*!
* \brief Run one decode step to decode the next token.
*/
void Decode() { decode_(); }

bool IsStopped() { return stopped_(); }
/*! \return Whether the current round stopped. */
bool Stopped() { return stopped_(); }

/*!
* \return Get the output message in the current round.
* \note This function returns the message that corresponds to
* all the tokens decoded so far.
*/
std::string GetMessage() { return get_message_(); }

protected:
// TVM Modules and functions with TVM's calling convention
tvm::runtime::Module chat_mod_;
tvm::runtime::PackedFunc prefill_;
Expand All @@ -267,11 +289,11 @@ std::optional<std::filesystem::path> TryInferMLCChatConfig(const std::string& ar
return FindFile(
{
//
artifact_path + "/" + local_id + "/params", //
artifact_path + "/prebuilt/" + local_id, //
artifact_path + "/prebuilt/mlc-chat-" + local_id, //
}, //
{"mlc-chat-config"}, //
artifact_path + "/" + local_id + "/params", //
artifact_path + "/prebuilt/" + local_id, //
artifact_path + "/prebuilt/mlc-chat-" + local_id, //
}, //
{"mlc-chat-config"}, //
{".json"});
}

Expand Down Expand Up @@ -341,32 +363,82 @@ ModelPaths ModelPaths::Find(const std::string& artifact_path, const std::string&
return ModelPaths{config_path, params_json, lib_path};
}

/*!
* \brief Implementation of one round chat.
* \param chat The chat module.
* \param input The input prompt.
* \param stream_interval Refresh rate
* \param os output stream
*/
void Converse(ChatModule* chat, const std::string& input, int stream_interval,
std::ostream& os) { // NOLINT(*)
chat->Prefill(input);

std::string cur_msg = "";
std::vector<std::string> cur_utf8_chars = CountUTF8(cur_msg);

os << chat->GetRole1() << ": " << std::flush;
for (size_t i = 0; !chat->Stopped(); ++i) {
chat->Decode();
if (i % stream_interval == 0 || chat->Stopped()) {
std::string new_msg = chat->GetMessage();
// NOTE: display the new message.
// The main complication here is that new_msg can be different
// from prevous message, so we need to find the diff,
// delete previous messages that are different, then print it out.
// This logic is only needed for simple stdout.
//
// For UI apps that can directly update output text
// we can simply do last_reply.text = chat->GetMessage();
std::vector<std::string> new_utf8_chars = CountUTF8(new_msg);
// Step 1. Find the index of the first UTF8 char that differs
size_t pos = std::mismatch(cur_utf8_chars.begin(), cur_utf8_chars.end(),
new_utf8_chars.begin(), new_utf8_chars.end())
.first -
cur_utf8_chars.begin();
// Step 2. Delete the previous message since `pos`
std::string print = "";
for (size_t j = pos; j < cur_utf8_chars.size(); ++j) {
print += "\b \b";
}
// Step 3. Print the new message since `pos`
for (size_t j = pos; j < new_utf8_chars.size(); ++j) {
print += new_utf8_chars[j];
}
os << print << std::flush;
cur_msg = std::move(new_msg);
cur_utf8_chars = std::move(new_utf8_chars);
}
}
os << std::endl << std::flush;
}

/*!
* \brief Start a chat conversation.
*
* \param chat_mod The chat module.
* \param chat The chat module.
* \param executable The model library to initialize the chat module.
* \param model_path The model path with contains the model config, tokenizer and parameters.
*/
void Chat(LLMChatModule chat, const std::string& artifact_path, const std::string& device_name,
void Chat(ChatModule* chat, const std::string& artifact_path, const std::string& device_name,
std::string local_id, int stream_interval = 2) {
ModelPaths model = ModelPaths::Find(artifact_path, device_name, local_id);
std::cout << "Loading model..." << std::endl;
PrintSpecialCommands();
chat.Reload(model);
chat->Reload(model);
while (true) {
std::string input;
std::cout << chat.Role0() << ": " << std::flush;
std::cout << chat->GetRole0() << ": " << std::flush;
std::getline(std::cin, input);
if (!std::cin.good()) {
break;
} else if (input.substr(0, 6) == "/reset") {
chat.Reset();
chat->ResetChat();
std::cout << "RESET CHAT SUCCESS" << std::endl << std::flush;
} else if (input.substr(0, 5) == "/exit") {
break;
} else if (input.substr(0, 6) == "/stats") {
std::cout << chat.Stats() << std::endl << std::flush;
std::cout << chat->RuntimeStatsText() << std::endl << std::flush;
} else if (input.substr(0, 7) == "/reload") {
std::string new_local_id;
{
Expand All @@ -379,13 +451,13 @@ void Chat(LLMChatModule chat, const std::string& artifact_path, const std::strin
}
model = ModelPaths::Find(artifact_path, device_name, new_local_id);
std::cout << "Loading model..." << std::endl;
chat.Reload(model);
chat->Reload(model);
local_id = new_local_id;
std::cout << "LOAD MODEL " << local_id << " SUCCESS" << std::endl << std::flush;
} else if (input.substr(0, 5) == "/help") {
PrintSpecialCommands();
} else {
chat.Converse(input, stream_interval, std::cout);
Converse(chat, input, stream_interval, std::cout);
}
}
}
Expand Down Expand Up @@ -441,7 +513,7 @@ int main(int argc, char* argv[]) {
}

try {
LLMChatModule chat(GetDevice(device_name, device_id));
ChatModule chat(GetDevice(device_name, device_id));
if (args.get<bool>("--evaluate")) {
// `--evaluate` is only used for performance debugging, and thus will call low-level APIs that
// are not supposed to be used in chat app setting
Expand All @@ -452,7 +524,7 @@ int main(int argc, char* argv[]) {
chat_mod.GetFunction("reload")(lib, tvm::String(model_path));
chat_mod.GetFunction("evaluate")();
} else {
Chat(chat, artifact_path, device_name, local_id);
Chat(&chat, artifact_path, device_name, local_id);
}
} catch (const std::runtime_error& err) {
std::cerr << err.what() << std::endl;
Expand Down
7 changes: 7 additions & 0 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1109,6 +1109,13 @@ class LLMChat {
NDArray logits_on_cpu_{nullptr};
};

/*!
* \brief A chat module implementation that exposes
* the functions as tvm::runtime::Module.
*
* We do it so that the module is accessible to any
* language that tvm runtime can access.
*/
class LLMChatModule : public ModuleNode {
public:
// clear global memory manager
Expand Down
4 changes: 2 additions & 2 deletions ios/MLCChat/ChatState.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class ChatState : ObservableObject {
@Published var unfinishedRespondRole = MessageRole.bot;
@Published var unfinishedRespondMessage = "";
private var threadWorker = ThreadWorker();
private var backend = LLMChatInstance();
private var backend = ChatModule();

private var stopLock = NSLock();
private var requestedReset = false;
Expand Down Expand Up @@ -190,7 +190,7 @@ class ChatState : ObservableObject {
self.requestedReset = true;

threadWorker.push {
self.backend.reset()
self.backend.resetChat();
DispatchQueue.main.sync {
self.mainResetChat();
}
Expand Down
Loading

0 comments on commit b76d5d3

Please sign in to comment.