Skip to content

Commit

Permalink
fix mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
vvchernov committed Dec 29, 2023
1 parent 43b4625 commit ea7a66c
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 37 deletions.
4 changes: 2 additions & 2 deletions serve/mlc_serve/api/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse, StreamingResponse

# TODO(amalyshe): hadnle random_seed
# TODO(amalyshe): handle random_seed
# from .base import set_global_random_seed
from ..api.protocol import (
ChatCompletionRequest,
Expand Down Expand Up @@ -217,7 +217,7 @@ async def collect_result_stream(
finish_reasons = [None] * num_sequences
num_prompt_tokens = 0
num_generated_tokens = [0 for _ in range(num_sequences)]
logprob_infos = [[] for _ in range(num_sequences)]
logprob_infos = [[] for _ in range(num_sequences)] # type: ignore
async for res in result_generator:
# TODO: verify that the request cancellation happens after this returns
if res.error:
Expand Down
2 changes: 1 addition & 1 deletion serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class SequenceOutput:
finish_reason: Optional[FinishReason] = None
# Number of generated tokens so far
num_generated_tokens: int = 0
logprob_info: Optional[Tuple[Tuple, List[Tuple]]] = None
logprob_info: Optional[Tuple[Tuple, Dict[str, float]]] = None

@property
def is_finished(self) -> bool:
Expand Down
6 changes: 3 additions & 3 deletions serve/mlc_serve/engine/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@


class DummyInferenceEngine:
def __init__(self):
self.queue_lock = Lock()
self.has_new_requests = Condition(self.queue_lock)
def __init__(self) -> None:
self.queue_lock: Lock = Lock()
self.has_new_requests: Condition = Condition(self.queue_lock)
self.request_queue: Dict[RequestId, int] = {}

def add(self, requests: list[Request]):
Expand Down
26 changes: 1 addition & 25 deletions serve/mlc_serve/engine/staging_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import queue
from threading import Lock
from collections import defaultdict
from typing import Callable, Tuple, List, Dict
from typing import Callable

import structlog

Expand Down Expand Up @@ -40,30 +40,6 @@
LOG = structlog.stdlib.get_logger(__name__)


def logprob_detokenize(tokenizer: Tokenizer, logprob_info: Tuple[Tuple, List[Tuple]]) -> Tuple[Tuple, List[Tuple]]:
"""Detokenize logprob information"""
if logprob_info is None:
return None
(res, res_logprob), top_tokens = logprob_info
top_tokens = list(top_tokens)
count = {}
logprob_dict = {}
# dedup duplicates
# Todo: Make sure decode can generate different tokens
for top_token, _ in top_tokens:
detokenized = tokenizer.decode(top_token)
if detokenized in count:
count[detokenized] += 1
else:
count[detokenized] = 1
for top_token, top_logprob in top_tokens:
detokenized = tokenizer.decode(top_token)
if count[detokenized] == 1:
logprob_dict[detokenized] = float(top_logprob)
else:
logprob_dict[f"{detokenized}_{top_token}"] = float(top_logprob)
return (str(tokenizer.decode(res)), res_logprob), logprob_dict

class StagingInferenceEngine(ScopedInferenceEngine):
"""
An implementation of InferenceEngine that offloads the text generation loop to another worker process,
Expand Down
12 changes: 6 additions & 6 deletions serve/mlc_serve/model/paged_cache_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from collections import defaultdict
from typing import List, Optional
from typing import Any, List, Optional

from ..engine import (
RequestId,
Expand Down Expand Up @@ -104,18 +104,18 @@ def replace_head_prompt_block_with(self, new_block):
class KVCache:
def __init__(
self,
cache_blocks,
block_size,
cache_blocks: Any,
block_size: int,
):
self.cache_blocks = cache_blocks
self.block_size = block_size

# SequenceId -> list[int]
self.prompt_block_tables = defaultdict(list)
self.slot_mappings = defaultdict(list)
self.prompt_block_tables = defaultdict(list) # type: ignore
self.slot_mappings = defaultdict(list) # type: ignore

# The core data structure
self.decode_block_tables = dict[SequenceId, DecodeBlockTable]()
self.decode_block_tables: dict = dict[SequenceId, DecodeBlockTable]()

# Record indices of blocks to copy after prefill in the format [src1, dst1, src2, dst2, ...]
self.pending_copy_from_to: list[int] = []
Expand Down

0 comments on commit ea7a66c

Please sign in to comment.