Skip to content

Commit

Permalink
Some clean after remarks in merged #82 (#184)
Browse files Browse the repository at this point in the history
fix
  • Loading branch information
vvchernov committed Feb 1, 2024
1 parent 0fcfd18 commit 583bb4b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
9 changes: 9 additions & 0 deletions serve/mlc_serve/model/model_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ def get_num_cache_blocks(
)


def get_logprob_infos(
i: int,
logprob_infos: Optional[RawLogprobsInfos],
) -> Optional[RawLogprobsInfos]:
if logprob_infos is None or logprob_infos[i] is None:
return None
return [logprob_infos[i]]


def get_raw_logprob_info(
logits,
token_id,
Expand Down
23 changes: 7 additions & 16 deletions serve/mlc_serve/model/tvm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .model_common import (
sample,
prepare_inputs,
get_logprob_infos,
get_num_cache_blocks,
)

Expand Down Expand Up @@ -204,16 +205,6 @@ def profile_memory_usage(self, seq_lens):

return self.get_used_memory()

def get_logprob_infos(
self,
i: int,
logprob_infos: Optional[RawLogprobsInfos],
) -> Optional[RawLogprobsInfos]:
if logprob_infos is None or logprob_infos[i] is None:
return None
return [logprob_infos[i]]


def generate(
self,
requests: Sequence[Union[PrefillRequest, DecodeRequest]],
Expand Down Expand Up @@ -352,7 +343,7 @@ def generate(
sequence_id=SequenceId(sequence_id.request_id, seq_id),
generated_tokens=[new_token],
error=None,
logprob_info=self.get_logprob_infos(i, logprob_infos),
logprob_info=get_logprob_infos(i, logprob_infos),
)
)
else:
Expand All @@ -361,7 +352,7 @@ def generate(
sequence_id=sequence_id,
generated_tokens=[new_token],
error=None,
logprob_info=self.get_logprob_infos(i, logprob_infos),
logprob_info=get_logprob_infos(i, logprob_infos),
)
)

Expand Down Expand Up @@ -401,7 +392,7 @@ def generate(
),
generated_tokens=[new_token], # type: ignore
error=None,
logprob_info=self.get_logprob_infos(0, logprob_infos),
logprob_info=get_logprob_infos(0, logprob_infos),
)
)
else:
Expand All @@ -410,7 +401,7 @@ def generate(
sequence_id=sequence_id,
generated_tokens=[new_token], # type: ignore
error=None,
logprob_info=self.get_logprob_infos(0, logprob_infos),
logprob_info=get_logprob_infos(0, logprob_infos),
)
)
else:
Expand All @@ -423,7 +414,7 @@ def generate(
),
generated_tokens=[],
error=err_msg,
logprob_info=self.get_logprob_infos(0, logprob_infos),
logprob_info=get_logprob_infos(0, logprob_infos),
)
)
else:
Expand All @@ -432,7 +423,7 @@ def generate(
sequence_id=sequence_id,
generated_tokens=[],
error=err_msg,
logprob_info=self.get_logprob_infos(0, logprob_infos),
logprob_info=get_logprob_infos(0, logprob_infos),
)
)

Expand Down

0 comments on commit 583bb4b

Please sign in to comment.