diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index f60a894e3b..f51aeede82 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -316,6 +316,7 @@ class RequestState: request_id: RequestId prompt_token_ids: list[int] + prompt_mask: Optional[list[bool]] sampling_params: SamplingParams generation_sequences: list[GenerationSequence] stopping_criteria: StoppingCriteria diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 1ddf6fa7fa..c434781994 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -2,11 +2,11 @@ Common utilites for engine classes. """ -import torch import time -from typing import Tuple, Deque, Dict, Optional, Union, Callable, List +from typing import Tuple, Deque, Dict, Optional, Callable, List from collections import deque from threading import Condition, Lock +import numpy as np import structlog @@ -38,8 +38,12 @@ LOG = structlog.stdlib.get_logger(__name__) +# TODO: check if we can drop vocab_size here def get_new_request_state( - request: Request, conversation_template: ConversationTemplate, tokenizer: TokenizerP + request: Request, + conversation_template: ConversationTemplate, + tokenizer: TokenizerP, + vocab_size: int, ) -> RequestState: if request.debug_options.prompt_token_ids is not None: prompt_token_ids = request.debug_options.prompt_token_ids @@ -51,6 +55,14 @@ def get_new_request_state( prompt_token_ids = tokenizer.encode(prompt) + # TODO: Currently, always create this. But we only need this for the requests with repetition penalty + # Follow-up and optimize when it has been stabilized. + # Create prompt mask for repetition penalty + tokens = np.array([prompt_token_ids], dtype=np.int64) + prompt_mask = np.zeros((vocab_size + 1,), dtype=bool) + prompt_mask[tokens] = True + prompt_mask = list(prompt_mask[:vocab_size]) + validation_err = None if request.validate_tokens is not None: validation_err = request.validate_tokens(request, prompt_token_ids) @@ -68,6 +80,7 @@ def get_new_request_state( return RequestState( request_id=request.request_id, prompt_token_ids=prompt_token_ids, + prompt_mask=prompt_mask, generation_sequences=gen_seqs, sampling_params=request.sampling_params, stopping_criteria=request.stopping_criteria, @@ -111,9 +124,7 @@ def detokenize_incrementally( prefix_end_offset = max(len(output_tokens) - 1, 0) else: # Put new_token_id in a list so skip_special_tokens is respected - new_tokens = tokenizer.convert_ids_to_tokens( - [new_token_id] - ) + new_tokens = tokenizer.convert_ids_to_tokens([new_token_id]) output_tokens = generation_sequence.prev_tokens + new_tokens prefix_begin_offset = generation_sequence.prefix_begin_offset @@ -241,18 +252,6 @@ def prepare_output( return delta, out_logprob_info -def set_mask_prompt_to(state: RequestState): - # Prompt tokens - tokens=torch.tensor(state.prompt_token_ids, dtype=torch.long) - vocab_size = state.sampling_params.vocab_size - bin_counts = torch.zeros((vocab_size + 1,), - dtype=torch.long, - device=tokens.device) - bin_counts.scatter_add_(0, tokens, torch.ones_like(tokens)) - bin_counts = bin_counts[:vocab_size] - state.sampling_params.mask_prompt = bin_counts > 0 - - def get_requests_to_process( current_states: list[RequestState], cache_manager: KVCacheManager, @@ -277,13 +276,11 @@ def get_requests_to_process( if is_prompt_batch: for state in current_states: if is_evicted_parallel_sampling_request(state): - # TODO(vvchernov): we still need mask if apply_penalty = True - # if state.sampling_params.repetition_penalty != 1.0: - # set_mask_prompt_to(state) requests.append( PrefillRequest( request_id=state.request_id, token_ids=state.prompt_token_ids, + prompt_mask=state.prompt_mask, num_sequence=state.num_sequences, sampling_params=state.sampling_params, ) @@ -327,13 +324,11 @@ def get_requests_to_process( else: token_ids = state.prompt_token_ids - # TODO(vvchernov): we still need mask if apply_penalty = True - # if state.sampling_params.repetition_penalty != 1.0: - set_mask_prompt_to(state) requests.append( PrefillRequest( request_id=state.request_id, token_ids=token_ids, + prompt_mask=state.prompt_mask, num_sequence=state.num_sequences, sampling_params=state.sampling_params, ) @@ -355,6 +350,7 @@ def get_requests_to_process( DecodeRequest( sequence_id=gen_seq.seq_id, prompt_token_counts=prompt_counts, + prompt_mask=state.prompt_mask, token_ids=gen_seq.generated_token_ids, sampling_params=state.sampling_params, ) diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index d542d06440..a7581f8f1a 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -21,6 +21,7 @@ class PrefillRequest: request_id: RequestId # `token_ids` contains prompt token ids token_ids: List[int] + prompt_mask: Optional[List[bool]] # Number of sequences to generate num_sequence: int sampling_params: SamplingParams @@ -36,6 +37,7 @@ class PrefillRequest: class DecodeRequest: sequence_id: SequenceId prompt_token_counts: int + prompt_mask: Optional[List[bool]] # Decoded tokens for this sequence token_ids: List[int] sampling_params: SamplingParams diff --git a/serve/mlc_serve/engine/sampling_params.py b/serve/mlc_serve/engine/sampling_params.py index fe81729a39..f2638c7fa4 100644 --- a/serve/mlc_serve/engine/sampling_params.py +++ b/serve/mlc_serve/engine/sampling_params.py @@ -7,7 +7,6 @@ from enum import IntEnum from functools import cached_property from typing import Dict, Optional, Any -import torch _SAMPLING_EPS = 1e-5 LOGPROB_TOP_K_MAX = 5 @@ -76,7 +75,6 @@ class SamplingParams: vocab_size: int = 32000 json_schema: Optional[Dict[str, Any]] = None logits_processor: Optional[Any] = None - mask_prompt: Optional[torch.Tensor] = None def __post_init__(self): if self.logit_bias: diff --git a/serve/mlc_serve/engine/staging_engine.py b/serve/mlc_serve/engine/staging_engine.py index 03cebe963c..52eafd9831 100644 --- a/serve/mlc_serve/engine/staging_engine.py +++ b/serve/mlc_serve/engine/staging_engine.py @@ -120,7 +120,10 @@ def add(self, requests: list[Request]): # If the request violates the tokenization, this returns None, so skip. state = get_new_request_state( - req, self.conversation_template, self.tokenizer + req, + self.conversation_template, + self.tokenizer, + self.model_artifact_config.vocab_size, ) new_request_states.append(state) diff --git a/serve/mlc_serve/engine/sync_engine.py b/serve/mlc_serve/engine/sync_engine.py index a0e7194b22..de0157be4b 100644 --- a/serve/mlc_serve/engine/sync_engine.py +++ b/serve/mlc_serve/engine/sync_engine.py @@ -64,7 +64,10 @@ def add(self, requests: list[Request]): assert isinstance(req.stopping_criteria.stop_sequences, list) state = get_new_request_state( - req, self.conversation_template, self.tokenizer + req, + self.conversation_template, + self.tokenizer, + self.model_artifact_config.vocab_size, ) new_request_states.append(state) self.num_sequences_per_requests[state.request_id] = state.num_sequences diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index 54414f45da..b6154dca59 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -89,6 +89,7 @@ def sample_from_logits( torch_dtype: torch.dtype, torch_dev: str, past_decode_tokens: List[List[int]], + prompt_masks: List[List[bool]], ) -> List[TextGenerationResult]: batch_size = logits.shape[0] assert batch_size == len(requests) @@ -144,6 +145,7 @@ def sample_from_logits( logits_per_token = logits[i] sampling_param = sampling_state.sampling_params[i] past_decode_tokens_per_request = past_decode_tokens[i] + prompt_mask = prompt_masks[i] # NOTE: Rerun the preparation for simplicity. # Assume this code path is taken rarely and the recomputation overhead is # marginal. @@ -151,6 +153,7 @@ def sample_from_logits( new_sampling_state = SamplingState.from_sampling_params( [sampling_param], [past_decode_tokens_per_request], + [prompt_mask], torch_dtype, torch_dev, vocab_size, diff --git a/serve/mlc_serve/model/sampler.py b/serve/mlc_serve/model/sampler.py index 2f1b9d3c72..b75a24d36b 100644 --- a/serve/mlc_serve/model/sampler.py +++ b/serve/mlc_serve/model/sampler.py @@ -129,7 +129,11 @@ def from_lists( ) # `mask_top_logprob` will be on cpu mask_top_logprob = torch.from_numpy(list_mask_top_logprob) - mask_prompt = torch.stack(list_mask_prompt) + mask_prompt = torch.tensor( + list_mask_prompt, + dtype=torch.bool, + device="cpu", + ) temp = torch.tensor( list_temperatures, dtype=dtype, @@ -252,12 +256,12 @@ def from_sampling_params( cls, sampling_params: List[SamplingParams], list_past_output_tokens: List[List[int]], + list_mask_prompt: List[List[bool]], dtype: torch.dtype, dev: str, vocab_size: int, ): list_mask_random = [] - list_mask_prompt = [] list_temperatures = [] list_top_ps = [] list_top_ks = [] @@ -293,10 +297,8 @@ def from_sampling_params( if param.sampling_type == SamplingType.RANDOM: list_mask_random.append(True) idx_random += 1 - list_top_ps.append(param.top_p) - list_top_ks.append(param.top_k if param.top_k != -1 else vocab_size) - do_top_p |= list_top_ps[-1] < 1.0 - SAMPLING_EPS - do_top_k |= list_top_ks[-1] != vocab_size + do_top_p |= param.top_p < 1.0 - SAMPLING_EPS + do_top_k |= param.top_k != vocab_size else: list_mask_random.append(False) idx_greedy += 1 @@ -312,10 +314,12 @@ def from_sampling_params( or abs(param.frequency_penalty) >= SAMPLING_EPS or abs(param.repetition_penalty - 1.0) >= SAMPLING_EPS ) + + list_top_ps.append(param.top_p) + list_top_ks.append(param.top_k if param.top_k != -1 else vocab_size) list_frequency_penalties.append(param.frequency_penalty) list_presence_penalties.append(param.presence_penalty) list_repetition_penalties.append(param.repetition_penalty) - list_mask_prompt.append(param.mask_prompt) if param.logit_bias_index: assert param.logit_bias_value @@ -325,6 +329,7 @@ def from_sampling_params( list_logit_bias_values.append(param.logit_bias_value) else: list_logit_bias_indices.append([]) + list_logit_bias_values.append([]) num_random_samples = idx_random + 1 num_greedy_samples = idx_greedy + 1 @@ -387,9 +392,9 @@ def get_bin_counts_and_mask( vocab_size: int, num_seqs: int, ) -> Tuple[torch.Tensor, torch.Tensor]: - bin_counts = torch.zeros((num_seqs, vocab_size + 1), - dtype=torch.long, - device=tokens.device) + bin_counts = torch.zeros( + (num_seqs, vocab_size + 1), dtype=torch.long, device=tokens.device + ) bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) bin_counts = bin_counts[:, :vocab_size] mask = bin_counts > 0 @@ -397,10 +402,7 @@ def get_bin_counts_and_mask( return bin_counts, mask -def adjust_logits( - logits: torch.Tensor, - sampling_state: SamplingState, - vocab_size: int): +def adjust_logits(logits: torch.Tensor, sampling_state: SamplingState, vocab_size: int): batch_size = logits.shape[0] ( apply_top_p_top_k, diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index 5e117836ca..831480746c 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -263,6 +263,7 @@ def generate_multi_query( last_query_offsets: List[int] = [] sampling_params = [] past_decode_tokens = [] + prompt_masks: List[List[bool]] = [] for request in requests: assert not isinstance(request.queries, DraftTokens) sequence_ids.append(request.sequence_id) @@ -273,6 +274,8 @@ def generate_multi_query( last_query_offsets[-1] + request.queries.num_tokens ) sampling_params.append(request.sampling_params) + # TODO: Empty mask for now. This is for repetion penalty + prompt_masks.append([False]) # Use `vocab_size` as a padding past_decode_tokens.append([self.vocab_size, *request.queries.token_ids]) @@ -282,6 +285,7 @@ def generate_multi_query( sampling_state = SamplingState.from_sampling_params( sampling_params, past_decode_tokens, + prompt_masks, self.torch_dtype, self.torch_dev, self.vocab_size, @@ -344,6 +348,7 @@ def generate_multi_query( self.torch_dtype, self.torch_dev, past_decode_tokens, + prompt_masks, ) def generate( @@ -371,6 +376,7 @@ def generate( num_decode_query_tokens = 1 sampling_params = [] past_decode_tokens = [] + prompt_masks = [] for request in requests: if isinstance(request, PrefillRequest): @@ -392,6 +398,7 @@ def generate( raise Exception("`EvalMultiQueryRequest` should not reach here.") past_decode_tokens.append(request_past_decode_tokens) + prompt_masks.append(request.prompt_mask) sequence_ids.append(seq_id) assert not isinstance(request, EvalMultiQueryRequest) @@ -404,6 +411,7 @@ def generate( sampling_state = SamplingState.from_sampling_params( sampling_params, past_decode_tokens, + prompt_masks, self.torch_dtype, self.torch_dev, self.vocab_size, @@ -528,6 +536,7 @@ def generate( self.torch_dtype, self.torch_dev, past_decode_tokens, + prompt_masks, ) diff --git a/serve/tests/unittest/test_sampler.py b/serve/tests/unittest/test_sampler.py index 5017220617..d8af6d345a 100644 --- a/serve/tests/unittest/test_sampler.py +++ b/serve/tests/unittest/test_sampler.py @@ -2,21 +2,25 @@ import pytest from mlc_serve.model.sampler import SamplingState, adjust_logits from mlc_serve.engine import SamplingParams, SAMPLING_EPS +import random dtype = torch.float32 dev = "cuda" vocab_size = 32000 -def get_sampling_state(sampling_params, past_output_tokens=None): +def get_sampling_state(sampling_params, past_output_tokens=None, prompt_masks=None): batch_size = len(sampling_params) if past_output_tokens is None: past_output_tokens = [[] for _ in range(batch_size)] + if prompt_masks is None: + prompt_masks = [[] for _ in range(batch_size)] _copy_stream: torch.cuda.Stream = torch.cuda.Stream() with torch.cuda.stream(_copy_stream): sampling_state = SamplingState.from_sampling_params( sampling_params, list_past_output_tokens=past_output_tokens, + list_mask_prompt=prompt_masks, dtype=dtype, dev=dev, vocab_size=vocab_size, @@ -28,9 +32,7 @@ def get_sampling_state(sampling_params, past_output_tokens=None): def _test_temperature(temp=0, batch_size=1): shape = (batch_size, vocab_size) logits = torch.rand(shape, dtype=dtype, device=dev) - sampling_param = SamplingParams( - temperature=temp, - ) + sampling_param = SamplingParams(temperature=temp) sampling_state = get_sampling_state([sampling_param]) @@ -149,6 +151,7 @@ def _test_penalties(): presence_penalties = [0.8] frequency_penalties = [0.3] past_output_tokens = [[2, 2, 2, 3]] + prompt_masks = [[False] * vocab_size] * batch_size def prepare_metadata(past_output_tokens): count_map = [] @@ -188,7 +191,7 @@ def get_expected_result( ) ] sampling_state = get_sampling_state( - sampling_param, past_output_tokens=past_output_tokens + sampling_param, past_output_tokens=past_output_tokens, prompt_masks=prompt_masks ) new_logits = adjust_logits(logits, sampling_state, vocab_size) assert torch.allclose(expected, new_logits) @@ -199,6 +202,7 @@ def get_expected_result( presence_penalties = [0.8, 0.7, -0.8] frequency_penalties = [-0.3, 2.0, 1.2] past_output_tokens = [[2, 2, 2, 3, 5], [3, 1, 2, 4], [3, 3, 1]] + prompt_masks = [[False] * vocab_size] * batch_size count_map, mask = prepare_metadata(past_output_tokens) expected = get_expected_result( @@ -213,7 +217,9 @@ def get_expected_result( for i in range(batch_size) ] sampling_state = get_sampling_state( - sampling_params, past_output_tokens=past_output_tokens + sampling_params, + past_output_tokens=past_output_tokens, + prompt_masks=prompt_masks, ) new_logits = adjust_logits(logits, sampling_state, vocab_size) assert torch.allclose(expected, new_logits) @@ -314,6 +320,26 @@ def get_expected_result(logits, top_pks, filter_value=-float("Inf")): assert torch.allclose(expected, new_logits) +def _test_mixture_of_requests(): + # Mixed greedy & top_p/top_ks + batch_size = 6 + shape = (batch_size, vocab_size) + logits = torch.rand(shape, dtype=dtype, device=dev) + top_pks = [(0.7, 3), (1.0, -1), (1.0, -1), (0.5, 2), (1.0, -1), (0.8, 5)] + temps = [0.8, 0.8, 0.0, 0.0, 0.0, 0.7] + sampling_params = [ + SamplingParams(temperature=temps[i], top_p=top_p, top_k=top_k) + for i, (top_p, top_k) in enumerate(top_pks) + ] + sampling_state = get_sampling_state(sampling_params) + new_logits = adjust_logits(logits, sampling_state, vocab_size) + + # TODO(team): please follow-up. correctness check + # expected = logits.clone() + # expected = get_expected_result(expected, top_pks) + # assert torch.allclose(expected, new_logits) + + if __name__ == "__main__": _test_temperature() _test_logit_bias_checker() @@ -322,3 +348,4 @@ def get_expected_result(logits, top_pks, filter_value=-float("Inf")): _test_penalties() _test_top_p_top_k_checker() _test_top_p_top_k() + _test_mixture_of_requests()