From 9b053e872088db68a1919390b399e79d08f7ad94 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Wed, 22 Nov 2023 00:47:10 +0000 Subject: [PATCH] Squashed commit for logprobs implementation. Init with tests. Server working. Major fix, serve working great. Minor fix and tests. Remove extra line. fix log_softmax use constant for number of top logprobs small clean upstream to new OpenAI API Co-authored-by: Valery Chernov --- serve/mlc_serve/api/handler.py | 36 +++++-- serve/mlc_serve/api/protocol.py | 10 +- serve/mlc_serve/engine/__init__.py | 2 +- serve/mlc_serve/engine/async_connector.py | 5 +- serve/mlc_serve/engine/base.py | 3 +- serve/mlc_serve/engine/model_module.py | 6 +- serve/mlc_serve/engine/sampling_params.py | 16 +++ serve/mlc_serve/engine/staging_engine.py | 29 +++++- .../mlc_serve/engine/staging_engine_worker.py | 21 ++-- serve/mlc_serve/engine/sync_engine.py | 1 + serve/mlc_serve/model/paged_cache_model.py | 98 ++++++++++++++----- .../unittest/test_engine_with_samplers.py | 46 ++++++++- 12 files changed, 223 insertions(+), 50 deletions(-) diff --git a/serve/mlc_serve/api/handler.py b/serve/mlc_serve/api/handler.py index 8457502c38..b2370eebd3 100644 --- a/serve/mlc_serve/api/handler.py +++ b/serve/mlc_serve/api/handler.py @@ -42,7 +42,6 @@ def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse router = APIRouter() - def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams: sampling_params = SamplingParams( # These params came from vllm @@ -60,6 +59,9 @@ def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams: sampling_params.temperature = request.temperature if request.top_p is not None: sampling_params.top_p = request.top_p + if request.logprobs is not None: + sampling_params.top_logprobs = request.top_logprobs + sampling_params.logprobs = request.logprobs return sampling_params @@ -152,7 +154,7 @@ async def generate_completion_stream( created_time = int(time.time()) def create_stream_response( - choices: list[ChatCompletionResponseStreamChoice], + choices: List[ChatCompletionResponseStreamChoice], ) -> ChatCompletionStreamResponse: return ChatCompletionStreamResponse( id=request_id, @@ -172,7 +174,6 @@ def create_stream_response( ], ) yield f"data: {json.dumps(first_chunk.dict(exclude_unset=True), ensure_ascii=False)}\n\n" - async for res in result_generator: if res.error: raise RuntimeError(f"Error when generating: {res.error}") @@ -188,6 +189,7 @@ def create_stream_response( finish_reason=seq.finish_reason.value if seq.finish_reason is not None else None, + logprob_info=seq.logprob_info[0] if seq.logprob_info else None ) for seq in res.sequences ] @@ -208,6 +210,7 @@ async def collect_result_stream( finish_reasons = [None] * num_sequences num_prompt_tokens = 0 num_generated_tokens = [0 for _ in range(num_sequences)] + logprob_infos = [[] for _ in range(num_sequences)] async for res in result_generator: # TODO: verify that the request cancellation happens after this returns if res.error: @@ -215,6 +218,8 @@ async def collect_result_stream( if res.num_prompt_tokens is not None: num_prompt_tokens = res.num_prompt_tokens for seq in res.sequences: + if seq.logprob_info: + logprob_infos[seq.index].append(seq.logprob_info) if seq.index >= len(sequences): raise RuntimeError(f"Unexpected sequence index: {seq.index}.") num_generated_tokens[seq.index] = seq.num_generated_tokens @@ -224,15 +229,30 @@ async def collect_result_stream( else: assert seq.delta is not None sequences[seq.index].append(seq.delta) - - choices = [ - ChatCompletionResponseChoice( + + choices = [] + for index, (chunks, finish_reason) in enumerate(zip(sequences, finish_reasons)): + choice = ChatCompletionResponseChoice( index=index, message=ChatMessage(role="assistant", content="".join(chunks)), finish_reason=finish_reason, ) - for index, (chunks, finish_reason) in enumerate(zip(sequences, finish_reasons)) - ] + content = [] + if logprob_infos[index] != []: + for logprob_info in logprob_infos[index]: + content.append({ + "token": str(logprob_info[0][0]), + "logprob": float(logprob_info[0][1]), + # TODO(vvchernov): implement bytes bases on https://platform.openai.com/docs/api-reference/chat/object + "bytes": None, + "top_logprobs": [{ + "token": top_logprob[0], + "logprob": top_logprob[1], + "bytes": None, + } for top_logprob in logprob_info[1]], + }) + choice.logprobs.content = content + choices.append(choice) usage = UsageInfo( prompt_tokens=num_prompt_tokens, diff --git a/serve/mlc_serve/api/protocol.py b/serve/mlc_serve/api/protocol.py index 5271272e63..971a9f56d5 100644 --- a/serve/mlc_serve/api/protocol.py +++ b/serve/mlc_serve/api/protocol.py @@ -2,7 +2,7 @@ # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py # https://github.com/vllm-project/vllm/blob/acbed3ef40f015fcf64460e629813922fab90380/vllm/entrypoints/openai/protocol.py import time -from typing import Dict, List, Literal, Optional, Union +from typing import Dict, List, Literal, Optional, Union, Tuple from pydantic import BaseModel, Field @@ -70,11 +70,18 @@ class ChatCompletionRequest(BaseModel): logit_bias: Optional[Dict[str, float]] = None user: Optional[str] = None ignore_eos: Optional[bool] = False + logprobs: Optional[bool] = False + top_logprobs: Optional[int] = None + + +class Logprobs(BaseModel): + content: Optional[List[Dict]] class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage + logprobs: Optional[Logprobs] finish_reason: Optional[Literal["stop", "length", "cancelled"]] = None @@ -95,6 +102,7 @@ class DeltaMessage(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage + logprob_info: Optional[Tuple[Tuple, List[Tuple]]] finish_reason: Optional[Literal["stop", "length"]] = None diff --git a/serve/mlc_serve/engine/__init__.py b/serve/mlc_serve/engine/__init__.py index 129b7c05ed..5068200b95 100644 --- a/serve/mlc_serve/engine/__init__.py +++ b/serve/mlc_serve/engine/__init__.py @@ -17,4 +17,4 @@ PROMPT_SEQEUNCE_INDEX, get_prompt_sequence_id, ) -from .sampling_params import SamplingParams, SamplingType +from .sampling_params import SamplingParams, SamplingType, TOP_LOGPROBS_NUMBER diff --git a/serve/mlc_serve/engine/async_connector.py b/serve/mlc_serve/engine/async_connector.py index c7d5d3d7b0..1bf261be10 100644 --- a/serve/mlc_serve/engine/async_connector.py +++ b/serve/mlc_serve/engine/async_connector.py @@ -1,6 +1,7 @@ import asyncio import structlog -from typing import AsyncIterator, Any +from typing import AsyncIterator, Any, Dict +import logging from .base import ( InferenceEngine, @@ -29,7 +30,7 @@ def __init__(self, engine: InferenceEngine, engine_wait_timeout=1): self.engine_loop_task = None self.engine_loop_exception = None self.shutdown_event = asyncio.Event() - self.result_queues = dict[RequestId, ResultQueue]() + self.result_queues: Dict[RequestId, ResultQueue] = {} async def start(self): """ diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index 9551b6c446..cfd3f465d9 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -4,7 +4,7 @@ from enum import Enum from abc import ABC, abstractmethod -from typing import List, Callable, Any, Optional, Dict +from typing import List, Callable, Any, Optional, Dict, Tuple import inspect from .sampling_params import SamplingParams, SamplingType @@ -161,6 +161,7 @@ class SequenceOutput: finish_reason: Optional[FinishReason] = None # Number of generated tokens so far num_generated_tokens: int = 0 + logprob_info: Optional[Tuple[Tuple, List[Tuple]]] = None @property def is_finished(self) -> bool: diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index 9b018c6cc4..320e0bc72c 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -2,13 +2,16 @@ Required interfaces for the actual inference capability in InferenceEngine. """ from dataclasses import dataclass -from typing import Optional, Protocol, Union, List +from typing import Optional, Protocol, Union, Tuple, List from .base import ChatMessage, RequestId, MLCServeEngineConfig, RequestState, SequenceId from ..model.base import ModelArtifactConfig from .sampling_params import SamplingParams +LOGPROBS_TYPE = Tuple[Tuple, List[Tuple]] +# ((token, logprob), [(top1_token, top1_logprob), ...]) + @dataclass class PrefillRequest: request_id: RequestId @@ -44,6 +47,7 @@ class TextGenerationResult: # making this a list of token ids to leave room for speculative decoding generated_tokens: List[int] error: Optional[str] + logprob_info: Optional[Tuple[Tuple, List[Tuple]]] = None class KVCache(Protocol): diff --git a/serve/mlc_serve/engine/sampling_params.py b/serve/mlc_serve/engine/sampling_params.py index fbe153283d..b721881197 100644 --- a/serve/mlc_serve/engine/sampling_params.py +++ b/serve/mlc_serve/engine/sampling_params.py @@ -7,8 +7,10 @@ from enum import IntEnum from functools import cached_property +from typing import Optional _SAMPLING_EPS = 1e-5 +TOP_LOGPROBS_NUMBER = 5 class SamplingType(IntEnum): @@ -37,6 +39,13 @@ class SamplingParams: to consider. Must be in (0, 1]. Set to 1 to consider all tokens. top_k: Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens. + logprobs: Optional[bool] Whether to return log probabilities of the output + tokens or not. If true, returns the log probabilities of each output + token returned in the content of message. + top_logprobs: Optional[Integer] An integer between 0 and 5 specifying + the number of most likely tokens to return at each token position, + each with an associated log probability. logprobs must be set to + true if this parameter is used. """ presence_penalty: float = 0.0 @@ -44,6 +53,8 @@ class SamplingParams: temperature: float = 1.0 top_p: float = 1.0 top_k: int = -1 + logprobs: Optional[bool] = False + top_logprobs: Optional[int] = None def __post_init__(self): self._verify_args() @@ -71,6 +82,11 @@ def _verify_args(self) -> None: raise ValueError( f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}." ) + if self.logprobs is not None and self.logprobs: + if (self.top_logprobs < 0 or self.top_logprobs > TOP_LOGPROBS_NUMBER): + raise ValueError( + f"top_logprobs must be between 0 and {TOP_LOGPROBS_NUMBER}, got {self.top_logprobs}." + ) def _verify_greedy_sampling(self) -> None: if self.top_p < 1.0 - _SAMPLING_EPS: diff --git a/serve/mlc_serve/engine/staging_engine.py b/serve/mlc_serve/engine/staging_engine.py index 5aa7d823a6..3fd9c5152d 100644 --- a/serve/mlc_serve/engine/staging_engine.py +++ b/serve/mlc_serve/engine/staging_engine.py @@ -5,8 +5,8 @@ import multiprocessing import queue from threading import Lock -from typing import Callable from collections import defaultdict +from typing import Callable, Tuple, List import structlog @@ -23,7 +23,7 @@ get_new_request_state, update_sequence, ) -from .model_module import ModelModule, TokenizerModule +from .model_module import ModelModule, TokenizerModule, Tokenizer from .staging_engine_worker import ( AddRequestsCommand, CancelRequestCommand, @@ -37,6 +37,30 @@ LOG = structlog.stdlib.get_logger(__name__) +def logprob_detokenize(tokenizer: Tokenizer, logprob_info: Tuple[Tuple, List[Tuple]]) -> Tuple[Tuple, List[Tuple]]: + """Detokenize logprob information""" + if logprob_info is None: + return None + (res, res_logprob), top_tokens = logprob_info + top_tokens = list(top_tokens) + count = {} + logprob_dict = {} + # dedup duplicates + # Todo: Make sure decode can generate different tokens + for top_token, _ in top_tokens: + detokenized = tokenizer.decode(top_token) + if detokenized in count: + count[detokenized] += 1 + else: + count[detokenized] = 1 + for top_token, top_logprob in top_tokens: + detokenized = tokenizer.decode(top_token) + if count[detokenized] == 1: + logprob_dict[detokenized] = float(top_logprob) + else: + logprob_dict[f"{detokenized}_{top_token}"] = float(top_logprob) + return (str(tokenizer.decode(res)), res_logprob), logprob_dict + class StagingInferenceEngine(ScopedInferenceEngine): """ An implementation of InferenceEngine that offloads the text generation loop to another worker process, @@ -235,6 +259,7 @@ def step(self) -> InferenceStepResult: delta, finish_reason=seq_output.finish_reason, num_generated_tokens=len(gen_seq.generated_token_ids), + logprob_info=logprob_detokenize(self.tokenizer, seq_output.logprob_info), ) seq_outputs[request_id].append(output) diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index 02c4e8d5ab..7a205ac256 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -5,10 +5,10 @@ import multiprocessing import multiprocessing.synchronize from dataclasses import dataclass -from threading import Thread -from typing import Callable, Optional, Union, Any, Dict, List - +from threading import Condition, Lock, Thread +from typing import Callable, Optional, Union, Tuple, Any, Dict, Deque, List import structlog +import numpy as np from .base import FinishReason, RequestId, RequestState, ValidationError, SequenceId from .metrics import PrometheusMetrics @@ -33,7 +33,7 @@ class ShutdownCommand: @dataclass class AddRequestsCommand: - request_states: list[RequestState] + request_states: List[RequestState] @dataclass @@ -54,14 +54,15 @@ class StopRequestCommand: @dataclass class SequenceGenerationOutput: id: SequenceId - new_tokens: list[int] + new_tokens: List[int] finish_reason: Optional[FinishReason] = None error: Optional[Union[str, ValidationError]] = None + logprob_info: Optional[Tuple[Tuple, List[Tuple]]] = None @dataclass class GenerationLoopWorkerOutput: - sequences: list[SequenceGenerationOutput] + sequences: List[SequenceGenerationOutput] error: Optional[str] = None @@ -77,8 +78,8 @@ def __init__( ): EngineBase.__init__(self, model_module) - self.cancelled_requests = list[RequestState]() - self.stopped_requests = list[RequestState]() + self.cancelled_requests: List[RequestState] = [] + self.stopped_requests: List[RequestState] = [] self.prom_metrics = PrometheusMetrics() self.inv_kv_cache_size = 1.0 / self.cache_manager.get_kv_cache_size() @@ -167,7 +168,7 @@ def has_pending_requests(self) -> bool: def step(self) -> GenerationLoopWorkerOutput: LOG.debug("Starting new inference step.") - outputs = list[SequenceGenerationOutput]() + outputs: List[SequenceGenerationOutput] = [] result = GenerationLoopWorkerOutput(sequences=outputs) # TODO: consolidate into a single function @@ -263,7 +264,7 @@ def step(self) -> GenerationLoopWorkerOutput: gen_seq.generated_token_ids.extend(new_tokens) outputs.append( - SequenceGenerationOutput(id=res.sequence_id, new_tokens=new_tokens) + SequenceGenerationOutput(id=res.sequence_id, new_tokens=new_tokens, logprob_info=res.logprob_info) ) if is_prompt_batch: diff --git a/serve/mlc_serve/engine/sync_engine.py b/serve/mlc_serve/engine/sync_engine.py index befe8f48ba..748e09c771 100644 --- a/serve/mlc_serve/engine/sync_engine.py +++ b/serve/mlc_serve/engine/sync_engine.py @@ -218,6 +218,7 @@ def step(self) -> InferenceStepResult: delta, num_generated_tokens=len(gen_seq.generated_token_ids), finish_reason=finish_reason, + logprob_info=res.logprob_info, ) ) diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index f442fe402d..a32162e160 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -1,7 +1,9 @@ import math import os -from typing import List, Union, Optional from pathlib import Path +from collections import defaultdict +from typing import List, Union, Optional, Tuple +from dataclasses import dataclass import structlog import numpy as np @@ -19,6 +21,7 @@ SamplingType, MLCServeEngineConfig, SamplingParams, + TOP_LOGPROBS_NUMBER, SequenceId, PROMPT_SEQEUNCE_INDEX, get_prompt_sequence_id, @@ -27,6 +30,7 @@ DecodeRequest, PrefillRequest, TextGenerationResult, + LOGPROBS_TYPE ) from ..engine.model_module import ModelModule @@ -61,7 +65,7 @@ def sample( sampling_params: List[SamplingParams], vocab_size: int, check_safety=False, -) -> Optional[np.ndarray]: +) -> Optional[Tuple[np.ndarray, Optional[LOGPROBS_TYPE]]]: def _is_safe_to_sample(prob_like): return ( torch.sum(torch.isnan(prob_like) | torch.isinf(prob_like) | (prob_like < 0)) @@ -80,10 +84,21 @@ def _is_safe_to_sample(prob_like): logits_greedy = logits[mask_greedy] if logits_greedy.shape[0] > 0: - res_greedy = torch.argmax(logits_greedy, -1).cpu().numpy() - + # Greedy sampling + logprobs = torch.log_softmax(logits_greedy, dim=-1) + res_greedy_logprob, res_greedy = torch.max(logprobs, dim=-1) + + top_greedy_logprob, top_greedy = torch.topk( + logprobs, k=TOP_LOGPROBS_NUMBER, dim=-1, largest=True, sorted=True + ) + # Convert to numpy + res_greedy_logprob = res_greedy_logprob.cpu().numpy() + res_greedy = res_greedy.cpu().numpy() + top_greedy_logprob = top_greedy_logprob.cpu().numpy() + top_greedy = top_greedy.cpu().numpy() + # Case when there's only greedy sampling if logits_greedy.shape[0] == num_seq: - return res_greedy + return res_greedy, ((res_greedy, res_greedy_logprob), (top_greedy, top_greedy_logprob)) temperatures = [] top_ps = [] @@ -114,22 +129,40 @@ def _is_safe_to_sample(prob_like): logits = _apply_top_p_top_k(logits_random, top_ps, top_ks) probs = torch.softmax(logits_random, dim=-1) + logprobs = torch.log_softmax(logits_greedy, dim=-1) + top_random_logprob, top_random = torch.topk( + logprobs, k=TOP_LOGPROBS_NUMBER, dim=-1, largest=True, sorted=True + ) + top_random_logprob = top_random_logprob.cpu().numpy() + top_random = top_random.cpu().numpy() if check_safety and not _is_safe_to_sample(probs): return None res_random = torch.multinomial(probs, 1, True).cpu().numpy()[:, 0] + res_random_logprobs = torch.gather(logprobs, dim=-1, index=torch.tensor(res_random, dtype=torch.int64, device=logits.device)).cpu().numpy() if logits_random.shape[0] == num_seq: - return res_random + return res_random, (res_random_logprobs, (top_random, top_random_logprob)) res = np.empty((num_seq,), dtype=np.int32) + res_logprobs = np.empty((num_seq,), dtype=np.float32) + top = np.empty((num_seq, TOP_LOGPROBS_NUMBER), dtype=np.int32) + top_logprobs = np.empty((num_seq, TOP_LOGPROBS_NUMBER), dtype=np.float32) + res[mask_random] = res_random + res_logprobs[mask_random] = res_random_logprobs + top[mask_random] = top_random + top_logprobs[mask_random] = top_random_logprob + if logits_greedy.shape[0] > 0: res[mask_greedy] = res_greedy + res_logprobs[mask_greedy] = res_greedy_logprob + top[mask_greedy] = top_greedy + top_logprobs[mask_greedy] = top_greedy_logprob - return res + return res, ((res, res_logprobs), (top, top_logprobs)) def load_disco_module(artifact_path, lib_path, num_shards): @@ -175,6 +208,26 @@ def get_tvm_model(config, dev): return load_disco_module(config.model_artifact_path, lib_path, config.num_shards) +def fetch_logprobs( + logprob_info: LOGPROBS_TYPE, + index: int, + sampling_param: SamplingParams, + ) -> Optional[Tuple[np.ndarray, List[Tuple[np.ndarray, np.ndarray]]]]: + """Fetch the logprob information with index""" + if ( + sampling_param.logprobs is None or + not sampling_param.logprobs or + logprob_info is None + ): + return None + (res, res_logprobs), (top, top_logprobs) = logprob_info + return (res[index],res_logprobs[index]), \ + zip( + top[index][:sampling_param.top_logprobs], + top_logprobs[index][:sampling_param.top_logprobs] + ) + + def _prepare_inputs( sequence_ids, all_token_ids, @@ -412,13 +465,6 @@ def generate( slot_mapping, self.params, ) - - if self.disco_session: - logits, _ = out.debug_get_from_remote(0) - else: - logits = out[ - 0 - ] # Ignore returned KV cache since it is updated in-place anyway. else: torch.cuda.nvtx.range_push(f"forward decode {input_shape}") @@ -435,10 +481,12 @@ def generate( self.params, ) - if self.disco_session: - logits, _ = out.debug_get_from_remote(0) - else: - logits = out[0] + if self.disco_session: + logits, _ = out.debug_get_from_remote(0) + else: + logits = out[ + 0 + ] # Ignore returned KV cache since it is updated in-place anyway. torch.cuda.synchronize() torch.cuda.nvtx.range_pop() @@ -460,7 +508,7 @@ def generate( cache.pending_copy_from_to = [] try: - next_tokens = sample(logits, sampling_params, self.vocab_size) + next_tokens, logprob_info = sample(logits, sampling_params, self.vocab_size) assert next_tokens is not None outputs = [] @@ -474,6 +522,7 @@ def generate( sequence_id=SequenceId(sequence_id.request_id, seq_id), generated_tokens=[new_token], error=None, + logprob_info=fetch_logprobs(logprob_info, seq_id, sampling_params[seq_id]), ) ) else: @@ -482,6 +531,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[new_token], error=None, + logprob_info=fetch_logprobs(logprob_info, index, sampling_params[index]), ) ) @@ -497,7 +547,7 @@ def generate( for i, (sequence_id, logits_per_token, sampling_param) in enumerate( zip(sequence_ids, torch.from_dlpack(logits), sampling_params) ): - maybe_new_token = sample( + maybe_new_token, logprob_info = sample( torch.unsqueeze(logits_per_token, 0), [sampling_param], self.vocab_size, @@ -514,6 +564,7 @@ def generate( ), generated_tokens=[maybe_new_token[0]], # type: ignore error=None, + logprob_info=fetch_logprobs(logprob_info, i, sampling_param) ) ) else: @@ -522,6 +573,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[maybe_new_token[0]], # type: ignore error=None, + logprob_info=fetch_logprobs(logprob_info, i, sampling_param) ) ) else: @@ -534,6 +586,7 @@ def generate( ), generated_tokens=[], error=err_msg, + logprob_info=fetch_logprobs(logprob_info, i, sampling_param) ) ) else: @@ -542,6 +595,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[], error=err_msg, + logprob_info=fetch_logprobs(logprob_info, i, sampling_param) ) ) @@ -575,8 +629,8 @@ def __init__(self, model: Model): self.model = model def generate( - self, requests: list[Union[PrefillRequest, DecodeRequest]], kv_cache - ) -> list[TextGenerationResult]: + self, requests: List[Union[PrefillRequest, DecodeRequest]], kv_cache + ) -> List[TextGenerationResult]: prefill_requests = [r for r in requests if isinstance(r, PrefillRequest)] decode_requests = [r for r in requests if isinstance(r, DecodeRequest)] diff --git a/serve/tests/unittest/test_engine_with_samplers.py b/serve/tests/unittest/test_engine_with_samplers.py index b600892730..c3785f839c 100644 --- a/serve/tests/unittest/test_engine_with_samplers.py +++ b/serve/tests/unittest/test_engine_with_samplers.py @@ -53,13 +53,15 @@ def create_engine( )) return engine -def create_request(idx, prompt, temp, max_tokens, stop, ignore_eos): +def create_request(idx, prompt, temp, max_tokens, stop, ignore_eos, top_logprobs): return Request( request_id = str(idx), messages = [ChatMessage(role="user", content=prompt)], sampling_params = SamplingParams( temperature=0.0, - ), + logprobs=True, + top_logprobs=top_logprobs + ), stopping_criteria = StoppingCriteria( max_tokens=max_tokens, stop_sequences=stop @@ -219,6 +221,43 @@ def _test_stop( if use_staging_engine: engine.stop() +def _test_logprobs( + model_artifact_path, + use_staging_engine, + max_num_sequences=4, + max_input_len=512, + num_requests=5, + top_logprobs=3, +): + prompt = "hi" + engine = create_engine( + model_artifact_path, + use_staging_engine, + max_num_sequences, + max_input_len, + ) + s = 113 + requests = [create_request(idx=str(n-s), prompt=prompt, temp=0, max_tokens=n, stop=None, ignore_eos=True, top_logprobs=top_logprobs) for n in range(s, s+num_requests)] + engine.add(requests) + + generated = ["" for _ in range(num_requests)] + + while engine.has_pending_requests(): + results = engine.step() + for res in results.outputs: + assert len(res.sequences) == 1 + seq = res.sequences[0] + + assert seq.finish_reason is not None or len(list(seq.logprobs.content[0]["top_logprobs"])) == top_logprobs + + if seq.is_finished: + assert seq.num_generated_tokens == requests[int(res.request_id)].stopping_criteria.max_tokens + assert seq.finish_reason == FinishReason.Length + else: + generated[int(res.request_id)] += seq.delta + + if use_staging_engine: + engine.stop() if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -233,7 +272,10 @@ def _test_stop( _test_ignore_eos(model_artifact_path, use_staging_engine=False) _test_stop(model_artifact_path, use_staging_engine=False) _test_stop(model_artifact_path, use_staging_engine=True) + _test_logprobs(model_artifact_path, use_staging_engine=True) + _test_logprobs(model_artifact_path, use_staging_engine=False) # These tests are broken since we are now imposing no length limit # if max_tokens = None. The tests do not finish in a reasonable time. # _test_max_context_length(model_artifact_path, use_staging_engine=True) # _test_max_context_length(model_artifact_path, use_staging_engine=False) +