From e78ff2a25745f3f98bccf775dc141cf10cf2ab73 Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Thu, 14 Dec 2023 19:01:47 +0000 Subject: [PATCH 1/4] wip --- serve/benchmarks/benchmark_throughput.py | 128 ++++++++++++++++------- 1 file changed, 92 insertions(+), 36 deletions(-) diff --git a/serve/benchmarks/benchmark_throughput.py b/serve/benchmarks/benchmark_throughput.py index beb2ebcb80..7fc191e87f 100644 --- a/serve/benchmarks/benchmark_throughput.py +++ b/serve/benchmarks/benchmark_throughput.py @@ -18,6 +18,8 @@ from mlc_serve.model.paged_cache_model import HfTokenizerModule, PagedCacheModelModule from mlc_serve.utils import get_default_mlc_serve_argparser, postproc_mlc_serve_args +SAMPLER_SETTING = {"ignore_eos": True, "temperature": 1, "use_beam_search": False} + def sample_requests( dataset_path: str, @@ -62,46 +64,91 @@ def sample_requests( return sampled_requests -def run_mlc( - requests: List[Tuple[str, int, int]], - engine, - num_sequences, -) -> float: - for i, (prompt, _, output_len) in enumerate(requests): +def run_mii(requests: List[Tuple[str, int, int]], model, num_shards) -> float: + from mii import pipeline + + engine = pipeline(model, tensor_parallel=num_shards) + prompts = [prompt for prompt, _, _ in requests] + # FIXME: hardcoded + output_len = 128 + start = time.perf_counter() + engine(prompts, max_new_tokens=output_len) + end = time.perf_counter() + return end - start + + +def run_vllm(requests: List[Tuple[str, int, int]], model, num_shards) -> float: + from vllm import LLM, SamplingParams + + # Fixme + llm = LLM( + model=model, + tokenizer=model, + quantization=None, + tensor_parallel_size=num_shards, + seed=0, + trust_remote_code=True, + dtype="auto", + max_model_len=2000, + ) + + # Add the requests to the engine. + for prompt, _, output_len in requests: sampling_params = SamplingParams( - temperature=1.0, - top_p=1.0, - frequency_penalty=-1, - logit_bias={1: -1, 3: 1, 2: 2}, + n=1, + use_beam_search=SAMPLER_SETTING["use_beam_search"], + temperature=SAMPLER_SETTING["temperature"], + ignore_eos=SAMPLER_SETTING["ignore_eos"], + max_tokens=output_len, + ) + # FIXME(woosuk): Do not use internal method. + llm._add_request( + prompt=prompt, + prompt_token_ids=None, + sampling_params=sampling_params, ) + start = time.perf_counter() + llm._run_engine(use_tqdm=True) + end = time.perf_counter() + return end - start + + +def run_mlc(engine, requests) -> float: + for i, (prompt, _, output_len) in enumerate(requests): engine.add( [ Request( request_id=str(i), messages=[ChatMessage(role="user", content=prompt)], - sampling_params=sampling_params, + sampling_params=SamplingParams( + temperature=SAMPLER_SETTING["temperature"] + ), stopping_criteria=StoppingCriteria( max_tokens=output_len, stop_sequences=None ), - debug_options=DebugOptions(ignore_eos=True, prompt=prompt), - num_sequences=num_sequences, + num_sequences=1, + debug_options=DebugOptions( + ignore_eos=SAMPLER_SETTING["ignore_eos"], prompt=prompt + ), ) ] ) - start = time.time() + start = time.perf_counter() while engine.has_pending_requests(): engine.step() - end = time.time() + end = time.perf_counter() + + if args.use_staging_engine: + engine.stop() + return end - start -def create_engine_and_tokenizer_module( - args: argparse.Namespace, -): +def create_mlc_engine(args: argparse.Namespace): engine_config = get_engine_config( { "use_staging_engine": args.use_staging_engine, @@ -122,7 +169,6 @@ def create_engine_and_tokenizer_module( }, ) engine.start() - tokenizer = engine.tokenizer else: engine = SynchronousInferenceEngine( PagedCacheModelModule( @@ -130,27 +176,34 @@ def create_engine_and_tokenizer_module( engine_config=engine_config, ) ) - tokenizer = engine.tokenizer - return engine, tokenizer + return engine def main(args: argparse.Namespace): print(args) - - engine, tokenizer = create_engine_and_tokenizer_module(args) - - # Sample the requests. - requests = sample_requests(args.dataset, args.num_prompts, tokenizer._tokenizer) - - elapsed_time = run_mlc( - requests, - engine, - args.num_sequences_to_sample, - ) - - if args.use_staging_engine: - engine.stop() + random.seed(args.seed) + + if args.backend == "mlc-serve": + # Create mlc engine + engine = create_mlc_engine(args) + # Sample the requests. + requests = sample_requests( + args.dataset, args.num_prompts, engine.tokenizer._tokenizer + ) + elapsed_time = run_mlc(engine, requests) + else: + from transformers import AutoTokenizer + + model = "/opt/models/mistral/Mixtral-8x7B-Instruct-v0.1" + tokenizer = AutoTokenizer.from_pretrained(model) + requests = sample_requests(args.dataset, args.num_prompts, tokenizer) + if args.backend == "mii": + num_shards = 1 + elapsed_time = run_mii(requests, model, num_shards) + elif args.backend == "vllm": + num_shards = 2 + elapsed_time = run_vllm(requests, model, num_shards) total_num_tokens = sum( prompt_len + output_len * args.num_sequences_to_sample @@ -160,7 +213,7 @@ def main(args: argparse.Namespace): tok_per_sec = total_num_tokens / elapsed_time print( - f"Throughput: {req_per_sec:.2f} requests/s, " + f"Engine Throughput: {req_per_sec:.2f} requests/s, " f"{total_num_tokens / elapsed_time:.2f} tokens/s" ) if args.report_path is not None: @@ -177,6 +230,9 @@ def main(args: argparse.Namespace): if __name__ == "__main__": parser = get_default_mlc_serve_argparser(description="Benchmark the throughput.") + parser.add_argument( + "--backend", type=str, default="mlc-serve", choices=["mlc-serve", "vllm", "mii"] + ) parser.add_argument( "--dataset", type=str, required=True, help="Path to the dataset." ) From 922646108e7b44b162155c04f7aef8e2dbebb092 Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Thu, 14 Dec 2023 19:27:49 +0000 Subject: [PATCH 2/4] rebased --- serve/benchmarks/benchmark_throughput.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/serve/benchmarks/benchmark_throughput.py b/serve/benchmarks/benchmark_throughput.py index 7fc191e87f..b04b35702e 100644 --- a/serve/benchmarks/benchmark_throughput.py +++ b/serve/benchmarks/benchmark_throughput.py @@ -114,7 +114,7 @@ def run_vllm(requests: List[Tuple[str, int, int]], model, num_shards) -> float: return end - start -def run_mlc(engine, requests) -> float: +def run_mlc(engine, requests, num_sequences_to_sample) -> float: for i, (prompt, _, output_len) in enumerate(requests): engine.add( [ @@ -127,7 +127,7 @@ def run_mlc(engine, requests) -> float: stopping_criteria=StoppingCriteria( max_tokens=output_len, stop_sequences=None ), - num_sequences=1, + num_sequences=num_sequences_to_sample, debug_options=DebugOptions( ignore_eos=SAMPLER_SETTING["ignore_eos"], prompt=prompt ), @@ -191,7 +191,7 @@ def main(args: argparse.Namespace): requests = sample_requests( args.dataset, args.num_prompts, engine.tokenizer._tokenizer ) - elapsed_time = run_mlc(engine, requests) + elapsed_time = run_mlc(engine, requests, args.num_sequences_to_sample) else: from transformers import AutoTokenizer From b0e98515dfc2bbcc7c5701b90c28a91c73ce6b69 Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Fri, 5 Jan 2024 22:18:16 +0000 Subject: [PATCH 3/4] wip --- serve/benchmarks/benchmark_throughput.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/serve/benchmarks/benchmark_throughput.py b/serve/benchmarks/benchmark_throughput.py index b04b35702e..1953512ad1 100644 --- a/serve/benchmarks/benchmark_throughput.py +++ b/serve/benchmarks/benchmark_throughput.py @@ -18,7 +18,7 @@ from mlc_serve.model.paged_cache_model import HfTokenizerModule, PagedCacheModelModule from mlc_serve.utils import get_default_mlc_serve_argparser, postproc_mlc_serve_args -SAMPLER_SETTING = {"ignore_eos": True, "temperature": 1, "use_beam_search": False} +SAMPLER_SETTING = {"ignore_eos": True, "temperature": 1} def sample_requests( @@ -96,12 +96,11 @@ def run_vllm(requests: List[Tuple[str, int, int]], model, num_shards) -> float: for prompt, _, output_len in requests: sampling_params = SamplingParams( n=1, - use_beam_search=SAMPLER_SETTING["use_beam_search"], + use_beam_search=False, temperature=SAMPLER_SETTING["temperature"], ignore_eos=SAMPLER_SETTING["ignore_eos"], max_tokens=output_len, ) - # FIXME(woosuk): Do not use internal method. llm._add_request( prompt=prompt, prompt_token_ids=None, From fb29bef99c87c8c0a41a8ff2ec8e1ffa6239dac8 Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Sat, 6 Jan 2024 01:09:40 +0000 Subject: [PATCH 4/4] done --- serve/benchmarks/benchmark_throughput.py | 113 +++++++++++++++-------- 1 file changed, 77 insertions(+), 36 deletions(-) diff --git a/serve/benchmarks/benchmark_throughput.py b/serve/benchmarks/benchmark_throughput.py index 1953512ad1..08851da0b6 100644 --- a/serve/benchmarks/benchmark_throughput.py +++ b/serve/benchmarks/benchmark_throughput.py @@ -64,47 +64,47 @@ def sample_requests( return sampled_requests -def run_mii(requests: List[Tuple[str, int, int]], model, num_shards) -> float: +def run_mii(requests: List[Tuple[str, int, int]], args) -> float: from mii import pipeline - engine = pipeline(model, tensor_parallel=num_shards) + engine = pipeline(args.model, tensor_parallel=args.num_shards) prompts = [prompt for prompt, _, _ in requests] - # FIXME: hardcoded - output_len = 128 start = time.perf_counter() - engine(prompts, max_new_tokens=output_len) + engine( + prompts, + max_new_tokens=args.max_output_tokens, + ignore_eos=SAMPLER_SETTING["ignore_eos"], + temperature=SAMPLER_SETTING["temperature"], + ) end = time.perf_counter() return end - start -def run_vllm(requests: List[Tuple[str, int, int]], model, num_shards) -> float: +def run_vllm(requests: List[Tuple[str, int, int]], args) -> float: from vllm import LLM, SamplingParams - # Fixme llm = LLM( - model=model, - tokenizer=model, - quantization=None, - tensor_parallel_size=num_shards, - seed=0, + model=args.model, + tokenizer=args.model, + dtype=args.dtype, + quantization=args.quantization, + tensor_parallel_size=args.num_shards, trust_remote_code=True, - dtype="auto", - max_model_len=2000, + max_model_len=None, # derive from the model ) # Add the requests to the engine. - for prompt, _, output_len in requests: - sampling_params = SamplingParams( - n=1, - use_beam_search=False, - temperature=SAMPLER_SETTING["temperature"], - ignore_eos=SAMPLER_SETTING["ignore_eos"], - max_tokens=output_len, - ) + for prompt, _, _ in requests: llm._add_request( prompt=prompt, prompt_token_ids=None, - sampling_params=sampling_params, + sampling_params=SamplingParams( + n=args.num_sequences_to_sample, + use_beam_search=False, + temperature=SAMPLER_SETTING["temperature"], + ignore_eos=SAMPLER_SETTING["ignore_eos"], + max_tokens=args.max_output_tokens, + ), ) start = time.perf_counter() @@ -113,8 +113,8 @@ def run_vllm(requests: List[Tuple[str, int, int]], model, num_shards) -> float: return end - start -def run_mlc(engine, requests, num_sequences_to_sample) -> float: - for i, (prompt, _, output_len) in enumerate(requests): +def run_mlc(engine, requests, args) -> float: + for i, (prompt, _, _) in enumerate(requests): engine.add( [ Request( @@ -124,9 +124,9 @@ def run_mlc(engine, requests, num_sequences_to_sample) -> float: temperature=SAMPLER_SETTING["temperature"] ), stopping_criteria=StoppingCriteria( - max_tokens=output_len, stop_sequences=None + max_tokens=args.max_output_tokens, stop_sequences=None ), - num_sequences=num_sequences_to_sample, + num_sequences=args.num_sequences_to_sample, debug_options=DebugOptions( ignore_eos=SAMPLER_SETTING["ignore_eos"], prompt=prompt ), @@ -190,23 +190,27 @@ def main(args: argparse.Namespace): requests = sample_requests( args.dataset, args.num_prompts, engine.tokenizer._tokenizer ) - elapsed_time = run_mlc(engine, requests, args.num_sequences_to_sample) + elapsed_time = run_mlc(engine, requests, args) else: from transformers import AutoTokenizer - model = "/opt/models/mistral/Mixtral-8x7B-Instruct-v0.1" - tokenizer = AutoTokenizer.from_pretrained(model) + assert ( + args.model is not None + ), "Please provide model path for vllm and deepspeed mii." + assert ( + args.num_shards is not None + ), "Please provide number of gpus for vllm and deepspeed mii." + + tokenizer = AutoTokenizer.from_pretrained(args.model) requests = sample_requests(args.dataset, args.num_prompts, tokenizer) if args.backend == "mii": - num_shards = 1 - elapsed_time = run_mii(requests, model, num_shards) + elapsed_time = run_mii(requests, args) elif args.backend == "vllm": - num_shards = 2 - elapsed_time = run_vllm(requests, model, num_shards) + elapsed_time = run_vllm(requests, args) total_num_tokens = sum( - prompt_len + output_len * args.num_sequences_to_sample - for _, prompt_len, output_len in requests + prompt_len + args.max_output_tokens * args.num_sequences_to_sample + for _, prompt_len, _ in requests ) req_per_sec = len(requests) / elapsed_time tok_per_sec = total_num_tokens / elapsed_time @@ -238,12 +242,49 @@ def main(args: argparse.Namespace): parser.add_argument( "--num-prompts", type=int, default=1000, help="Number of prompts to process." ) + parser.add_argument( + "--max-output-tokens", + type=int, + default=128, + help="Maximum number of generation tokens.", + ) parser.add_argument( "--report-path", type=str, default=None, help="Append the current result to the given path if provided.", ) + # flags for vllm and deepspeed mii + parser.add_argument( + "--model", + type=str, + default=None, + help="Model path. This is for vLLM and Deepspeed MII.", + ) + parser.add_argument( + "--num-shards", + type=int, + default=None, + help="Number of GPUs. This is for vLLM and Deepspeed MII.", + ) + # flags for vllm + parser.add_argument( + "--quantization", + "-q", + choices=["awq", "gptq", "squeezellm", None], + default=None, + ) + parser.add_argument( + "--dtype", + type=str, + default="auto", + choices=["auto", "half", "float16", "bfloat16", "float", "float32"], + help="data type for model weights and activations. " + 'The "auto" option will use FP16 precision ' + "for FP32 and FP16 models, and BF16 precision " + "for BF16 models.", + ) + args = parser.parse_args() args = postproc_mlc_serve_args(args)