Skip to content

Commit

Permalink
Merge branch 'batch-serving' into parallel-sampling-eviction
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jan 13, 2024
2 parents 9fb9261 + 66a2e53 commit bc3dc83
Show file tree
Hide file tree
Showing 9 changed files with 848 additions and 918 deletions.
70 changes: 40 additions & 30 deletions serve/benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Benchmark latency offline."""
import argparse
import cuda
import cuda.cudart
import time, numpy as np
from mlc_serve.engine import (
DebugOptions,
Expand All @@ -15,47 +17,55 @@
from utils import add_sampling_flags, postproc_sampling_args


request_counter = 0


def create_request(request_id):
global request_counter
request_counter += 1
return Request(
request_id=str(request_counter),
messages=None, # Provide prompt as `DebugOption` to bypass the conv template
sampling_params=SamplingParams(
temperature=args.temperature,
top_p=(1 if args.temperature == 0.0 else args.sampling_setting["top_p"]),
top_k=(-1 if args.temperature == 0.0 else args.sampling_setting["top_k"]),
repetition_penalty=args.sampling_setting["repetition_penalty"],
frequency_penalty=args.sampling_setting["frequency_penalty"],
presence_penalty=args.sampling_setting["presence_penalty"],
logit_bias=args.sampling_setting["logit_bias"],
),
stopping_criteria=StoppingCriteria(
max_tokens=args.num_output_tokens, stop_sequences=None
),
debug_options=DebugOptions(
ignore_eos=True, prompt_token_ids=[3] * args.num_input_tokens
),
num_sequences=args.num_sequences_to_sample,
)


def main(args: argparse.Namespace):
print(args)

engine = create_mlc_engine(args)
engine.add(
[
Request(
request_id="0",
messages=None, # Provide prompt as `DebugOption` to bypass the conv template
sampling_params=SamplingParams(
temperature=args.temperature,
top_p=(
1 if args.temperature == 0.0 else args.sampling_setting["top_p"]
),
top_k=(
-1
if args.temperature == 0.0
else args.sampling_setting["top_k"]
),
repetition_penalty=args.sampling_setting["repetition_penalty"],
frequency_penalty=args.sampling_setting["frequency_penalty"],
presence_penalty=args.sampling_setting["presence_penalty"],
logit_bias=args.sampling_setting["logit_bias"],
),
stopping_criteria=StoppingCriteria(
max_tokens=args.num_output_tokens, stop_sequences=None
),
debug_options=DebugOptions(
ignore_eos=True, prompt_token_ids=[3] * args.num_input_tokens
),
num_sequences=args.num_sequences_to_sample,
)
]
)

# warm up
engine.add([create_request(args)])

while engine.has_pending_requests():
engine.step()

latencies = []
engine.add([create_request(args)])

cuda.cudart.cudaProfilerStart()
while engine.has_pending_requests():
t0 = time.perf_counter()
engine.step()
t1 = time.perf_counter()
latencies.append(t1 - t0)
cuda.cudart.cudaProfilerStop()

if args.use_staging_engine:
engine.stop()
Expand Down
6 changes: 3 additions & 3 deletions serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Required interfaces for the actual inference capability in InferenceEngine.
"""
from dataclasses import dataclass
from typing import Optional, Protocol, Union, List
from typing import Optional, Protocol, Union, List, Sequence

from .base import ChatMessage, RequestId, MLCServeEngineConfig, RequestState, SequenceId
from ..model.base import ModelArtifactConfig
Expand Down Expand Up @@ -143,8 +143,8 @@ class TextGenerator(Protocol):

def generate(
self,
requests: List[Union[PrefillRequest, DecodeRequest, MultiQueryDecodeRequest]],
kv_cache: KVCache,
requests: Sequence[Union[PrefillRequest, DecodeRequest]],
kv_cache,
) -> List[TextGenerationResult]:
"""
A unified entrypoint for text generation.
Expand Down
22 changes: 18 additions & 4 deletions serve/mlc_serve/model/dummy_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Union, List

from mlc_serve.engine import (
ChatMessage,
Expand All @@ -16,16 +16,29 @@
)

class DummyTokenizer:
is_fast = True
@property
def eos_token_id(self):
return 2
return 1000

def encode(self, text: str, **kwargs) -> list[int]:
return [1] * len(text.split())

def decode(self, tokens: list[int], **kwargs) -> str:
return "test " * len(tokens)

def convert_ids_to_tokens(self, token_ids: List[int]) -> List[str]:
r = []
for i in token_ids:
r.append(f"_test{i}")
return r

def convert_tokens_to_string(self, tokens: List[str]) -> str:
ret = ""
for a in tokens:
ret += a + " "
return ret


class DummyConversationTemplate:
def apply(self, messages: list[ChatMessage]) -> str:
Expand All @@ -45,7 +58,7 @@ def __init__(self, max_cached_tokens: int):
def get_cache(self) -> KVCache:
return self.cache

def allocate(self, request_id: RequestId, num_tokens: int) -> bool:
def allocate(self, request_id: RequestId, num_tokens: int, num_sequences: int) -> bool:
seq_id = SequenceId(request_id, 0)
self.cache.cached_requests[seq_id] = num_tokens
if self.get_free_space() < 0:
Expand Down Expand Up @@ -107,7 +120,8 @@ def generate(
request_id=request_id,
sequence_index=0,
),
generated_tokens=[1],
generated_tokens=[req.token_ids[-1] + 1],
# generated_tokens=[1],
error=None,
)
)
Expand Down
Loading

0 comments on commit bc3dc83

Please sign in to comment.