From 583bb4b5c80e6fa1225046f41fe97bb60a660c78 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 1 Feb 2024 11:51:29 +0400 Subject: [PATCH] Some clean after remarks in merged #82 (#184) fix --- serve/mlc_serve/model/model_common.py | 9 +++++++++ serve/mlc_serve/model/tvm_model.py | 23 +++++++---------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index b9e23ddad0..dcbe3cd42b 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -39,6 +39,15 @@ def get_num_cache_blocks( ) +def get_logprob_infos( + i: int, + logprob_infos: Optional[RawLogprobsInfos], +) -> Optional[RawLogprobsInfos]: + if logprob_infos is None or logprob_infos[i] is None: + return None + return [logprob_infos[i]] + + def get_raw_logprob_info( logits, token_id, diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index eb5cfc30d1..c3d0b556f9 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -14,6 +14,7 @@ from .model_common import ( sample, prepare_inputs, + get_logprob_infos, get_num_cache_blocks, ) @@ -204,16 +205,6 @@ def profile_memory_usage(self, seq_lens): return self.get_used_memory() - def get_logprob_infos( - self, - i: int, - logprob_infos: Optional[RawLogprobsInfos], - ) -> Optional[RawLogprobsInfos]: - if logprob_infos is None or logprob_infos[i] is None: - return None - return [logprob_infos[i]] - - def generate( self, requests: Sequence[Union[PrefillRequest, DecodeRequest]], @@ -352,7 +343,7 @@ def generate( sequence_id=SequenceId(sequence_id.request_id, seq_id), generated_tokens=[new_token], error=None, - logprob_info=self.get_logprob_infos(i, logprob_infos), + logprob_info=get_logprob_infos(i, logprob_infos), ) ) else: @@ -361,7 +352,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[new_token], error=None, - logprob_info=self.get_logprob_infos(i, logprob_infos), + logprob_info=get_logprob_infos(i, logprob_infos), ) ) @@ -401,7 +392,7 @@ def generate( ), generated_tokens=[new_token], # type: ignore error=None, - logprob_info=self.get_logprob_infos(0, logprob_infos), + logprob_info=get_logprob_infos(0, logprob_infos), ) ) else: @@ -410,7 +401,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[new_token], # type: ignore error=None, - logprob_info=self.get_logprob_infos(0, logprob_infos), + logprob_info=get_logprob_infos(0, logprob_infos), ) ) else: @@ -423,7 +414,7 @@ def generate( ), generated_tokens=[], error=err_msg, - logprob_info=self.get_logprob_infos(0, logprob_infos), + logprob_info=get_logprob_infos(0, logprob_infos), ) ) else: @@ -432,7 +423,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[], error=err_msg, - logprob_info=self.get_logprob_infos(0, logprob_infos), + logprob_info=get_logprob_infos(0, logprob_infos), ) )