Skip to content

Commit

Permalink
Merge pull request #1 from Deelvin/vc/prefill_logprob
Browse files Browse the repository at this point in the history
Update logprob for prefill step
  • Loading branch information
zxybazh committed Dec 19, 2023
2 parents e232862 + ca91b7a commit b64db01
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 43 deletions.
21 changes: 15 additions & 6 deletions serve/mlc_serve/api/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams:
if request.top_p is not None:
sampling_params.top_p = request.top_p
if request.logprobs is not None:
sampling_params.top_logprobs = request.top_logprobs
sampling_params.logprobs = request.logprobs
return sampling_params

Expand Down Expand Up @@ -211,13 +212,21 @@ async def collect_result_stream(
message=ChatMessage(role="assistant", content="".join(chunks)),
finish_reason=finish_reason,
)
content = []
if logprob_infos[index] != []:
choice.logprobs={
"token_logprobs": [float(logprob_info[0][1]) for logprob_info in logprob_infos[index]],
"tokens": [str(logprob_info[0][0]) for logprob_info in logprob_infos[index]],
"offset": list(accumulate([len(str(logprob_info[0][0])) for logprob_info in logprob_infos[index]])),
"top_logprobs": [logprob_info[1] for logprob_info in logprob_infos[index]]
}
for logprob_info in logprob_infos[index]:
content.append({
"token": str(logprob_info[0][0]),
"logprob": float(logprob_info[0][1]),
# TODO(vvchernov): implement bytes bases on https://platform.openai.com/docs/api-reference/chat/object
"bytes": None,
"top_logprobs": [{
"token": top_logprob[0],
"logprob": top_logprob[1],
"bytes": None,
} for top_logprob in logprob_info[1]],
})
choice.logprobs.content = content
choices.append(choice)

usage = UsageInfo(
Expand Down
9 changes: 7 additions & 2 deletions serve/mlc_serve/api/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,18 @@ class ChatCompletionRequest(BaseModel):
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
ignore_eos: Optional[bool] = False
logprobs: Optional[int] = None
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None


class Logprobs(BaseModel):
content: Optional[List[Dict]]


class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
logprobs: Optional[Dict[str, Union[List, Dict]]]
logprobs: Optional[Logprobs]
finish_reason: Optional[Literal["stop", "length", "cancelled"]] = None


Expand Down
2 changes: 1 addition & 1 deletion serve/mlc_serve/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
MLCServeEngineConfig,
get_engine_config
)
from .sampling_params import SamplingParams, SamplingType
from .sampling_params import SamplingParams, SamplingType, TOP_LOGPROBS_NUMBER
23 changes: 15 additions & 8 deletions serve/mlc_serve/engine/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Optional

_SAMPLING_EPS = 1e-5
TOP_LOGPROBS_NUMBER = 5


class SamplingType(IntEnum):
Expand Down Expand Up @@ -38,17 +39,22 @@ class SamplingParams:
to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
top_k: Integer that controls the number of top tokens to consider. Set
to -1 to consider all tokens.
logprobs: Optional[Integer] that determines number of log probabilities
to return per sampled tokens, default to None meaning disabled,
otherwise minimum 0, maximum 5.
logprobs: Optional[bool] Whether to return log probabilities of the output
tokens or not. If true, returns the log probabilities of each output
token returned in the content of message.
top_logprobs: Optional[Integer] An integer between 0 and 5 specifying
the number of most likely tokens to return at each token position,
each with an associated log probability. logprobs must be set to
true if this parameter is used.
"""

presence_penalty: float = 0.0
frequency_penalty: float = 0.0
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1
logprobs: Optional[int] = None
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None

def __post_init__(self):
self._verify_args()
Expand Down Expand Up @@ -76,10 +82,11 @@ def _verify_args(self) -> None:
raise ValueError(
f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}."
)
if self.logprobs is not None and (self.logprobs < 0 or self.logprobs > 5):
raise ValueError(
f"logprobs must be between 0 and 5, got {self.logprobs}."
)
if self.logprobs is not None and self.logprobs:
if (self.top_logprobs < 0 or self.top_logprobs > TOP_LOGPROBS_NUMBER):
raise ValueError(
f"top_logprobs must be between 0 and {TOP_LOGPROBS_NUMBER}, got {self.top_logprobs}."
)

def _verify_greedy_sampling(self) -> None:
if self.top_p < 1.0 - _SAMPLING_EPS:
Expand Down
51 changes: 30 additions & 21 deletions serve/mlc_serve/model/paged_cache_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@

from .base import get_model_artifact_config
from .tokenizer import HfTokenizerModule, ConversationTemplate
from ..engine import RequestId, SamplingType, MLCServeEngineConfig, SamplingParams
from ..engine import (
RequestId,
SamplingType,
MLCServeEngineConfig,
SamplingParams,
TOP_LOGPROBS_NUMBER
)
from ..engine.model_module import (
DecodeRequest,
PrefillRequest,
Expand Down Expand Up @@ -269,10 +275,12 @@ def _is_safe_to_sample(prob_like):

if logits_greedy.shape[0] > 0:
# Greedy sampling
logprobs = torch.log(torch.softmax(logits_greedy, dim=-1))
logprobs = torch.log_softmax(logits_greedy, dim=-1)
res_greedy_logprob, res_greedy = torch.max(logprobs, dim=-1)

top_greedy_logprob, top_greedy = torch.topk(logprobs, k=5, dim=-1, largest=True, sorted=True)
top_greedy_logprob, top_greedy = torch.topk(
logprobs, k=TOP_LOGPROBS_NUMBER, dim=-1, largest=True, sorted=True
)
# Convert to numpy
res_greedy_logprob = res_greedy_logprob.cpu().numpy()
res_greedy = res_greedy.cpu().numpy()
Expand Down Expand Up @@ -311,8 +319,10 @@ def _is_safe_to_sample(prob_like):
logits = _apply_top_p_top_k(logits_random, top_ps, top_ks)

probs = torch.softmax(logits_random, dim=-1)
logprobs = torch.log(torch.softmax(logits_greedy, dim=-1))
top_random_logprob, top_random = torch.topk(logprobs, k=5, dim=-1, largest=True, sorted=True)
logprobs = torch.log_softmax(logits_greedy, dim=-1)
top_random_logprob, top_random = torch.topk(
logprobs, k=TOP_LOGPROBS_NUMBER, dim=-1, largest=True, sorted=True
)
top_random_logprob = top_random_logprob.cpu().numpy()
top_random = top_random.cpu().numpy()

Expand All @@ -327,8 +337,8 @@ def _is_safe_to_sample(prob_like):

res = np.empty((num_seq,), dtype=np.int32)
res_logprobs = np.empty((num_seq,), dtype=np.float32)
top = np.empty((num_seq, 5), dtype=np.int32)
top_logprobs = np.empty((num_seq, 5), dtype=np.float32)
top = np.empty((num_seq, TOP_LOGPROBS_NUMBER), dtype=np.int32)
top_logprobs = np.empty((num_seq, TOP_LOGPROBS_NUMBER), dtype=np.float32)

res[mask_random] = res_random
res_logprobs[mask_random] = res_random_logprobs
Expand Down Expand Up @@ -387,13 +397,17 @@ def fetch_logprobs(
sampling_param: SamplingParams,
) -> Optional[Tuple[np.ndarray, List[Tuple[np.ndarray, np.ndarray]]]]:
"""Fetch the logprob information with index"""
if sampling_param.logprobs is None or logprob_info is None:
if (
sampling_param.logprobs is None or
not sampling_param.logprobs or
logprob_info is None
):
return None
(res, res_logprobs), (top, top_logprobs) = logprob_info
return (res[index],res_logprobs[index]), \
zip(
top[index][:sampling_param.logprobs],
top_logprobs[index][:sampling_param.logprobs]
top[index][:sampling_param.top_logprobs],
top_logprobs[index][:sampling_param.top_logprobs]
)


Expand Down Expand Up @@ -622,13 +636,6 @@ def generate(
out = self.mod["prefill"](
input_ids, positions, seq_lens, kv_cache, slot_mapping, self.params
)

if self.disco_session:
logits, _ = out.debug_get_from_remote(0)
else:
logits = out[
0
] # Ignore returned KV cache since it is updated in-place anyway.
else:
torch.cuda.nvtx.range_push(f"forward decode {input_shape}")

Expand All @@ -645,10 +652,12 @@ def generate(
self.params,
)

if self.disco_session:
logits, _ = out.debug_get_from_remote(0)
else:
logits = out[0]
if self.disco_session:
logits, _ = out.debug_get_from_remote(0)
else:
logits = out[
0
] # Ignore returned KV cache since it is updated in-place anyway.

torch.cuda.synchronize()
torch.cuda.nvtx.range_pop()
Expand Down
11 changes: 6 additions & 5 deletions serve/tests/unittest/test_engine_with_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,14 @@ def create_engine(
))
return engine

def create_request(idx, prompt, temp, max_tokens, stop, ignore_eos, logprobs):
def create_request(idx, prompt, temp, max_tokens, stop, ignore_eos, top_logprobs):
return Request(
request_id = str(idx),
messages = [ChatMessage(role="user", content=prompt)],
sampling_params = SamplingParams(
temperature=0.0,
logprobs=logprobs
logprobs=True,
top_logprobs=top_logprobs
),
stopping_criteria = StoppingCriteria(
max_tokens=max_tokens,
Expand Down Expand Up @@ -226,7 +227,7 @@ def test_logprobs(
max_num_sequences=4,
max_input_len=512,
num_requests=5,
logprobs=3,
top_logprobs=3,
):
prompt = "hi"
engine = create_engine(
Expand All @@ -236,7 +237,7 @@ def test_logprobs(
max_input_len,
)
s = 113
requests = [create_request(idx=str(n-s), prompt=prompt, temp=0, max_tokens=n, stop=None, ignore_eos=True, logprobs=logprobs) for n in range(s, s+num_requests)]
requests = [create_request(idx=str(n-s), prompt=prompt, temp=0, max_tokens=n, stop=None, ignore_eos=True, top_logprobs=top_logprobs) for n in range(s, s+num_requests)]
engine.add(requests)

generated = ["" for _ in range(num_requests)]
Expand All @@ -247,7 +248,7 @@ def test_logprobs(
assert len(res.sequences) == 1
seq = res.sequences[0]

assert seq.finish_reason is not None or len(list(seq.logprob_info[1])) == logprobs
assert seq.finish_reason is not None or len(list(seq.logprobs.content[0]["top_logprobs"])) == top_logprobs

if seq.is_finished:
assert seq.num_generated_tokens == requests[int(res.request_id)].stopping_criteria.max_tokens
Expand Down

0 comments on commit b64db01

Please sign in to comment.