From e034fcf431a87fca6efa46506889ad6cac53d17c Mon Sep 17 00:00:00 2001 From: yukavio <67678385+yukavio@users.noreply.github.com> Date: Tue, 24 Sep 2024 18:32:44 +0800 Subject: [PATCH] add radix cache scheduler --- .gitignore | 5 + python/sglang/bench_serving.py | 164 ++- python/sglang/bench_serving_improve.py | 1028 +++++++++++++++ python/sglang/global_config.py | 2 +- python/sglang/iter_search_best_qps.py | 1098 +++++++++++++++++ python/sglang/launch_server.py | 1 + python/sglang/search_best_qps.py | 1068 ++++++++++++++++ python/sglang/srt/managers/controller_flex.py | 312 ++++- .../sglang/srt/managers/controller_single.py | 1 + python/sglang/srt/managers/io_struct.py | 12 +- python/sglang/srt/managers/schedule_batch.py | 8 + python/sglang/srt/managers/tp_worker.py | 79 +- python/sglang/srt/mem_cache/radix_cache.py | 84 +- python/sglang/srt/server_args.py | 2 + python/sglang/utils.py | 4 +- 15 files changed, 3800 insertions(+), 68 deletions(-) create mode 100644 python/sglang/bench_serving_improve.py create mode 100644 python/sglang/iter_search_best_qps.py create mode 100644 python/sglang/search_best_qps.py diff --git a/.gitignore b/.gitignore index ca43e1ccba..b6fb08775a 100644 --- a/.gitignore +++ b/.gitignore @@ -183,3 +183,8 @@ work_dirs/ *.csv !logo.png +test.py +a.txt +b.txt +*.txt +launch_server.py.lprof diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 363683e05f..e67871b28d 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -311,15 +311,25 @@ def download_sharegpt_dataset(path): raise Exception(f"Failed to download dataset: {e}") +import pickle + + def sample_sharegpt_requests( dataset_path: str, num_requests: int, tokenizer: PreTrainedTokenizerBase, - fixed_output_len: Optional[int] = None, + max_seqlen: int, ) -> List[Tuple[str, int, int]]: - if fixed_output_len is not None and fixed_output_len < 4: - raise ValueError("output_len too small") - + cache_path = f"./input_cache_v2_{num_requests}" + # 尝试加载缓存的 input_requests + if os.path.isfile(cache_path): + with open(cache_path, "rb") as f: + input_requests = pickle.load(f) + print("Loaded input_requests from cache.") + return input_requests + prompts = [] + prompt_lens = [] + response_lens = [] # Download sharegpt if necessary if not os.path.isfile(dataset_path) and not os.path.isfile(default_sharegpt_path): download_sharegpt_dataset(default_sharegpt_path) @@ -331,42 +341,52 @@ def sample_sharegpt_requests( # Load the dataset. with open(dataset_path) as f: - dataset = json.load(f) - # Filter out the conversations with less than 2 turns. - dataset = [data for data in dataset if len(data["conversations"]) >= 2] - # Only keep the first two turns of each conversation. - dataset = [ - (data["conversations"][0]["value"], data["conversations"][1]["value"]) - for data in dataset - ] - - # Shuffle the dataset. - random.shuffle(dataset) - - # Filter out sequences that are too long or too short - filtered_dataset: List[Tuple[str, int, int]] = [] - for i in range(len(dataset)): - if len(filtered_dataset) == num_requests: - break - - # Tokenize the prompts and completions. - prompt = dataset[i][0] - prompt_token_ids = tokenizer(prompt).input_ids - completion = dataset[i][1] - completion_token_ids = tokenizer(completion).input_ids - prompt_len = len(prompt_token_ids) - output_len = ( - len(completion_token_ids) if fixed_output_len is None else fixed_output_len - ) - if prompt_len < 4 or output_len < 4: - # Prune too short sequences. - continue - if prompt_len > 1024 or prompt_len + output_len > 2048: - # Prune too long sequences. - continue - filtered_dataset.append((prompt, prompt_len, output_len)) + datasets = json.load(f) + for data in datasets: + if len(data["conversations"]) >= 2: + prompt = data["conversations"][0]["value"] + res = data["conversations"][1]["value"] + prompt_token_ids = tokenizer(prompt).input_ids + completion_token_ids = tokenizer(res).input_ids + + if ( + len(prompt_token_ids) + len(completion_token_ids) < max_seqlen + and len(prompt_token_ids) > 0 + and len(completion_token_ids) > 0 + ): + prompts.append(prompt) + prompt_lens.append(len(prompt_token_ids)) + response_lens.append(len(completion_token_ids)) + if len(prompts) > num_requests: + break + + sampled_ids = [random.randint(0, len(prompts) - 1) for _ in range(num_requests)] + sampled_prompts = [prompts[idx] for idx in sampled_ids] + sampled_prompts_lens = [prompt_lens[idx] for idx in sampled_ids] + sampled_response_lens = [response_lens[idx] for idx in sampled_ids] + + for i, (prompt_len, gen_len) in enumerate( + zip(sampled_prompts_lens, sampled_response_lens) + ): + total = prompt_len + gen_len + if total > max_seqlen: + print(f"truncating long prompt+gen_len {prompt_len=} {gen_len=}") + gen_len = max_seqlen - prompt_len + sampled_response_lens[i] = gen_len + input_requests = list( + zip(sampled_prompts, sampled_prompts_lens, sampled_response_lens) + ) + + + with open(cache_path, "wb") as f: + pickle.dump(input_requests, f) + print(f"Saved input_requests_{num_requests} to cache.") + print(f"#Input tokens: {np.sum(sampled_prompts_lens)}") + print(f"#Output tokens: {np.sum(sampled_response_lens)}") + return input_requests - return filtered_dataset + +import pickle def sample_random_requests( @@ -377,7 +397,13 @@ def sample_random_requests( tokenizer: PreTrainedTokenizerBase, dataset_path: str, ) -> List[Tuple[str, int, int]]: - + cache_path = f"./input_cache_{num_prompts}" + # 尝试加载缓存的 input_requests + if os.path.isfile(cache_path): + with open(cache_path, "rb") as f: + input_requests = pickle.load(f) + print("Loaded input_requests from cache.") + return input_requests input_lens = np.random.randint( max(int(input_len * range_ratio), 1), input_len + 1, @@ -444,7 +470,9 @@ def sample_random_requests( ] ) input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) - + with open(cache_path, "wb") as f: + pickle.dump(input_requests, f) + print(f"Saved input_requests_{num_prompts} to cache.") print(f"#Input tokens: {np.sum(input_lens)}") print(f"#Output tokens: {np.sum(output_lens)}") return input_requests @@ -483,6 +511,9 @@ def calculate_metrics( tpots: List[float] = [] ttfts: List[float] = [] e2e_latencies: List[float] = [] + + input_lens: List[float] = [] + for i in range(len(outputs)): if outputs[i].success: output_len = outputs[i].output_len @@ -492,6 +523,9 @@ def calculate_metrics( ) retokenized_output_lens.append(retokenized_output_len) total_input += input_requests[i][1] + + input_lens.append(input_requests[i][1]) + if output_len > 1: tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1)) itls += outputs[i].itl @@ -510,6 +544,11 @@ def calculate_metrics( "on the benchmark arguments.", stacklevel=2, ) + + # metric_data = [input_lens, output_lens, ttfts] + # with open(f'metrics_{time.time()}.json', 'w') as f: + # json.dump(metric_data, f) + metrics = BenchmarkMetrics( completed=completed, total_input=total_input, @@ -557,6 +596,15 @@ async def benchmark( print("Starting initial single prompt test run...") test_prompt, test_prompt_len, test_output_len = input_requests[0] + + words = test_prompt.split() + + # 使用random.shuffle打乱单词列表 + random.shuffle(words) + + # 将打乱后的单词列表合并回文本 + test_prompt = " ".join(words) + test_input = RequestFuncInput( model=model_id, prompt=test_prompt, @@ -723,6 +771,31 @@ async def benchmark( "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, "median_e2e_latency_ms": metrics.median_e2e_latency_ms, } + + balance_method = os.getenv("LOAD_BALANCE_METHOD") + new_item = { + "method": balance_method, + "mean_ttft": metrics.mean_ttft_ms, + "request_rate": request_rate, + "request_throughput": metrics.request_throughput, + "p99_ttft_ms": metrics.p99_ttft_ms, + "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, + "time": datetime.now().isoformat(), + } + file_name = f"{balance_method}_result.json" + if not os.path.exists(file_name): + with open(file_name, "w") as f: + json.dump([], f) + + with open(file_name, "r") as f: + tmp_data = json.load(f) + + tmp_data.append(new_item) + + with open(file_name, "w") as f: + json.dump(tmp_data, f, indent=4) + + print(f"add new item to {file_name}: {new_item}") return result @@ -819,7 +892,7 @@ def run_benchmark(args_: argparse.Namespace): dataset_path=args.dataset_path, num_requests=args.num_prompts, tokenizer=tokenizer, - fixed_output_len=args.sharegpt_output_len, + max_seqlen=args.sharegpt_max_seqlen, ) elif args.dataset_name == "random": input_requests = sample_random_requests( @@ -851,6 +924,7 @@ def run_benchmark(args_: argparse.Namespace): ) ) else: + return asyncio.run( benchmark( backend=backend, @@ -927,6 +1001,12 @@ def set_ulimit(target_soft_limit=65535): default=1000, help="Number of prompts to process. Default is 1000.", ) + parser.add_argument( + "--sharegpt-max-seqlen", + type=int, + default=8192, + help="Number of max request len. Default is 8192.", + ) parser.add_argument( "--sharegpt-output-len", type=int, diff --git a/python/sglang/bench_serving_improve.py b/python/sglang/bench_serving_improve.py new file mode 100644 index 0000000000..7d63bcdcf9 --- /dev/null +++ b/python/sglang/bench_serving_improve.py @@ -0,0 +1,1028 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py +# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py + +""" +Benchmark online serving. + +Usage: +python3 -m sglang.bench_serving --backend sglang --num-prompt 10 + +python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5 +python3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi +""" + +import argparse +import asyncio +import json +import os +import random +import resource +import sys +import time +import traceback +import warnings +from argparse import ArgumentParser +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union +from concurrent.futures import ThreadPoolExecutor, as_completed +import aiohttp +import numpy as np +import requests +from tqdm.asyncio import tqdm +from transformers import ( + AutoTokenizer, + PreTrainedTokenizer, + PreTrainedTokenizerBase, + PreTrainedTokenizerFast, +) + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + +global args + + +@dataclass +class RequestFuncInput: + prompt: str + api_url: str + prompt_len: int + output_len: int + model: str + extra_request_body: Dict[str, Any] + + +@dataclass +class RequestFuncOutput: + generated_text: str = "" + success: bool = False + latency: float = 0.0 + ttft: float = 0.0 # Time to first token + itl: List[float] = field(default_factory=list) # List of inter-token latencies + prompt_len: int = 0 + error: str = "" + output_len: int = 0 + + +def remove_prefix(text: str, prefix: str) -> str: + return text[len(prefix) :] if text.startswith(prefix) else text + + +# trt llm not support ignore_eos +# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505 +async def async_request_trt_llm( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith("generate_stream") + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + payload = { + "accumulate_tokens": True, + "text_input": request_func_input.prompt, + "temperature": 0.000001, + "top_p": 1.0, + "max_tokens": request_func_input.output_len, + "stream": True, + "min_length": request_func_input.output_len, + "end_id": 1048576, + **request_func_input.extra_request_body, + } + if args.disable_ignore_eos: + del payload["min_length"] + del payload["end_id"] + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data:") + + data = json.loads(chunk) + output.generated_text += data["text_output"] + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + + output.latency = most_recent_timestamp - st + output.success = True + output.output_len = request_func_input.output_len + + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +# set ignore_eos True by default +async def async_request_openai_completions( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith( + "completions" + ), "OpenAI Completions API URL must end with 'completions'." + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + payload = { + "model": request_func_input.model, + "prompt": request_func_input.prompt, + "temperature": 0.0, + "best_of": 1, + "max_tokens": request_func_input.output_len, + "stream": not args.disable_stream, + "ignore_eos": not args.disable_ignore_eos, + **request_func_input.extra_request_body, + } + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") + latency = time.perf_counter() - st + if chunk == "[DONE]": + pass + else: + data = json.loads(chunk) + + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if data["choices"][0]["text"]: + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text += data["choices"][0]["text"] + + output.generated_text = generated_text + output.success = True + output.latency = latency + output.output_len = request_func_input.output_len + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +def get_model(pretrained_model_name_or_path: str) -> str: + if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true": + import huggingface_hub.constants + from modelscope import snapshot_download + + model_path = snapshot_download( + model_id=pretrained_model_name_or_path, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"], + ) + + return model_path + return pretrained_model_name_or_path + + +def get_tokenizer( + pretrained_model_name_or_path: str, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + if pretrained_model_name_or_path is not None and not os.path.exists( + pretrained_model_name_or_path + ): + pretrained_model_name_or_path = get_model(pretrained_model_name_or_path) + return AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=True + ) + + +ASYNC_REQUEST_FUNCS = { + "sglang": async_request_openai_completions, + "vllm": async_request_openai_completions, + "lmdeploy": async_request_openai_completions, + "trt": async_request_trt_llm, +} + + +@dataclass +class BenchmarkMetrics: + completed: int + total_input: int + total_output: int + total_output_retokenized: int + request_throughput: float + input_throughput: float + output_throughput: float + output_throughput_retokenized: float + mean_ttft_ms: float + median_ttft_ms: float + std_ttft_ms: float + p99_ttft_ms: float + mean_tpot_ms: float + median_tpot_ms: float + std_tpot_ms: float + p99_tpot_ms: float + mean_itl_ms: float + median_itl_ms: float + std_itl_ms: float + p99_itl_ms: float + mean_e2e_latency_ms: float + median_e2e_latency_ms: float + + +default_sharegpt_path = "ShareGPT_V3_unfiltered_cleaned_split.json" + + +def download_sharegpt_dataset(path): + url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" + + print(f"Downloading dataset from {url}") + try: + response = requests.get(url, stream=True) + response.raise_for_status() + + total_size = int(response.headers.get("content-length", 0)) + block_size = 8192 + + with open(path, "wb") as f, tqdm( + desc="Downloading", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as progress_bar: + for data in response.iter_content(block_size): + size = f.write(data) + progress_bar.update(size) + + print(f"Dataset downloaded and saved to {path}") + except requests.RequestException as e: + raise Exception(f"Failed to download dataset: {e}") + + +def sample_sharegpt_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int] = None, +) -> List[Tuple[str, int, int]]: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Download sharegpt if necessary + if not os.path.isfile(dataset_path) and not os.path.isfile(default_sharegpt_path): + download_sharegpt_dataset(default_sharegpt_path) + dataset_path = default_sharegpt_path + else: + dataset_path = ( + dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path + ) + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + ] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: List[Tuple[str, int, int]] = [] + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = ( + len(completion_token_ids) if fixed_output_len is None else fixed_output_len + ) + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + filtered_dataset.append((prompt, prompt_len, output_len)) + + return filtered_dataset +import pickle + +def sample_random_requests( + input_len: int, + output_len: int, + num_prompts: int, + range_ratio: float, + tokenizer: PreTrainedTokenizerBase, + dataset_path: str, +) -> List[Tuple[str, int, int]]: + cache_path = f'./input_cache_{num_prompts}' + # 尝试加载缓存的 input_requests + if os.path.isfile(cache_path): + with open(cache_path, 'rb') as f: + input_requests = pickle.load(f) + print("Loaded input_requests from cache.") + return input_requests + input_lens = np.random.randint( + max(int(input_len * range_ratio), 1), + input_len + 1, + size=num_prompts, + ) + output_lens = np.random.randint( + int(output_len * range_ratio), + output_len + 1, + size=num_prompts, + ) + # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens + + # Download sharegpt if necessary + if not os.path.isfile(dataset_path) and not os.path.isfile( + default_sharegpt_path + ): + download_sharegpt_dataset(default_sharegpt_path) + dataset_path = default_sharegpt_path + else: + dataset_path = ( + dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path + ) + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + ] + + # Shuffle the dataset. + random.shuffle(dataset) + + def process_prompt(i): + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + prompt_len = len(prompt_token_ids) + if prompt_len > input_lens[i]: + input_ids = prompt_token_ids[: input_lens[i]] + else: + ratio = (input_lens[i] + prompt_len - 1) // prompt_len + input_ids = (prompt_token_ids * ratio)[: input_lens[i]] + prompt = tokenizer.decode(input_ids) + # print(f"sample {i} has been processed...") + return (prompt, int(input_lens[i]), int(output_lens[i])) + input_requests = [] + # Filter out sequences that are too long or too short + t1 = time.time() + with ThreadPoolExecutor(max_workers=os.cpu_count() * 5) as executor: + # 提交所有任务 + futures = [executor.submit(process_prompt, i) for i in range(num_prompts)] + + # 等待所有任务完成,并收集结果 + for future in as_completed(futures): + try: + result = future.result() + input_requests.append(result) + except Exception as e: + print(f"Task generated an exception: {e}") + t2 = time.time() + print(f"It takes {t2 - t1} seconds to prepare prompts....") + # 保存 input_requests 到缓存文件 + with open(cache_path, 'wb') as f: + pickle.dump(input_requests, f) + print(f"Saved input_requests_{num_prompts} to cache.") + print(f"#Input tokens: {np.sum(input_lens)}") + print(f"#Output tokens: {np.sum(output_lens)}") + return input_requests + + +async def get_request( + input_requests: List[Tuple[str, int, int]], + request_rate: float, +) -> AsyncGenerator[Tuple[str, int, int], None]: + input_requests = iter(input_requests) + for request in input_requests: + yield request + + if request_rate == float("inf"): + # If the request rate is infinity, then we don't need to wait. + continue + + # Sample the request interval from the exponential distribution. + interval = np.random.exponential(1.0 / request_rate) + # The next request will be sent after the interval. + await asyncio.sleep(interval) + + +def calculate_metrics( + input_requests: List[Tuple[str, int, int]], + outputs: List[RequestFuncOutput], + dur_s: float, + tokenizer: PreTrainedTokenizerBase, + backend: str, +) -> Tuple[BenchmarkMetrics, List[int]]: + output_lens: List[int] = [] + retokenized_output_lens: List[int] = [] + total_input = 0 + completed = 0 + itls: List[float] = [] + tpots: List[float] = [] + ttfts: List[float] = [] + e2e_latencies: List[float] = [] + + input_lens:List[float] = [] + + for i in range(len(outputs)): + if outputs[i].success: + output_len = outputs[i].output_len + output_lens.append(output_len) + retokenized_output_len = len( + tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids + ) + retokenized_output_lens.append(retokenized_output_len) + total_input += input_requests[i][1] + + input_lens.append(input_requests[i][1]) + + + if output_len > 1: + tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1)) + itls += outputs[i].itl + ttfts.append(outputs[i].ttft) + + e2e_latencies.append(outputs[i].latency) + + completed += 1 + else: + output_lens.append(0) + retokenized_output_lens.append(0) + + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2, + ) + + metric_data = [input_lens, output_lens, ttfts] + with open(f'metrics_{time.time()}.json', 'w') as f: + json.dump(metric_data, f) + + metrics = BenchmarkMetrics( + completed=completed, + total_input=total_input, + total_output=sum(output_lens), + total_output_retokenized=sum(retokenized_output_lens), + request_throughput=completed / dur_s, + input_throughput=total_input / dur_s, + output_throughput=sum(output_lens) / dur_s, + output_throughput_retokenized=sum(retokenized_output_lens) / dur_s, + mean_ttft_ms=np.mean(ttfts or 0) + * 1000, # ttfts is empty if streaming is not supported by backend + median_ttft_ms=np.median(ttfts or 0) * 1000, + std_ttft_ms=np.std(ttfts or 0) * 1000, + p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, + mean_tpot_ms=np.mean(tpots or 0) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + std_tpot_ms=np.std(tpots or 0) * 1000, + p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000, + mean_itl_ms=np.mean(itls or 0) * 1000, + median_itl_ms=np.median(itls or 0) * 1000, + std_itl_ms=np.std(itls or 0) * 1000, + p99_itl_ms=np.percentile(itls or 0, 99) * 1000, + mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000, + median_e2e_latency_ms=np.median(e2e_latencies) * 1000, + ) + + return metrics, output_lens + + +async def benchmark( + backend: str, + api_url: str, + model_id: str, + tokenizer: PreTrainedTokenizerBase, + input_requests: List[Tuple[str, int, int]], + request_rate: float, + disable_tqdm: bool, + enable_multi: bool, + extra_request_body: Dict[str, Any], +): + if backend in ASYNC_REQUEST_FUNCS: + request_func = ASYNC_REQUEST_FUNCS[backend] + else: + raise ValueError(f"Unknown backend: {backend}") + + print("Starting initial single prompt test run...") + test_prompt, test_prompt_len, test_output_len = input_requests[0] + + words = test_prompt.split() + + # 使用random.shuffle打乱单词列表 + random.shuffle(words) + + # 将打乱后的单词列表合并回文本 + test_prompt = ' '.join(words) + + + test_input = RequestFuncInput( + model=model_id, + prompt=test_prompt, + api_url=api_url, + prompt_len=test_prompt_len, + output_len=test_output_len, + extra_request_body=extra_request_body, + ) + + pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + + benchmark_start_time = time.perf_counter() + tasks: List[asyncio.Task] = [] + async for request in get_request(input_requests, request_rate): + prompt, prompt_len, output_len = request + request_func_input = RequestFuncInput( + model=model_id, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + extra_request_body=extra_request_body, + ) + tasks.append( + asyncio.create_task( + request_func(request_func_input=request_func_input, pbar=pbar) + ) + ) + outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) + + if pbar is not None: + pbar.close() + + benchmark_duration = time.perf_counter() - benchmark_start_time + + metrics, output_lens = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + backend=backend, + ) + + print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) + print("{:<40} {:<10}".format("Backend:", backend)) + print("{:<40} {:<10}".format("Traffic request rate:", request_rate)) + print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) + print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) + print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) + print( + "{:<40} {:<10}".format( + "Total generated tokens (retokenized):", metrics.total_output_retokenized + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Request throughput (req/s):", metrics.request_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Input token throughput (tok/s):", metrics.input_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Output token throughput (tok/s):", metrics.output_throughput + ) + ) + print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-")) + print( + "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms) + ) + print( + "{:<40} {:<10.2f}".format( + "Median E2E Latency (ms):", metrics.median_e2e_latency_ms + ) + ) + print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-")) + print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms)) + print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms)) + print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms)) + print( + "{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-") + ) + print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms)) + print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms)) + print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms)) + print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-")) + print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms)) + print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms)) + print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms)) + print("=" * 50) + + if ( + metrics.median_ttft_ms is not None + and metrics.mean_itl_ms is not None + and metrics.output_throughput is not None + ): + result = { + "backend": args.backend, + "dataset_name": args.dataset_name, + "request_rate": request_rate, + "total_input": metrics.total_input, + "total_output": metrics.total_output, + "total_output_retokenized": metrics.total_output_retokenized, + "mean_e2e_latency": metrics.mean_e2e_latency_ms, + "median_e2e_latency": metrics.median_e2e_latency_ms, + "median_ttft": metrics.median_ttft_ms, + "median_itl": metrics.median_itl_ms, + "output_token_throughput": metrics.output_throughput, + "sharegpt_output_len": args.sharegpt_output_len, + "random_input_len": args.random_input_len, + "random_output_len": args.random_output_len, + "random_range_ratio": args.random_range_ratio, + "benchmark_duration": benchmark_duration, + } + else: + print(f"Error running benchmark for request rate: {request_rate}") + print("-" * 30) + + # Determine output file name + if args.output_file: + output_file_name = args.output_file + else: + now = datetime.now().strftime("%m%d") + if args.dataset_name == "random": + output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl" + else: + output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl" + + # Append results to a JSONL file + with open(output_file_name, "a") as file: + file.write(json.dumps(result) + "\n") + + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "total_output_tokens_retokenized": metrics.total_output_retokenized, + "request_throughput": metrics.request_throughput, + "input_throughput": metrics.input_throughput, + "output_throughput": metrics.output_throughput, + "mean_ttft_ms": metrics.mean_ttft_ms, + "median_ttft_ms": metrics.median_ttft_ms, + "std_ttft_ms": metrics.std_ttft_ms, + "p99_ttft_ms": metrics.p99_ttft_ms, + "mean_tpot_ms": metrics.mean_tpot_ms, + "median_tpot_ms": metrics.median_tpot_ms, + "std_tpot_ms": metrics.std_tpot_ms, + "p99_tpot_ms": metrics.p99_tpot_ms, + "mean_itl_ms": metrics.mean_itl_ms, + "median_itl_ms": metrics.median_itl_ms, + "std_itl_ms": metrics.std_itl_ms, + "p99_itl_ms": metrics.p99_itl_ms, + "input_lens": [output.prompt_len for output in outputs], + "output_lens": output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], + "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, + "median_e2e_latency_ms": metrics.median_e2e_latency_ms, + } + return result + + +def parse_request_rate_range(request_rate_range): + if len(request_rate_range.split(",")) == 3: + start, stop, step = map(int, request_rate_range.split(",")) + return list(range(start, stop, step)) + else: + return list(map(int, request_rate_range.split(","))) + + +def check_chat_template(model_path): + try: + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + return "chat_template" in tokenizer.init_kwargs + except Exception as e: + print(f"Fail to load tokenizer config with error={e}") + return False + + +def run_benchmark(args_: argparse.Namespace): + global args + args = args_ + + set_ulimit() + random.seed(args.seed) + np.random.seed(args.seed) + + extra_request_body = {} + if args.extra_request_body: + extra_request_body = json.loads(args.extra_request_body) + + if args.port is None: + args.port = { + "sglang": 30000, + "lmdeploy": 23333, + "vllm": 8000, + "trt": 8000, + }.get(args.backend, 30000) + + api_url = ( + f"{args.base_url}/v1/completions" + if args.base_url + else f"http://{args.host}:{args.port}/v1/completions" + ) + model_url = ( + f"{args.base_url}/v1/models" + if args.base_url + else f"http://{args.host}:{args.port}/v1/models" + ) + + if args.backend == "trt": + api_url = ( + f"{args.base_url}/v2/models/ensemble/generate_stream" + if args.base_url + else f"http://{args.host}:{args.port}/v2/models/ensemble/generate_stream" + ) + if args.model is None: + print("Please provide a model using `--model` when using `trt` backend.") + sys.exit(1) + + if args.model is None: + try: + response = requests.get(model_url) + model_list = response.json().get("data", []) + args.model = model_list[0]["id"] if model_list else None + except Exception as e: + print(f"Failed to fetch model from {model_url}. Error: {e}") + print( + "Please specify the correct host and port using `--host` and `--port`." + ) + sys.exit(1) + + if args.model is None: + print("No model specified or found. Please provide a model using `--model`.") + sys.exit(1) + + if not check_chat_template(args.model): + print( + "\nWARNING It is recommended to use the `Chat` or `Instruct` model for benchmarking.\n" + "Because when the tokenizer counts the output tokens, if there is gibberish, it might count incorrectly.\n" + ) + + print(f"{args}\n") + + backend = args.backend + model_id = args.model + tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model + + tokenizer = get_tokenizer(tokenizer_id) + + if args.dataset_name == "sharegpt": + input_requests = sample_sharegpt_requests( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + tokenizer=tokenizer, + fixed_output_len=args.sharegpt_output_len, + ) + elif args.dataset_name == "random": + input_requests = sample_random_requests( + input_len=args.random_input_len, + output_len=args.random_output_len, + num_prompts=args.num_prompts, + range_ratio=args.random_range_ratio, + tokenizer=tokenizer, + dataset_path=args.dataset_path, + ) + else: + raise ValueError(f"Unknown dataset: {args.dataset_name}") + + if args.multi: + request_rates = parse_request_rate_range(args.request_rate_range) + + for rate in request_rates: + asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_rate=rate, + disable_tqdm=args.disable_tqdm, + enable_multi=args.multi, + extra_request_body=extra_request_body, + ) + ) + else: + return asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_rate=args.request_rate, + disable_tqdm=args.disable_tqdm, + enable_multi=args.multi, + extra_request_body=extra_request_body, + ) + ) + + +# to avoid relying on SGLang's components +def set_ulimit(target_soft_limit=65535): + resource_type = resource.RLIMIT_NOFILE + current_soft, current_hard = resource.getrlimit(resource_type) + + if current_soft < target_soft_limit: + try: + resource.setrlimit(resource_type, (target_soft_limit, current_hard)) + except ValueError as e: + print(f"Fail to set RLIMIT_NOFILE: {e}") + + +if __name__ == "__main__": + parser = ArgumentParser(description="Benchmark the online serving throughput.") + parser.add_argument( + "--backend", + type=str, + choices=list(ASYNC_REQUEST_FUNCS.keys()), + default="sglang", + help="Must specify a backend, depending on the LLM Inference Engine.", + ) + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server or API base url if not using http host and port.", + ) + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0." + ) + parser.add_argument( + "--port", + type=int, + help="If not set, the default port is configured according to its default value for different LLM Inference Engines.", + ) + parser.add_argument( + "--dataset-name", + type=str, + default="sharegpt", + choices=["sharegpt", "random"], + help="Name of the dataset to benchmark on.", + ) + parser.add_argument( + "--dataset-path", type=str, default="", help="Path to the dataset." + ) + parser.add_argument( + "--model", + type=str, + help="Name or path of the model. If not set, the default model will request /v1/models for conf.", + ) + parser.add_argument( + "--tokenizer", + type=str, + help="Name or path of the tokenizer. If not set, using the model conf.", + ) + parser.add_argument( + "--num-prompts", + type=int, + default=1000, + help="Number of prompts to process. Default is 1000.", + ) + parser.add_argument( + "--sharegpt-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output length from the ShareGPT dataset.", + ) + parser.add_argument( + "--random-input-len", + type=int, + default=1024, + help="Number of input tokens per request, used only for random dataset.", + ) + parser.add_argument( + "--random-output-len", + type=int, + default=128, + help="Number of output tokens per request, used only for random dataset.", + ) + parser.add_argument( + "--random-range-ratio", + type=float, + default=0.0, + help="Range of sampled ratio of input/output length, " + "used only for random dataset.", + ) + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, then all the requests are sent at time 0. " + "Otherwise, we use Poisson process to synthesize the request arrival times. Default is 128.0.", + ) + parser.add_argument("--seed", type=int, default=0, help="Default is 0.") + parser.add_argument( + "--multi", + action="store_true", + help="Use request rate range rather than single value.", + ) + parser.add_argument( + "--request-rate-range", + type=str, + default="2,34,2", + help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.", + ) + parser.add_argument("--output-file", type=str, help="Output JSONL file name.") + parser.add_argument( + "--disable-tqdm", + action="store_true", + help="Specify to disable tqdm progress bar.", + ) + parser.add_argument( + "--disable-stream", + action="store_true", + help="Disable streaming mode.", + ) + parser.add_argument( + "--disable-ignore-eos", + action="store_true", + help="Disable ignoring EOS.", + ) + parser.add_argument( + "--extra-request-body", + metavar='{"key1": "value1", "key2": "value2"}', + type=str, + help="Append given JSON object to the request payload. You can use this to specify" + "additional generate params like sampling params.", + ) + args = parser.parse_args() + run_benchmark(args) diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index b02ce9f81e..0689cc11d4 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -25,7 +25,7 @@ def __init__(self): self.layer_sync_threshold = 8192 # Runtime constants: others - self.num_continue_decode_steps = 10 + self.num_continue_decode_steps = 1 self.retract_decode_steps = 20 self.flashinfer_workspace_size = 192 * 1024 * 1024 diff --git a/python/sglang/iter_search_best_qps.py b/python/sglang/iter_search_best_qps.py new file mode 100644 index 0000000000..d978cd21ba --- /dev/null +++ b/python/sglang/iter_search_best_qps.py @@ -0,0 +1,1098 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py +# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py + +""" +Benchmark online serving. + +Usage: +python3 -m sglang.bench_serving --backend sglang --num-prompt 10 + +python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5 +python3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi +""" + +import argparse +import asyncio +import json +import os +import random +import resource +import sys +import time +import traceback +import warnings +from argparse import ArgumentParser +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union + +import aiohttp +import numpy as np +import requests +from tqdm.asyncio import tqdm +from transformers import ( + AutoTokenizer, + PreTrainedTokenizer, + PreTrainedTokenizerBase, + PreTrainedTokenizerFast, +) + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + +global args + + +@dataclass +class RequestFuncInput: + prompt: str + api_url: str + prompt_len: int + output_len: int + model: str + extra_request_body: Dict[str, Any] + + +@dataclass +class RequestFuncOutput: + generated_text: str = "" + success: bool = False + latency: float = 0.0 + ttft: float = 0.0 # Time to first token + itl: List[float] = field(default_factory=list) # List of inter-token latencies + prompt_len: int = 0 + error: str = "" + output_len: int = 0 + + +def remove_prefix(text: str, prefix: str) -> str: + return text[len(prefix) :] if text.startswith(prefix) else text + + +# trt llm not support ignore_eos +# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505 +async def async_request_trt_llm( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith("generate_stream") + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + payload = { + "accumulate_tokens": True, + "text_input": request_func_input.prompt, + "temperature": 0.000001, + "top_p": 1.0, + "max_tokens": request_func_input.output_len, + "stream": True, + "min_length": request_func_input.output_len, + "end_id": 1048576, + **request_func_input.extra_request_body, + } + if args.disable_ignore_eos: + del payload["min_length"] + del payload["end_id"] + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data:") + + data = json.loads(chunk) + output.generated_text += data["text_output"] + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + + output.latency = most_recent_timestamp - st + output.success = True + output.output_len = request_func_input.output_len + + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +# set ignore_eos True by default +async def async_request_openai_completions( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith( + "completions" + ), "OpenAI Completions API URL must end with 'completions'." + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + payload = { + "model": request_func_input.model, + "prompt": request_func_input.prompt, + "temperature": 0.0, + "best_of": 1, + "max_tokens": request_func_input.output_len, + "stream": not args.disable_stream, + "ignore_eos": not args.disable_ignore_eos, + **request_func_input.extra_request_body, + } + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") + latency = time.perf_counter() - st + if chunk == "[DONE]": + pass + else: + data = json.loads(chunk) + + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if data["choices"][0]["text"]: + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text += data["choices"][0]["text"] + + output.generated_text = generated_text + output.success = True + output.latency = latency + output.output_len = request_func_input.output_len + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +def get_model(pretrained_model_name_or_path: str) -> str: + if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true": + import huggingface_hub.constants + from modelscope import snapshot_download + + model_path = snapshot_download( + model_id=pretrained_model_name_or_path, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"], + ) + + return model_path + return pretrained_model_name_or_path + + +def get_tokenizer( + pretrained_model_name_or_path: str, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + if pretrained_model_name_or_path is not None and not os.path.exists( + pretrained_model_name_or_path + ): + pretrained_model_name_or_path = get_model(pretrained_model_name_or_path) + return AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=True + ) + + +ASYNC_REQUEST_FUNCS = { + "sglang": async_request_openai_completions, + "vllm": async_request_openai_completions, + "lmdeploy": async_request_openai_completions, + "trt": async_request_trt_llm, +} + + +@dataclass +class BenchmarkMetrics: + completed: int + total_input: int + total_output: int + total_output_retokenized: int + request_throughput: float + input_throughput: float + output_throughput: float + output_throughput_retokenized: float + mean_ttft_ms: float + median_ttft_ms: float + std_ttft_ms: float + p99_ttft_ms: float + mean_tpot_ms: float + median_tpot_ms: float + std_tpot_ms: float + p99_tpot_ms: float + mean_itl_ms: float + median_itl_ms: float + std_itl_ms: float + p99_itl_ms: float + mean_e2e_latency_ms: float + median_e2e_latency_ms: float + + +default_sharegpt_path = "ShareGPT_V3_unfiltered_cleaned_split.json" + + +def download_sharegpt_dataset(path): + url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" + + print(f"Downloading dataset from {url}") + try: + response = requests.get(url, stream=True) + response.raise_for_status() + + total_size = int(response.headers.get("content-length", 0)) + block_size = 8192 + + with open(path, "wb") as f, tqdm( + desc="Downloading", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as progress_bar: + for data in response.iter_content(block_size): + size = f.write(data) + progress_bar.update(size) + + print(f"Dataset downloaded and saved to {path}") + except requests.RequestException as e: + raise Exception(f"Failed to download dataset: {e}") + + +def sample_sharegpt_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int] = None, +) -> List[Tuple[str, int, int]]: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Download sharegpt if necessary + if not os.path.isfile(dataset_path) and not os.path.isfile(default_sharegpt_path): + download_sharegpt_dataset(default_sharegpt_path) + dataset_path = default_sharegpt_path + else: + dataset_path = ( + dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path + ) + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + ] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: List[Tuple[str, int, int]] = [] + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = ( + len(completion_token_ids) if fixed_output_len is None else fixed_output_len + ) + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + filtered_dataset.append((prompt, prompt_len, output_len)) + + return filtered_dataset + + +def sample_random_requests( + input_len: int, + output_len: int, + num_prompts: int, + range_ratio: float, + tokenizer: PreTrainedTokenizerBase, + dataset_path: str, +) -> List[Tuple[str, int, int]]: + + input_lens = np.random.randint( + max(int(input_len * range_ratio), 1), + input_len + 1, + size=num_prompts, + ) + output_lens = np.random.randint( + int(output_len * range_ratio), + output_len + 1, + size=num_prompts, + ) + + if True: + # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens + + # Download sharegpt if necessary + if not os.path.isfile(dataset_path) and not os.path.isfile( + default_sharegpt_path + ): + download_sharegpt_dataset(default_sharegpt_path) + dataset_path = default_sharegpt_path + else: + dataset_path = ( + dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path + ) + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + ] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + input_requests: List[Tuple[str, int, int]] = [] + for i in range(num_prompts): + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + prompt_len = len(prompt_token_ids) + + if prompt_len > input_lens[i]: + input_ids = prompt_token_ids[: input_lens[i]] + else: + ratio = (input_lens[i] + prompt_len - 1) // prompt_len + input_ids = (prompt_token_ids * ratio)[: input_lens[i]] + prompt = tokenizer.decode(input_ids) + input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) + else: + # Sample token ids from random integers. This can cause some NaN issues. + offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) + input_requests = [] + for i in range(num_prompts): + prompt = tokenizer.decode( + [ + (offsets[i] + i + j) % tokenizer.vocab_size + for j in range(input_lens[i]) + ] + ) + input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) + + print(f"#Input tokens: {np.sum(input_lens)}") + print(f"#Output tokens: {np.sum(output_lens)}") + return input_requests + + +async def get_request( + input_requests: List[Tuple[str, int, int]], + request_rate: float, +) -> AsyncGenerator[Tuple[str, int, int], None]: + input_requests = iter(input_requests) + for request in input_requests: + yield request + + if request_rate == float("inf"): + # If the request rate is infinity, then we don't need to wait. + continue + + # Sample the request interval from the exponential distribution. + interval = np.random.exponential(1.0 / request_rate) + # The next request will be sent after the interval. + await asyncio.sleep(interval) + + +def calculate_metrics( + input_requests: List[Tuple[str, int, int]], + outputs: List[RequestFuncOutput], + dur_s: float, + tokenizer: PreTrainedTokenizerBase, + backend: str, +) -> Tuple[BenchmarkMetrics, List[int]]: + output_lens: List[int] = [] + retokenized_output_lens: List[int] = [] + total_input = 0 + completed = 0 + itls: List[float] = [] + tpots: List[float] = [] + ttfts: List[float] = [] + e2e_latencies: List[float] = [] + for i in range(len(outputs)): + if outputs[i].success: + output_len = outputs[i].output_len + output_lens.append(output_len) + retokenized_output_len = len( + tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids + ) + retokenized_output_lens.append(retokenized_output_len) + total_input += input_requests[i][1] + if output_len > 1: + tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1)) + itls += outputs[i].itl + ttfts.append(outputs[i].ttft) + + e2e_latencies.append(outputs[i].latency) + + completed += 1 + else: + output_lens.append(0) + retokenized_output_lens.append(0) + + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2, + ) + metrics = BenchmarkMetrics( + completed=completed, + total_input=total_input, + total_output=sum(output_lens), + total_output_retokenized=sum(retokenized_output_lens), + request_throughput=completed / dur_s, + input_throughput=total_input / dur_s, + output_throughput=sum(output_lens) / dur_s, + output_throughput_retokenized=sum(retokenized_output_lens) / dur_s, + mean_ttft_ms=np.mean(ttfts or 0) + * 1000, # ttfts is empty if streaming is not supported by backend + median_ttft_ms=np.median(ttfts or 0) * 1000, + std_ttft_ms=np.std(ttfts or 0) * 1000, + p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, + mean_tpot_ms=np.mean(tpots or 0) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + std_tpot_ms=np.std(tpots or 0) * 1000, + p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000, + mean_itl_ms=np.mean(itls or 0) * 1000, + median_itl_ms=np.median(itls or 0) * 1000, + std_itl_ms=np.std(itls or 0) * 1000, + p99_itl_ms=np.percentile(itls or 0, 99) * 1000, + mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000, + median_e2e_latency_ms=np.median(e2e_latencies) * 1000, + ) + + return metrics, output_lens + + +async def benchmark( + backend: str, + api_url: str, + model_id: str, + tokenizer: PreTrainedTokenizerBase, + input_requests: List[Tuple[str, int, int]], + request_rate: float, + disable_tqdm: bool, + enable_multi: bool, + extra_request_body: Dict[str, Any], +): + if backend in ASYNC_REQUEST_FUNCS: + request_func = ASYNC_REQUEST_FUNCS[backend] + else: + raise ValueError(f"Unknown backend: {backend}") + + print("Starting initial single prompt test run...") + test_prompt, test_prompt_len, test_output_len = input_requests[0] + + words = test_prompt.split() + + # 使用random.shuffle打乱单词列表 + random.shuffle(words) + + # 将打乱后的单词列表合并回文本 + test_prompt = ' '.join(words) + + + test_input = RequestFuncInput( + model=model_id, + prompt=test_prompt, + api_url=api_url, + prompt_len=test_prompt_len, + output_len=test_output_len, + extra_request_body=extra_request_body, + ) + + pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + + benchmark_start_time = time.perf_counter() + tasks: List[asyncio.Task] = [] + async for request in get_request(input_requests, request_rate): + prompt, prompt_len, output_len = request + request_func_input = RequestFuncInput( + model=model_id, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + extra_request_body=extra_request_body, + ) + tasks.append( + asyncio.create_task( + request_func(request_func_input=request_func_input, pbar=pbar) + ) + ) + outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) + + if pbar is not None: + pbar.close() + + benchmark_duration = time.perf_counter() - benchmark_start_time + + metrics, output_lens = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + backend=backend, + ) + + print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) + print("{:<40} {:<10}".format("Backend:", backend)) + print("{:<40} {:<10}".format("Traffic request rate:", request_rate)) + print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) + print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) + print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) + print( + "{:<40} {:<10}".format( + "Total generated tokens (retokenized):", metrics.total_output_retokenized + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Request throughput (req/s):", metrics.request_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Input token throughput (tok/s):", metrics.input_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Output token throughput (tok/s):", metrics.output_throughput + ) + ) + print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-")) + print( + "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms) + ) + print( + "{:<40} {:<10.2f}".format( + "Median E2E Latency (ms):", metrics.median_e2e_latency_ms + ) + ) + print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-")) + print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms)) + print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms)) + print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms)) + print( + "{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-") + ) + print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms)) + print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms)) + print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms)) + print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-")) + print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms)) + print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms)) + print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms)) + print("=" * 50) + + if ( + metrics.median_ttft_ms is not None + and metrics.mean_itl_ms is not None + and metrics.output_throughput is not None + ): + result = { + "backend": args.backend, + "dataset_name": args.dataset_name, + "request_rate": request_rate, + "total_input": metrics.total_input, + "total_output": metrics.total_output, + "total_output_retokenized": metrics.total_output_retokenized, + "mean_e2e_latency": metrics.mean_e2e_latency_ms, + "median_e2e_latency": metrics.median_e2e_latency_ms, + "median_ttft": metrics.median_ttft_ms, + "median_itl": metrics.median_itl_ms, + "output_token_throughput": metrics.output_throughput, + "sharegpt_output_len": args.sharegpt_output_len, + "random_input_len": args.random_input_len, + "random_output_len": args.random_output_len, + "random_range_ratio": args.random_range_ratio, + "benchmark_duration": benchmark_duration, + } + else: + print(f"Error running benchmark for request rate: {request_rate}") + print("-" * 30) + + # Determine output file name + if args.output_file: + output_file_name = args.output_file + else: + now = datetime.now().strftime("%m%d") + if args.dataset_name == "random": + output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl" + else: + output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl" + + # Append results to a JSONL file + with open(output_file_name, "a") as file: + file.write(json.dumps(result) + "\n") + + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "total_output_tokens_retokenized": metrics.total_output_retokenized, + "request_throughput": metrics.request_throughput, + "input_throughput": metrics.input_throughput, + "output_throughput": metrics.output_throughput, + "mean_ttft_ms": metrics.mean_ttft_ms, + "median_ttft_ms": metrics.median_ttft_ms, + "std_ttft_ms": metrics.std_ttft_ms, + "p99_ttft_ms": metrics.p99_ttft_ms, + "mean_tpot_ms": metrics.mean_tpot_ms, + "median_tpot_ms": metrics.median_tpot_ms, + "std_tpot_ms": metrics.std_tpot_ms, + "p99_tpot_ms": metrics.p99_tpot_ms, + "mean_itl_ms": metrics.mean_itl_ms, + "median_itl_ms": metrics.median_itl_ms, + "std_itl_ms": metrics.std_itl_ms, + "p99_itl_ms": metrics.p99_itl_ms, + "input_lens": [output.prompt_len for output in outputs], + "output_lens": output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], + "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, + "median_e2e_latency_ms": metrics.median_e2e_latency_ms, + } + return result + + +def parse_request_rate_range(request_rate_range): + if len(request_rate_range.split(",")) == 3: + start, stop, step = map(int, request_rate_range.split(",")) + return list(range(start, stop, step)) + else: + return list(map(int, request_rate_range.split(","))) + + +def check_chat_template(model_path): + try: + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + return "chat_template" in tokenizer.init_kwargs + except Exception as e: + print(f"Fail to load tokenizer config with error={e}") + return False + + +def run_benchmark(args_: argparse.Namespace): + global args + args = args_ + + set_ulimit() + random.seed(args.seed) + np.random.seed(args.seed) + + extra_request_body = {} + if args.extra_request_body: + extra_request_body = json.loads(args.extra_request_body) + + if args.port is None: + args.port = { + "sglang": 30000, + "lmdeploy": 23333, + "vllm": 8000, + "trt": 8000, + }.get(args.backend, 30000) + + api_url = ( + f"{args.base_url}/v1/completions" + if args.base_url + else f"http://{args.host}:{args.port}/v1/completions" + ) + model_url = ( + f"{args.base_url}/v1/models" + if args.base_url + else f"http://{args.host}:{args.port}/v1/models" + ) + + if args.backend == "trt": + api_url = ( + f"{args.base_url}/v2/models/ensemble/generate_stream" + if args.base_url + else f"http://{args.host}:{args.port}/v2/models/ensemble/generate_stream" + ) + if args.model is None: + print("Please provide a model using `--model` when using `trt` backend.") + sys.exit(1) + + if args.model is None: + try: + response = requests.get(model_url) + model_list = response.json().get("data", []) + args.model = model_list[0]["id"] if model_list else None + except Exception as e: + print(f"Failed to fetch model from {model_url}. Error: {e}") + print( + "Please specify the correct host and port using `--host` and `--port`." + ) + sys.exit(1) + + if args.model is None: + print("No model specified or found. Please provide a model using `--model`.") + sys.exit(1) + + if not check_chat_template(args.model): + print( + "\nWARNING It is recommended to use the `Chat` or `Instruct` model for benchmarking.\n" + "Because when the tokenizer counts the output tokens, if there is gibberish, it might count incorrectly.\n" + ) + + print(f"{args}\n") + + backend = args.backend + model_id = args.model + tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model + + tokenizer = get_tokenizer(tokenizer_id) + + if args.dataset_name == "sharegpt": + input_requests = sample_sharegpt_requests( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + tokenizer=tokenizer, + fixed_output_len=args.sharegpt_output_len, + ) + elif args.dataset_name == "random": + input_requests = sample_random_requests( + input_len=args.random_input_len, + output_len=args.random_output_len, + num_prompts=args.num_prompts, + range_ratio=args.random_range_ratio, + tokenizer=tokenizer, + dataset_path=args.dataset_path, + ) + else: + raise ValueError(f"Unknown dataset: {args.dataset_name}") + + if args.multi: + request_rates = parse_request_rate_range(args.request_rate_range) + + for rate in request_rates: + asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_rate=rate, + disable_tqdm=args.disable_tqdm, + enable_multi=args.multi, + extra_request_body=extra_request_body, + ) + ) + else: + return asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_rate=args.request_rate, + disable_tqdm=args.disable_tqdm, + enable_multi=args.multi, + extra_request_body=extra_request_body, + ) + ) + + +# to avoid relying on SGLang's components +def set_ulimit(target_soft_limit=65535): + resource_type = resource.RLIMIT_NOFILE + current_soft, current_hard = resource.getrlimit(resource_type) + + if current_soft < target_soft_limit: + try: + resource.setrlimit(resource_type, (target_soft_limit, current_hard)) + except ValueError as e: + print(f"Fail to set RLIMIT_NOFILE: {e}") + + +import json + +def update_qps_json_file(data_list, file_path): + # 读取现有数据 + try: + with open(file_path, 'r', encoding='utf-8') as json_file: + qps_result_list = json.load(json_file) + except FileNotFoundError: + qps_result_list = [] + + # 获取现有数据中所有的request_rate + existing_rates = {item['request_rate'] for item in qps_result_list} + + # 添加新数据,如果request_rate不存在 + new_entries_added = False + for new_data in data_list: + if new_data['request_rate'] not in existing_rates: + qps_result_list.append(new_data) + existing_rates.add(new_data['request_rate']) + new_entries_added = True + else: + print(f"数据已存在,request_rate为{new_data['request_rate']}的项将被跳过。") + + # 如果有新数据被添加,则排序并写入文件 + if new_entries_added: + # 按request_rate排序 + qps_result_list.sort(key=lambda x: x['request_rate']) + + # 写入更新后的数据 + with open(file_path, 'w', encoding='utf-8') as json_file: + json.dump(qps_result_list, json_file, ensure_ascii=False, indent=4) + print("新数据已添加并排序。") + + +if __name__ == "__main__": + parser = ArgumentParser(description="Benchmark the online serving throughput.") + parser.add_argument( + "--backend", + type=str, + choices=list(ASYNC_REQUEST_FUNCS.keys()), + default="sglang", + help="Must specify a backend, depending on the LLM Inference Engine.", + ) + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server or API base url if not using http host and port.", + ) + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0." + ) + parser.add_argument( + "--port", + type=int, + help="If not set, the default port is configured according to its default value for different LLM Inference Engines.", + ) + parser.add_argument( + "--dataset-name", + type=str, + default="sharegpt", + choices=["sharegpt", "random"], + help="Name of the dataset to benchmark on.", + ) + parser.add_argument( + "--dataset-path", type=str, default="", help="Path to the dataset." + ) + parser.add_argument( + "--model", + type=str, + help="Name or path of the model. If not set, the default model will request /v1/models for conf.", + ) + parser.add_argument( + "--tokenizer", + type=str, + help="Name or path of the tokenizer. If not set, using the model conf.", + ) + parser.add_argument( + "--num-prompts", + type=int, + default=1000, + help="Number of prompts to process. Default is 1000.", + ) + parser.add_argument( + "--sharegpt-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output length from the ShareGPT dataset.", + ) + parser.add_argument( + "--random-input-len", + type=int, + default=1024, + help="Number of input tokens per request, used only for random dataset.", + ) + parser.add_argument( + "--random-output-len", + type=int, + default=128, + help="Number of output tokens per request, used only for random dataset.", + ) + parser.add_argument( + "--random-range-ratio", + type=float, + default=0.0, + help="Range of sampled ratio of input/output length, " + "used only for random dataset.", + ) + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, then all the requests are sent at time 0. " + "Otherwise, we use Poisson process to synthesize the request arrival times. Default is 128.0.", + ) + parser.add_argument("--seed", type=int, default=0, help="Default is 0.") + parser.add_argument( + "--multi", + action="store_true", + help="Use request rate range rather than single value.", + ) + parser.add_argument( + "--request-rate-range", + type=str, + default="2,34,2", + help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.", + ) + parser.add_argument("--output-file", type=str, help="Output JSONL file name.") + parser.add_argument( + "--disable-tqdm", + action="store_true", + help="Specify to disable tqdm progress bar.", + ) + parser.add_argument( + "--disable-stream", + action="store_true", + help="Disable streaming mode.", + ) + parser.add_argument( + "--disable-ignore-eos", + action="store_true", + help="Disable ignoring EOS.", + ) + parser.add_argument( + "--extra-request-body", + metavar='{"key1": "value1", "key2": "value2"}', + type=str, + help="Append given JSON object to the request payload. You can use this to specify" + "additional generate params like sampling params.", + ) + parser.add_argument( + "--request-rate-list", + type=json.loads, + default="[1, 5, 10, 20, 35, 50, 70, 100]", + help="request rate list", + ) + + args = parser.parse_args() + qps_result_dict = [] + + + qps_list = args.request_rate_list + for request_rate in qps_list: + args.request_rate = request_rate + result = run_benchmark(args) + qps_result_dict.append({ + "request_rate": request_rate, + "mean_ttft_ms": float(result['mean_ttft_ms']), + "request_throughput": float(result["request_throughput"]) + }) + + # 将字典列表以JSON格式写入文件 + update_qps_json_file(qps_result_dict, 'qps_results.json') + # with open('qps_results.json', 'w', encoding='utf-8') as json_file: + # json.dump(qps_result_dict, json_file, ensure_ascii=False, indent=4) + + # print(f"QPS结果列表已保存到'qps_results.json'") + time.sleep(10) + + import matplotlib.pyplot as plt + + request_rates = [result['request_rate'] for result in qps_result_dict] + mean_ttft_ms = [result['mean_ttft_ms'] for result in qps_result_dict] + request_throughput = [result['request_throughput'] for result in qps_result_dict] + + fig, ax1 = plt.subplots() + + color = 'tab:red' + ax1.set_xlabel('Request Rate (QPS)') + ax1.set_ylabel('Mean Time to First Byte (ms)', color=color) + ax1.plot(request_rates, mean_ttft_ms, marker='o', color=color) + ax1.tick_params(axis='y', labelcolor=color) + + ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis + + color = 'tab:blue' + ax2.set_ylabel('Request Throughput', color=color) # we already handled the x-label with ax1 + ax2.plot(request_rates, request_throughput, marker='x', color=color) + ax2.tick_params(axis='y', labelcolor=color) + + fig.tight_layout() # otherwise the right y-label is slightly clipped + plt.title('Benchmark Results') + plt.savefig("./QPS.png") + plt.close() + + + + diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 91dc0dc4e9..0137b579c9 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -10,5 +10,6 @@ ServerArgs.add_cli_args(parser) args = parser.parse_args() server_args = ServerArgs.from_cli_args(args) + launch_server(server_args) diff --git a/python/sglang/search_best_qps.py b/python/sglang/search_best_qps.py new file mode 100644 index 0000000000..400254837b --- /dev/null +++ b/python/sglang/search_best_qps.py @@ -0,0 +1,1068 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py +# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py + +""" +Benchmark online serving. + +Usage: +python3 -m sglang.bench_serving --backend sglang --num-prompt 10 + +python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5 +python3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi +""" + +import argparse +import asyncio +import json +import os +import random +import resource +import sys +import time +import traceback +import warnings +from argparse import ArgumentParser +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union + +import aiohttp +import numpy as np +import requests +from tqdm.asyncio import tqdm +from transformers import ( + AutoTokenizer, + PreTrainedTokenizer, + PreTrainedTokenizerBase, + PreTrainedTokenizerFast, +) + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + +global args + + +@dataclass +class RequestFuncInput: + prompt: str + api_url: str + prompt_len: int + output_len: int + model: str + extra_request_body: Dict[str, Any] + + +@dataclass +class RequestFuncOutput: + generated_text: str = "" + success: bool = False + latency: float = 0.0 + ttft: float = 0.0 # Time to first token + itl: List[float] = field(default_factory=list) # List of inter-token latencies + prompt_len: int = 0 + error: str = "" + output_len: int = 0 + + +def remove_prefix(text: str, prefix: str) -> str: + return text[len(prefix) :] if text.startswith(prefix) else text + + +# trt llm not support ignore_eos +# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505 +async def async_request_trt_llm( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith("generate_stream") + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + payload = { + "accumulate_tokens": True, + "text_input": request_func_input.prompt, + "temperature": 0.000001, + "top_p": 1.0, + "max_tokens": request_func_input.output_len, + "stream": True, + "min_length": request_func_input.output_len, + "end_id": 1048576, + **request_func_input.extra_request_body, + } + if args.disable_ignore_eos: + del payload["min_length"] + del payload["end_id"] + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data:") + + data = json.loads(chunk) + output.generated_text += data["text_output"] + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + + output.latency = most_recent_timestamp - st + output.success = True + output.output_len = request_func_input.output_len + + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +# set ignore_eos True by default +async def async_request_openai_completions( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith( + "completions" + ), "OpenAI Completions API URL must end with 'completions'." + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + payload = { + "model": request_func_input.model, + "prompt": request_func_input.prompt, + "temperature": 0.0, + "best_of": 1, + "max_tokens": request_func_input.output_len, + "stream": not args.disable_stream, + "ignore_eos": not args.disable_ignore_eos, + **request_func_input.extra_request_body, + } + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") + latency = time.perf_counter() - st + if chunk == "[DONE]": + pass + else: + data = json.loads(chunk) + + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if data["choices"][0]["text"]: + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text += data["choices"][0]["text"] + + output.generated_text = generated_text + output.success = True + output.latency = latency + output.output_len = request_func_input.output_len + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +def get_model(pretrained_model_name_or_path: str) -> str: + if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true": + import huggingface_hub.constants + from modelscope import snapshot_download + + model_path = snapshot_download( + model_id=pretrained_model_name_or_path, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"], + ) + + return model_path + return pretrained_model_name_or_path + + +def get_tokenizer( + pretrained_model_name_or_path: str, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + if pretrained_model_name_or_path is not None and not os.path.exists( + pretrained_model_name_or_path + ): + pretrained_model_name_or_path = get_model(pretrained_model_name_or_path) + return AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=True + ) + + +ASYNC_REQUEST_FUNCS = { + "sglang": async_request_openai_completions, + "vllm": async_request_openai_completions, + "lmdeploy": async_request_openai_completions, + "trt": async_request_trt_llm, +} + + +@dataclass +class BenchmarkMetrics: + completed: int + total_input: int + total_output: int + total_output_retokenized: int + request_throughput: float + input_throughput: float + output_throughput: float + output_throughput_retokenized: float + mean_ttft_ms: float + median_ttft_ms: float + std_ttft_ms: float + p99_ttft_ms: float + mean_tpot_ms: float + median_tpot_ms: float + std_tpot_ms: float + p99_tpot_ms: float + mean_itl_ms: float + median_itl_ms: float + std_itl_ms: float + p99_itl_ms: float + mean_e2e_latency_ms: float + median_e2e_latency_ms: float + + +default_sharegpt_path = "ShareGPT_V3_unfiltered_cleaned_split.json" + + +def download_sharegpt_dataset(path): + url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" + + print(f"Downloading dataset from {url}") + try: + response = requests.get(url, stream=True) + response.raise_for_status() + + total_size = int(response.headers.get("content-length", 0)) + block_size = 8192 + + with open(path, "wb") as f, tqdm( + desc="Downloading", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as progress_bar: + for data in response.iter_content(block_size): + size = f.write(data) + progress_bar.update(size) + + print(f"Dataset downloaded and saved to {path}") + except requests.RequestException as e: + raise Exception(f"Failed to download dataset: {e}") + + +def sample_sharegpt_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int] = None, +) -> List[Tuple[str, int, int]]: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Download sharegpt if necessary + if not os.path.isfile(dataset_path) and not os.path.isfile(default_sharegpt_path): + download_sharegpt_dataset(default_sharegpt_path) + dataset_path = default_sharegpt_path + else: + dataset_path = ( + dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path + ) + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + ] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: List[Tuple[str, int, int]] = [] + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = ( + len(completion_token_ids) if fixed_output_len is None else fixed_output_len + ) + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + filtered_dataset.append((prompt, prompt_len, output_len)) + + return filtered_dataset + + +def sample_random_requests( + input_len: int, + output_len: int, + num_prompts: int, + range_ratio: float, + tokenizer: PreTrainedTokenizerBase, + dataset_path: str, +) -> List[Tuple[str, int, int]]: + + input_lens = np.random.randint( + max(int(input_len * range_ratio), 1), + input_len + 1, + size=num_prompts, + ) + output_lens = np.random.randint( + int(output_len * range_ratio), + output_len + 1, + size=num_prompts, + ) + + if True: + # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens + + # Download sharegpt if necessary + if not os.path.isfile(dataset_path) and not os.path.isfile( + default_sharegpt_path + ): + download_sharegpt_dataset(default_sharegpt_path) + dataset_path = default_sharegpt_path + else: + dataset_path = ( + dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path + ) + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + ] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + input_requests: List[Tuple[str, int, int]] = [] + for i in range(num_prompts): + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + prompt_len = len(prompt_token_ids) + + if prompt_len > input_lens[i]: + input_ids = prompt_token_ids[: input_lens[i]] + else: + ratio = (input_lens[i] + prompt_len - 1) // prompt_len + input_ids = (prompt_token_ids * ratio)[: input_lens[i]] + prompt = tokenizer.decode(input_ids) + input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) + else: + # Sample token ids from random integers. This can cause some NaN issues. + offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) + input_requests = [] + for i in range(num_prompts): + prompt = tokenizer.decode( + [ + (offsets[i] + i + j) % tokenizer.vocab_size + for j in range(input_lens[i]) + ] + ) + input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) + + print(f"#Input tokens: {np.sum(input_lens)}") + print(f"#Output tokens: {np.sum(output_lens)}") + return input_requests + + +async def get_request( + input_requests: List[Tuple[str, int, int]], + request_rate: float, +) -> AsyncGenerator[Tuple[str, int, int], None]: + input_requests = iter(input_requests) + for request in input_requests: + yield request + + if request_rate == float("inf"): + # If the request rate is infinity, then we don't need to wait. + continue + + # Sample the request interval from the exponential distribution. + interval = np.random.exponential(1.0 / request_rate) + # The next request will be sent after the interval. + await asyncio.sleep(interval) + + +def calculate_metrics( + input_requests: List[Tuple[str, int, int]], + outputs: List[RequestFuncOutput], + dur_s: float, + tokenizer: PreTrainedTokenizerBase, + backend: str, +) -> Tuple[BenchmarkMetrics, List[int]]: + output_lens: List[int] = [] + retokenized_output_lens: List[int] = [] + total_input = 0 + completed = 0 + itls: List[float] = [] + tpots: List[float] = [] + ttfts: List[float] = [] + e2e_latencies: List[float] = [] + for i in range(len(outputs)): + if outputs[i].success: + output_len = outputs[i].output_len + output_lens.append(output_len) + retokenized_output_len = len( + tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids + ) + retokenized_output_lens.append(retokenized_output_len) + total_input += input_requests[i][1] + if output_len > 1: + tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1)) + itls += outputs[i].itl + ttfts.append(outputs[i].ttft) + + e2e_latencies.append(outputs[i].latency) + + completed += 1 + else: + output_lens.append(0) + retokenized_output_lens.append(0) + + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2, + ) + metrics = BenchmarkMetrics( + completed=completed, + total_input=total_input, + total_output=sum(output_lens), + total_output_retokenized=sum(retokenized_output_lens), + request_throughput=completed / dur_s, + input_throughput=total_input / dur_s, + output_throughput=sum(output_lens) / dur_s, + output_throughput_retokenized=sum(retokenized_output_lens) / dur_s, + mean_ttft_ms=np.mean(ttfts or 0) + * 1000, # ttfts is empty if streaming is not supported by backend + median_ttft_ms=np.median(ttfts or 0) * 1000, + std_ttft_ms=np.std(ttfts or 0) * 1000, + p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, + mean_tpot_ms=np.mean(tpots or 0) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + std_tpot_ms=np.std(tpots or 0) * 1000, + p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000, + mean_itl_ms=np.mean(itls or 0) * 1000, + median_itl_ms=np.median(itls or 0) * 1000, + std_itl_ms=np.std(itls or 0) * 1000, + p99_itl_ms=np.percentile(itls or 0, 99) * 1000, + mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000, + median_e2e_latency_ms=np.median(e2e_latencies) * 1000, + ) + + return metrics, output_lens + + +async def benchmark( + backend: str, + api_url: str, + model_id: str, + tokenizer: PreTrainedTokenizerBase, + input_requests: List[Tuple[str, int, int]], + request_rate: float, + disable_tqdm: bool, + enable_multi: bool, + extra_request_body: Dict[str, Any], +): + if backend in ASYNC_REQUEST_FUNCS: + request_func = ASYNC_REQUEST_FUNCS[backend] + else: + raise ValueError(f"Unknown backend: {backend}") + + print("Starting initial single prompt test run...") + test_prompt, test_prompt_len, test_output_len = input_requests[0] + + words = test_prompt.split() + + # 使用random.shuffle打乱单词列表 + random.shuffle(words) + + # 将打乱后的单词列表合并回文本 + test_prompt = ' '.join(words) + + + test_input = RequestFuncInput( + model=model_id, + prompt=test_prompt, + api_url=api_url, + prompt_len=test_prompt_len, + output_len=test_output_len, + extra_request_body=extra_request_body, + ) + test_output = await request_func(request_func_input=test_input) + if not test_output.success: + raise ValueError( + "Initial test run failed - Please make sure benchmark arguments " + f"are correctly specified. Error: {test_output.error}" + ) + else: + print("Initial test run completed. Starting main benchmark run...") + + pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + + benchmark_start_time = time.perf_counter() + tasks: List[asyncio.Task] = [] + async for request in get_request(input_requests, request_rate): + prompt, prompt_len, output_len = request + request_func_input = RequestFuncInput( + model=model_id, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + extra_request_body=extra_request_body, + ) + tasks.append( + asyncio.create_task( + request_func(request_func_input=request_func_input, pbar=pbar) + ) + ) + outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) + + if pbar is not None: + pbar.close() + + benchmark_duration = time.perf_counter() - benchmark_start_time + + metrics, output_lens = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + backend=backend, + ) + + print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) + print("{:<40} {:<10}".format("Backend:", backend)) + print("{:<40} {:<10}".format("Traffic request rate:", request_rate)) + print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) + print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) + print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) + print( + "{:<40} {:<10}".format( + "Total generated tokens (retokenized):", metrics.total_output_retokenized + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Request throughput (req/s):", metrics.request_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Input token throughput (tok/s):", metrics.input_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Output token throughput (tok/s):", metrics.output_throughput + ) + ) + print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-")) + print( + "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms) + ) + print( + "{:<40} {:<10.2f}".format( + "Median E2E Latency (ms):", metrics.median_e2e_latency_ms + ) + ) + print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-")) + print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms)) + print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms)) + print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms)) + print( + "{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-") + ) + print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms)) + print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms)) + print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms)) + print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-")) + print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms)) + print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms)) + print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms)) + print("=" * 50) + + if ( + metrics.median_ttft_ms is not None + and metrics.mean_itl_ms is not None + and metrics.output_throughput is not None + ): + result = { + "backend": args.backend, + "dataset_name": args.dataset_name, + "request_rate": request_rate, + "total_input": metrics.total_input, + "total_output": metrics.total_output, + "total_output_retokenized": metrics.total_output_retokenized, + "mean_e2e_latency": metrics.mean_e2e_latency_ms, + "median_e2e_latency": metrics.median_e2e_latency_ms, + "median_ttft": metrics.median_ttft_ms, + "median_itl": metrics.median_itl_ms, + "output_token_throughput": metrics.output_throughput, + "sharegpt_output_len": args.sharegpt_output_len, + "random_input_len": args.random_input_len, + "random_output_len": args.random_output_len, + "random_range_ratio": args.random_range_ratio, + "benchmark_duration": benchmark_duration, + } + else: + print(f"Error running benchmark for request rate: {request_rate}") + print("-" * 30) + + # Determine output file name + if args.output_file: + output_file_name = args.output_file + else: + now = datetime.now().strftime("%m%d") + if args.dataset_name == "random": + output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl" + else: + output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl" + + # Append results to a JSONL file + with open(output_file_name, "a") as file: + file.write(json.dumps(result) + "\n") + + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "total_output_tokens_retokenized": metrics.total_output_retokenized, + "request_throughput": metrics.request_throughput, + "input_throughput": metrics.input_throughput, + "output_throughput": metrics.output_throughput, + "mean_ttft_ms": metrics.mean_ttft_ms, + "median_ttft_ms": metrics.median_ttft_ms, + "std_ttft_ms": metrics.std_ttft_ms, + "p99_ttft_ms": metrics.p99_ttft_ms, + "mean_tpot_ms": metrics.mean_tpot_ms, + "median_tpot_ms": metrics.median_tpot_ms, + "std_tpot_ms": metrics.std_tpot_ms, + "p99_tpot_ms": metrics.p99_tpot_ms, + "mean_itl_ms": metrics.mean_itl_ms, + "median_itl_ms": metrics.median_itl_ms, + "std_itl_ms": metrics.std_itl_ms, + "p99_itl_ms": metrics.p99_itl_ms, + "input_lens": [output.prompt_len for output in outputs], + "output_lens": output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], + "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, + "median_e2e_latency_ms": metrics.median_e2e_latency_ms, + } + return result + + +def parse_request_rate_range(request_rate_range): + if len(request_rate_range.split(",")) == 3: + start, stop, step = map(int, request_rate_range.split(",")) + return list(range(start, stop, step)) + else: + return list(map(int, request_rate_range.split(","))) + + +def check_chat_template(model_path): + try: + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + return "chat_template" in tokenizer.init_kwargs + except Exception as e: + print(f"Fail to load tokenizer config with error={e}") + return False + + +def run_benchmark(args_: argparse.Namespace): + global args + args = args_ + + set_ulimit() + random.seed(args.seed) + np.random.seed(args.seed) + + extra_request_body = {} + if args.extra_request_body: + extra_request_body = json.loads(args.extra_request_body) + + if args.port is None: + args.port = { + "sglang": 30000, + "lmdeploy": 23333, + "vllm": 8000, + "trt": 8000, + }.get(args.backend, 30000) + + api_url = ( + f"{args.base_url}/v1/completions" + if args.base_url + else f"http://{args.host}:{args.port}/v1/completions" + ) + model_url = ( + f"{args.base_url}/v1/models" + if args.base_url + else f"http://{args.host}:{args.port}/v1/models" + ) + + if args.backend == "trt": + api_url = ( + f"{args.base_url}/v2/models/ensemble/generate_stream" + if args.base_url + else f"http://{args.host}:{args.port}/v2/models/ensemble/generate_stream" + ) + if args.model is None: + print("Please provide a model using `--model` when using `trt` backend.") + sys.exit(1) + + if args.model is None: + try: + response = requests.get(model_url) + model_list = response.json().get("data", []) + args.model = model_list[0]["id"] if model_list else None + except Exception as e: + print(f"Failed to fetch model from {model_url}. Error: {e}") + print( + "Please specify the correct host and port using `--host` and `--port`." + ) + sys.exit(1) + + if args.model is None: + print("No model specified or found. Please provide a model using `--model`.") + sys.exit(1) + + if not check_chat_template(args.model): + print( + "\nWARNING It is recommended to use the `Chat` or `Instruct` model for benchmarking.\n" + "Because when the tokenizer counts the output tokens, if there is gibberish, it might count incorrectly.\n" + ) + + print(f"{args}\n") + + backend = args.backend + model_id = args.model + tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model + + tokenizer = get_tokenizer(tokenizer_id) + + if args.dataset_name == "sharegpt": + input_requests = sample_sharegpt_requests( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + tokenizer=tokenizer, + fixed_output_len=args.sharegpt_output_len, + ) + elif args.dataset_name == "random": + input_requests = sample_random_requests( + input_len=args.random_input_len, + output_len=args.random_output_len, + num_prompts=args.num_prompts, + range_ratio=args.random_range_ratio, + tokenizer=tokenizer, + dataset_path=args.dataset_path, + ) + else: + raise ValueError(f"Unknown dataset: {args.dataset_name}") + + if args.multi: + request_rates = parse_request_rate_range(args.request_rate_range) + + for rate in request_rates: + asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_rate=rate, + disable_tqdm=args.disable_tqdm, + enable_multi=args.multi, + extra_request_body=extra_request_body, + ) + ) + else: + return asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_rate=args.request_rate, + disable_tqdm=args.disable_tqdm, + enable_multi=args.multi, + extra_request_body=extra_request_body, + ) + ) + + +# to avoid relying on SGLang's components +def set_ulimit(target_soft_limit=65535): + resource_type = resource.RLIMIT_NOFILE + current_soft, current_hard = resource.getrlimit(resource_type) + + if current_soft < target_soft_limit: + try: + resource.setrlimit(resource_type, (target_soft_limit, current_hard)) + except ValueError as e: + print(f"Fail to set RLIMIT_NOFILE: {e}") + + +if __name__ == "__main__": + parser = ArgumentParser(description="Benchmark the online serving throughput.") + parser.add_argument( + "--backend", + type=str, + choices=list(ASYNC_REQUEST_FUNCS.keys()), + default="sglang", + help="Must specify a backend, depending on the LLM Inference Engine.", + ) + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server or API base url if not using http host and port.", + ) + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0." + ) + parser.add_argument( + "--port", + type=int, + help="If not set, the default port is configured according to its default value for different LLM Inference Engines.", + ) + parser.add_argument( + "--dataset-name", + type=str, + default="sharegpt", + choices=["sharegpt", "random"], + help="Name of the dataset to benchmark on.", + ) + parser.add_argument( + "--dataset-path", type=str, default="", help="Path to the dataset." + ) + parser.add_argument( + "--model", + type=str, + help="Name or path of the model. If not set, the default model will request /v1/models for conf.", + ) + parser.add_argument( + "--tokenizer", + type=str, + help="Name or path of the tokenizer. If not set, using the model conf.", + ) + parser.add_argument( + "--num-prompts", + type=int, + default=1000, + help="Number of prompts to process. Default is 1000.", + ) + parser.add_argument( + "--sharegpt-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output length from the ShareGPT dataset.", + ) + parser.add_argument( + "--random-input-len", + type=int, + default=1024, + help="Number of input tokens per request, used only for random dataset.", + ) + parser.add_argument( + "--random-output-len", + type=int, + default=128, + help="Number of output tokens per request, used only for random dataset.", + ) + parser.add_argument( + "--random-range-ratio", + type=float, + default=0.0, + help="Range of sampled ratio of input/output length, " + "used only for random dataset.", + ) + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, then all the requests are sent at time 0. " + "Otherwise, we use Poisson process to synthesize the request arrival times. Default is 128.0.", + ) + parser.add_argument("--seed", type=int, default=0, help="Default is 0.") + parser.add_argument( + "--multi", + action="store_true", + help="Use request rate range rather than single value.", + ) + parser.add_argument( + "--request-rate-range", + type=str, + default="2,34,2", + help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.", + ) + parser.add_argument("--output-file", type=str, help="Output JSONL file name.") + parser.add_argument( + "--disable-tqdm", + action="store_true", + help="Specify to disable tqdm progress bar.", + ) + parser.add_argument( + "--disable-stream", + action="store_true", + help="Disable streaming mode.", + ) + parser.add_argument( + "--disable-ignore-eos", + action="store_true", + help="Disable ignoring EOS.", + ) + parser.add_argument( + "--extra-request-body", + metavar='{"key1": "value1", "key2": "value2"}', + type=str, + help="Append given JSON object to the request payload. You can use this to specify" + "additional generate params like sampling params.", + ) + + + parser.add_argument( + "--search-qps-start", + type=float, + default=0.0, + help="Binary search best QPS, start request_rate, real request rate will multi dp_size", + ) + + parser.add_argument( + "--search-qps-end", + type=int, + default=0, + help="Binary search best QPS, end request_rate, real request rate will multi dp_size", + ) + + parser.add_argument( + "--search-qps-thread", + type=int, + default=0, + help="The qps value that we want(ms)", + ) + + + args = parser.parse_args() + + mean_ttft_ms = 0.0 + search_qps_thread = args.search_qps_thread + search_qps_start = args.search_qps_start + search_qps_end = args.search_qps_end + + + qps_result_dict = [] + while abs(mean_ttft_ms - search_qps_thread) < 500 and search_qps_start < search_qps_end: # 误差在0.5s以内都可以 + mid_qps = (search_qps_start + search_qps_end) / 2 + args.request_rate = mid_qps + result = run_benchmark(args) + mean_ttft_ms = float(result['mean_ttft_ms']) + + qps_result_dict.append({ + "request_rate": mid_qps, + "mean_ttft_ms": mean_ttft_ms + }) + + if mean_ttft_ms < search_qps_thread: + # 增加mean_ttft_ms就是增加request_rate + search_qps_end = mid_qps + else: + search_qps_start = mid_qps + + + print(f"request_rate\t\tmean_ttft_ms") + for item in qps_result_dict: + print(f"{item['request_rate']}\t\t{item['mean_ttft_ms']}") \ No newline at end of file diff --git a/python/sglang/srt/managers/controller_flex.py b/python/sglang/srt/managers/controller_flex.py index 14da2449fc..ba9a6ccaa4 100644 --- a/python/sglang/srt/managers/controller_flex.py +++ b/python/sglang/srt/managers/controller_flex.py @@ -17,17 +17,48 @@ A controller that manages multiple data parallel workers. Each data parallel worker can manage multiple tensor parallel workers. """ - import dataclasses import logging import multiprocessing import multiprocessing.shared_memory import os +import random +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from enum import Enum, auto import numpy as np +import torch import zmq + +def _key_match(key0, key1): + i = 0 + for k0, k1 in zip(key0, key1): + if k0 != k1: + break + i += 1 + return i + + +def get_match_len(node, key, match_length: int) -> int: + if len(key) == 0: + return match_length + + if key[0] in node.children.keys(): + child = node.children[key[0]] + prefix_len = _key_match(child.key, key) + match_length += prefix_len + if prefix_len < len(child.key): + return match_length + else: + return get_match_len(child, key[prefix_len:], match_length) + else: + return match_length + + +import threading +import time + from sglang.srt.managers.controller_single import ( start_controller_process as start_controller_process_single, ) @@ -50,6 +81,8 @@ class LoadBalanceMethod(Enum): ROUND_ROBIN = auto() SHORTEST_QUEUE = auto() RESOURCES_AWARE = auto() + POWER_OF_2_CHOICE = auto() + PRE_RADIX = auto() @classmethod def from_str(cls, method: str): @@ -94,21 +127,40 @@ def __init__( self.recv_from_tokenizer = context.socket(zmq.PULL) self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}") + self.recv_from_tree_cache = context.socket(zmq.PULL) + self.recv_from_tree_cache.setsockopt(zmq.RCVHWM, 1000) + self.recv_from_tree_cache.bind(f"tcp://127.0.0.1:41935") + + self.pre_radix = server_args.load_balance_method == "pre_radix" + self.dp_size = server_args.dp_size + # Dispatch method self.round_robin_counter = 0 dispatch_lookup = { LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler, LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler, LoadBalanceMethod.RESOURCES_AWARE: self.resources_aware_scheduler, + LoadBalanceMethod.POWER_OF_2_CHOICE: self.power_of_2_choice, + LoadBalanceMethod.PRE_RADIX: self.pre_radix_scheduler, } self.dispatching = dispatch_lookup[self.load_balance_method] + self.newest_tree_cache = {} + # Start data parallel workers self.workers = [] self.controller_info = ControllerInfo(server_args, model_overide_args) for i in range(server_args.dp_size): self.start_dp_worker(i) + self.scheduler_time = 0 + + self.cnt = 0 + + if self.pre_radix: + self.recv_tree_cache_lock = threading.Lock() + threading.Thread(target=self.loop_for_recv_tree_cache).start() + def start_dp_worker(self, dp_worker_id: int): tp_size = self.server_args.tp_size @@ -146,17 +198,218 @@ def start_dp_worker(self, dp_worker_id: int): ) ) + def compute_prefix_length(self, gpu_id, radix_cache, input_ids): + return gpu_id, get_match_len(radix_cache.root_node, input_ids, 0) + + def pre_radix_scheduler(self, input_requests): + if len(input_requests) == 0: + return + + # available_mem = [k.value for k in self.controller_info.available_kv_cache] + # num_reqs_waiting = [k.value for k in self.controller_info.waiting_reqs] + # num_reqs_running = [k.value for k in self.controller_info.running_reqs] + + # all_waitting = False + # if min(num_reqs_waiting) > 0: + # # 最小值都大于0,全部waiting + # all_waitting = True + # else: + # # 最小值都是0, 则全部waiting + # all_waitting = False + # # 选出不waiting + # no_waiting = [1 if waiting == 0 else 0 for waiting in num_reqs_waiting] + + # num_reqs_waiting = [k.value for k in self.controller_info.waiting_reqs] + + for r in input_requests: + prefix_lens = [0] * self.dp_size + + with self.recv_tree_cache_lock: + for gpu_id, radix_cache in self.newest_tree_cache.items(): + # t_1 = time.time() + pre_len = get_match_len(radix_cache.root_node, r.input_ids, 0) + # t_2 = time.time() + prefix_lens[gpu_id] = pre_len + + # with ThreadPoolExecutor() as executor: + # futures = [] + # for gpu_id, radix_cache in self.newest_tree_cache.items(): + # future = executor.submit( + # self.compute_prefix_length, + # gpu_id, + # radix_cache, + # r.input_ids, + # ) + # futures.append(future) + + # for future in futures: + # gpu_id, pre_len = future.result() + # prefix_lens[gpu_id] = pre_len + + # t4 = time.time() + # with open("match.log", "a+") as f: + # f.write(f"[rid={r.rid[:5]}]{prefix_lens}\n") + + # t7 = time.time() + max_len = max(prefix_lens) + max_len_indices = [i for i, x in enumerate(prefix_lens) if x == max_len] + # t8 = time.time() + + # logger.info(f"find max idx = {t8 - t7}") + + if len(max_len_indices) == 1: + # t9 = time.time() + selected_worker_index = max_len_indices[0] + self.workers[selected_worker_index].queue.put(r) + # t10 = time.time() + # logger.info(f"len one = {t10 - t9}") + # t5 = time.time() + # logger.info(f"if time = {t5 - t4}") + else: + self.resources_aware_scheduler([r]) + # t11 = time.time() + # if all_waitting: + # # 全部waiting,选最小的 + + # ratio = [ + # run / wait + # for run, wait in zip(num_reqs_running, num_reqs_waiting) + # ] + + # # run越大 认为后续释放的可能性越多,wait越少,说明后续计算能力更强 + # min_value = max(ratio) + # # 找到所有最小值的索引 + # min_indices = [i for i, x in enumerate(ratio) if x == min_value] + # # 从这些索引中随机选择一个 + # index = random.choice(min_indices) + # # 从waitting最小的找到available最大的 + # # index = max(min_indices, key=lambda i: available_mem[i]) + # # index = min(min_indices, key=lambda i: num_reqs_running[i]) + # self.workers[index].queue.put(r) + # num_reqs_waiting[index] += 1 + # available_mem[index] -= len(r.input_ids) + + # else: + # # 选出不waiting的且available mem最大的 + # # no_waiting 和available做乘法,找最大 + + # filter_result = [a * b for a, b in zip(no_waiting, available_mem)] + # index = filter_result.index(max(filter_result)) + # self.workers[index].queue.put(r) + + # # num_reqs_running[index] += 1 + # available_mem[index] -= len(r.input_ids) + # t12 = time.time() + # logger.info(f"len two = {t12 - t11}") + # t5 = time.time() + # logger.info(f"else time = {t5 - t4}") + # t6 = time.time() + # logger.info(f"real dispatch time = {t6 - t8}") + def resources_aware_scheduler(self, input_requests): if len(input_requests) == 0: return - remained_token = [k.value for k in self.controller_info.current_bs] + # remained_token = [k.value for k in self.controller_info.waiting_prefill_compute] + available_mem = [k.value for k in self.controller_info.available_kv_cache] + num_reqs_waiting = [k.value for k in self.controller_info.waiting_reqs] + num_reqs_running = [k.value for k in self.controller_info.running_reqs] + # with open('three_list.txt', 'a') as file: # 'a' 模式表示追加到文件末尾 + # print(f"available_mem={available_mem""" """}\nnum_reqs_waiting={num_reqs_waiting}\nnum_reqs_running={num_reqs_running}\n") + + # ava_resource = available_mem.copy() + # =======================method2======================= + # # 认为available + waiting为可用资源 + # for i in range(len(self.workers)): + # q = self.workers[i].queue + # qsize = q.qsize() + # for _ in range(qsize): + # req = q.get() + # ava_resource[i] += len(req.input_ids) + # q.put(req) # 将元素重新放回原队列 + + # # 选择ava最大的调度 + # for r in input_requests: + # index = ava_resource.index(max(ava_resource)) + # self.workers[index].queue.put(r) + # ava_resource[index] -= len(r.input_ids) + + # =======================method2======================= + + # =======================method1======================= + + # 判断是否是全部waiting + all_waitting = False + if min(num_reqs_waiting) > 0: + # 最小值都大于0,全部waiting + all_waitting = True + else: + # 最小值都是0, 则全部waiting + all_waitting = False + # 选出不waiting + no_waiting = [1 if waiting == 0 else 0 for waiting in num_reqs_waiting] + for r in input_requests: + # t1 = time.time() + if all_waitting: + # 全部waiting,选最小的 + + ratio = [ + run / wait for run, wait in zip(num_reqs_running, num_reqs_waiting) + ] + + # run越大 认为后续释放的可能性越多,wait越少,说明后续计算能力更强 + min_value = max(ratio) + # 找到所有最小值的索引 + min_indices = [i for i, x in enumerate(ratio) if x == min_value] + # 从这些索引中随机选择一个 + index = random.choice(min_indices) + # 从waitting最小的找到available最大的 + # index = max(min_indices, key=lambda i: available_mem[i]) + # index = min(min_indices, key=lambda i: num_reqs_running[i]) + self.workers[index].queue.put(r) + num_reqs_waiting[index] += 1 + available_mem[index] -= len(r.input_ids) + else: + # 选出不waiting的且available mem最大的 + # no_waiting 和available做乘法,找最大 + + filter_result = [a * b for a, b in zip(no_waiting, available_mem)] + index = filter_result.index(max(filter_result)) + self.workers[index].queue.put(r) + + # num_reqs_running[index] += 1 + available_mem[index] -= len(r.input_ids) + # t2 = time.time() + # logger.info(f"real dispatch time = {t2 - t1}") + + # =======================method1======================= + + def power_of_2_choice(self, input_requests): + if len(input_requests) == 0: + return + num_reqs_waiting = [k.value for k in self.controller_info.waiting_reqs] + num_reqs_running = [k.value for k in self.controller_info.running_reqs] + available_mem = [k.value for k in self.controller_info.available_kv_cache] + + instances_len = len(self.workers) + + # 比较两个worker的指标 + def compare_metrics(ins1, ins2): + if num_reqs_waiting[ins1] != num_reqs_waiting[ins2]: + return ins1 if num_reqs_waiting[ins1] < num_reqs_waiting[ins2] else ins2 + if num_reqs_running[ins1] != num_reqs_running[ins2]: + return ins1 if num_reqs_running[ins1] < num_reqs_running[ins2] else ins2 + if available_mem[ins1] != available_mem[ins2]: + return ins1 if available_mem[ins1] > available_mem[ins2] else ins2 + return ins1 + for r in input_requests: - index = remained_token.index(min(remained_token)) - self.workers[index].queue.put(r) - remained_token[index] += len(r.input_ids) - with self.controller_info.lock: - for i, v in enumerate(remained_token): - self.controller_info.current_bs[i].value = v + # 随机选两个worker + ins1, ins2 = random.sample(range(0, instances_len), 2) + ins_end = compare_metrics(ins1, ins2) + self.workers[ins_end].queue.put(r) + # available_mem[ins_end] -= len(r.input_ids) + # num_reqs_running[ins_end] += 1 + # num_reqs_waiting[ins_end] += 1 def round_robin_scheduler(self, input_requests): for r in input_requests: @@ -174,7 +427,48 @@ def shortest_queue_scheduler(self, input_requests): def loop_for_forward(self): while True: recv_reqs = self.recv_requests() - self.dispatching(recv_reqs) + + if len(recv_reqs) != 0: + # logger.info(f"len requests=[{len(recv_reqs)}]") + t1 = time.time() + + # if self.pre_radix: + # self.recv_tree_cache() + + self.dispatching(recv_reqs) + t2 = time.time() + logger.info(f"scheduler time = {t2 - t1}") + + def loop_for_recv_tree_cache(self): + while True: + self.recv_tree_cache() + + def recv_tree_cache(self): + flag = False + while True: + try: + recv_radix_cache = self.recv_from_tree_cache.recv_pyobj(zmq.NOBLOCK) + except zmq.ZMQError: + break + + gpu_id = recv_radix_cache.gpu_id + + if ( + gpu_id not in self.newest_tree_cache + or recv_radix_cache.time > self.newest_tree_cache[gpu_id].time + ): + with self.recv_tree_cache_lock: + if gpu_id in self.newest_tree_cache: + del self.newest_tree_cache[gpu_id] + self.newest_tree_cache[gpu_id] = recv_radix_cache + flag = True + + del recv_radix_cache + # 使用日志记录器记录信息 + if flag: + # logger.info(f"latest_cache={len(self.newest_tree_cache)}") + pass + torch.cuda.empty_cache() # 清空未被引用的显存 def recv_requests(self): recv_reqs = [] diff --git a/python/sglang/srt/managers/controller_single.py b/python/sglang/srt/managers/controller_single.py index 2ada4fa2aa..53b52adcfc 100644 --- a/python/sglang/srt/managers/controller_single.py +++ b/python/sglang/srt/managers/controller_single.py @@ -31,6 +31,7 @@ ) from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import kill_parent_process +from sglang.srt.managers.io_struct import ControllerInfo from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index de9c885247..7f4fe7e65d 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -20,12 +20,15 @@ import multiprocessing import uuid +import multiprocessing from dataclasses import dataclass from multiprocessing import Value from typing import Dict, List, Optional, Union +from multiprocessing import Value import numpy as np import torch +import numpy as np from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.sampling_params import SamplingParams @@ -272,15 +275,18 @@ class DetokenizeReqInput: class ControllerInfo: def __init__(self, server_args, model_overide_args): self.available_kv_cache = [] - self.current_bs = [] + self.waiting_prefill_compute = [] + self.running_reqs = [] + self.waiting_reqs = [] self.swap_in_queue = [] self.lock = multiprocessing.Lock() for i in range(server_args.dp_size): self.available_kv_cache.append(Value("i", 0)) - self.current_bs.append(Value("i", 0)) + self.waiting_prefill_compute.append(Value("i", 0)) + self.running_reqs.append(Value("i", 0)) + self.waiting_reqs.append(Value("i", 0)) self.swap_in_queue.append(multiprocessing.Queue()) self.swap_out_queue = multiprocessing.Queue() - cache_shape = get_cache_info(server_args, model_overide_args) # TODO: Make it editable by user @kavioyu diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index a461fa1812..d72604ef26 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -432,6 +432,14 @@ def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor reqs = self.reqs input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] extend_num_tokens = sum(len(ids) for ids in input_ids) + # num = 0 + # for req in reqs: + # num += len(req.origin_input_ids) + # if num != extend_num_tokens: + # print("*" * 100) + # print(extend_num_tokens) + # for req in reqs: + # print(len(req.origin_input_ids)) seq_lens = [] # Allocate memory diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 7d16cb59af..8a75eed871 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -78,14 +78,16 @@ def __init__( server_args: ServerArgs, nccl_port: int, model_overide_args: dict, - controller_info: ControllerInfo, - dp_worker_id: int, + controller_info: Optional[ControllerInfo] = None, + dp_worker_id: Optional[int] = None, ): + suppress_other_loggers() # Copy arguments self.gpu_id = gpu_id self.tp_rank = tp_rank + self.dp_rank = dp_worker_id self.tp_size = server_args.tp_size self.dp_size = server_args.dp_size @@ -120,6 +122,9 @@ def __init__( self.swap_cache = torch.frombuffer( buffer=shm.buf, dtype=self.model_runner.dtype ).reshape(self.controller_info.cache_shape) + self.controller_info.available_kv_cache[self.dp_rank].value = ( + self.model_runner.token_to_kv_pool.available_size() + ) else: self.controller_info = None @@ -181,6 +186,8 @@ def __init__( req_to_token_pool=self.model_runner.req_to_token_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool, disable=server_args.disable_radix_cache, + gpu_id=gpu_id, + pre_radix=(server_args.load_balance_method == "pre_radix"), ) self.tree_cache_metrics = {"total": 0, "hit": 0} self.scheduler = PolicyScheduler(self.schedule_policy, self.tree_cache) @@ -192,6 +199,7 @@ def __init__( self.running_batch: ScheduleBatch = None self.out_pyobjs = [] self.decode_forward_ct = 0 + self.forward_ct = 0 self.stream_interval = server_args.stream_interval self.num_generated_tokens = 0 self.last_stats_tic = time.time() @@ -268,8 +276,9 @@ def forward_step(self): self.forward_decode_batch(self.running_batch) # Print stats - if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0: + if self.tp_rank == 0 and self.decode_forward_ct % 10 == 0: self.print_decode_stats() + pass if self.running_batch.is_empty(): self.running_batch = None @@ -297,6 +306,13 @@ def print_decode_stats(self): f"#queue-req: {len(self.waiting_queue)}" ) + with open( + f"token_usage_gpu_{self.gpu_id}.log", mode="a+", encoding="utf-8" + ) as f: + f.write( + f"{self.gpu_id}\t\t{num_used / self.max_total_num_tokens:.2f}\t\t{len(self.waiting_queue)}\n" + ) + def check_memory(self): available_size = ( self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() @@ -430,9 +446,13 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: adder.log_input_tokens + adder.log_hit_tokens ) / 10**9 self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9 - tree_cache_hit_rate = ( - self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"] - ) + + try: + tree_cache_hit_rate = ( + self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"] + ) + except ZeroDivisionError: + tree_cache_hit_rate = 1.0 logger.info( f"[gpu={self.gpu_id}] Prefill batch. " f"#new-seq: {len(can_run_list)}, " @@ -461,11 +481,25 @@ def forward_prefill_batch(self, batch: ScheduleBatch): if self.controller_info: num = 0 - for r in batch.reqs: - num += len(r.origin_input_ids) + for req in batch.reqs: + num += len(req.origin_input_ids) with self.controller_info.lock: - self.controller_info.current_bs[self.dp_rank].value -= num + self.controller_info.waiting_prefill_compute[self.dp_rank].value -= num + self.controller_info.available_kv_cache[self.dp_rank].value = ( + self.token_to_kv_pool.available_size() + + self.tree_cache.evictable_size() + ) + self.controller_info.running_reqs[ + self.dp_rank + ].value = batch.batch_size() + ( + self.running_batch.batch_size() + if self.running_batch is not None + else 0 + ) + self.controller_info.waiting_reqs[self.dp_rank].value = len( + self.waiting_queue + ) if self.model_runner.is_generation: # Forward and sample the next tokens if batch.extend_num_tokens != 0: @@ -605,17 +639,26 @@ def forward_decode_batch(self, batch: ScheduleBatch): self.new_token_ratio = new_token_ratio logger.info( - "decode out of memory happened, " + f"[gpu{self.gpu_id}]decode out of memory happened, " f"#retracted_reqs: {len(retracted_reqs)}, " f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}" ) + + num = 0 + for req in retracted_reqs: + num += len(req.fill_ids) + + if self.controller_info is not None: + with self.controller_info.lock: + self.controller_info.waiting_prefill_compute[ + self.dp_rank + ].value += num self.waiting_queue.extend(retracted_reqs) else: self.new_token_ratio = max( self.new_token_ratio - self.new_token_ratio_decay, self.min_new_token_ratio, ) - if not self.disable_regex_jump_forward: # Check for jump-forward jump_forward_reqs = batch.check_for_jump_forward(self.model_runner) @@ -627,6 +670,20 @@ def forward_decode_batch(self, batch: ScheduleBatch): self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30) batch.prepare_for_decode() + if self.controller_info is not None and self.decode_forward_ct % 10 == 0: + with self.controller_info.lock: + self.controller_info.available_kv_cache[self.dp_rank].value = ( + self.token_to_kv_pool.available_size() + + self.tree_cache.evictable_size() + ) + self.controller_info.running_reqs[self.dp_rank].value = ( + batch.batch_size() + ) + + self.controller_info.waiting_reqs[self.dp_rank].value = len( + self.waiting_queue + ) + # Forward and sample the next tokens output = self.model_runner.forward(batch, ForwardMode.DECODE) next_token_ids = batch.sample(output.next_token_logits) diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index a1c685405a..a089ee5285 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -55,20 +55,87 @@ def _key_match(key0: List, key1: List): return i +import threading +from copy import deepcopy +from dataclasses import dataclass + +import zmq + + +@dataclass +class RadixCacheSend: + gpu_id: int + root_node: TreeNode + time: time + + class RadixCache(BasePrefixCache): def __init__( self, req_to_token_pool: ReqToTokenPool, token_to_kv_pool: BaseTokenToKVPool, disable: bool = False, + gpu_id: int = 0, + pre_radix: bool = False, ): self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool = token_to_kv_pool self.disable = disable + self.pre_radix = pre_radix + self.gpu_id = gpu_id + self.send_cnt = 0 + self.change_cnt = 0 + if pre_radix: + context = zmq.Context() + self.send_radix_tree = context.socket(zmq.PUSH) + self.send_radix_tree.setsockopt(zmq.SNDHWM, 1000) + self.send_radix_tree.connect(f"tcp://127.0.0.1:41935") + + self.change_cnt_lock = threading.Lock() + + threading.Thread(target=self.loop_for_send_tree_cache).start() + + else: + print(f"dp[{self.gpu_id}] will not use pre_radix....") + self.reset() ##### Public API ##### + def loop_for_send_tree_cache(self): + while True: + with self.change_cnt_lock: + if self.change_cnt != 0: + self.change_cnt -= 1 + self.send_prefix_tree() + time.sleep(0.5) + + def send_prefix_tree(self): + # t1 = time.time() + if self.pre_radix: + self.send_cnt += 1 + try: + try: + node = deepcopy(self.root_node) + except Exception as e: + return + self.send_radix_tree.send_pyobj( + RadixCacheSend( + gpu_id=self.gpu_id, root_node=node, time=time.time() + ), + zmq.NOBLOCK, + ) + # if self.send_cnt % 10 == 0: + # print(f"[{self.gpu_id}] has send [{self.send_cnt}] caches") + del node + # torch.cuda.empty_cache() + except zmq.Again as e: + print( + "=======================================Radix Cache Queue is full, drop out new radix cache tree=======================================" + ) + # t2 = time.time() + # print(f"send radix time = {t2 - t1}") + def reset(self): self.root_node = TreeNode() self.root_node.key = [] @@ -76,6 +143,8 @@ def reset(self): self.root_node.lock_ref = 1 self.evictable_size_ = 0 + # self.send_prefix_tree() + def match_prefix(self, key: List, **kwargs): if self.disable: return [], self.root_node @@ -95,7 +164,12 @@ def insert(self, key: List, value=None): if value is None: value = [x for x in key] - return self._insert_helper(self.root_node, key, value) + res = self._insert_helper(self.root_node, key, value) + + if self.pre_radix: + with self.change_cnt_lock: + self.change_cnt += 1 + return res def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None): """Cache request when it finishes.""" @@ -118,6 +192,8 @@ def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None): self.req_to_token_pool.free(req.req_pool_idx) self.dec_lock_ref(req.last_node) + # self.send_prefix_tree() + def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None): """Cache request when it is unfinished.""" if self.disable: @@ -146,6 +222,8 @@ def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None): req.prefix_indices = new_indices req.last_node = new_last_node + # self.send_prefix_tree() + def pretty_print(self): self._print_helper(self.root_node, 0) print(f"#tokens: {self.total_size()}") @@ -175,6 +253,10 @@ def evict(self, num_tokens: int, evict_callback: Callable): if len(x.parent.children) == 0: heapq.heappush(leaves, x.parent) + # self.send_prefix_tree() + if self.pre_radix: + with self.change_cnt_lock: + self.change_cnt += 1 def inc_lock_ref(self, node: TreeNode): if self.disable: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6fd11d1345..0ecdd2ba9d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -349,6 +349,8 @@ def add_cli_args(parser: argparse.ArgumentParser): "round_robin", "shortest_queue", "resources_aware", + "power_of_2_choice", + "pre_radix", ], ) diff --git a/python/sglang/utils.py b/python/sglang/utils.py index f7015d8796..a0035b29ec 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -15,6 +15,8 @@ import numpy as np import requests +from sglang.srt.model_config import ModelConfig, AttentionArch +from sglang.srt.server_args import ServerArgs from sglang.srt.model_config import AttentionArch, ModelConfig from sglang.srt.server_args import ServerArgs @@ -278,4 +280,4 @@ def __getattr__(self, name: str): def __call__(self, *args, **kwargs): module = self._load() - return module(*args, **kwargs) + return module(*args, **kwargs) \ No newline at end of file