diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index e3a2ad0a2c..e67871b28d 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -311,15 +311,25 @@ def download_sharegpt_dataset(path): raise Exception(f"Failed to download dataset: {e}") +import pickle + + def sample_sharegpt_requests( dataset_path: str, num_requests: int, tokenizer: PreTrainedTokenizerBase, - fixed_output_len: Optional[int] = None, + max_seqlen: int, ) -> List[Tuple[str, int, int]]: - if fixed_output_len is not None and fixed_output_len < 4: - raise ValueError("output_len too small") - + cache_path = f"./input_cache_v2_{num_requests}" + # 尝试加载缓存的 input_requests + if os.path.isfile(cache_path): + with open(cache_path, "rb") as f: + input_requests = pickle.load(f) + print("Loaded input_requests from cache.") + return input_requests + prompts = [] + prompt_lens = [] + response_lens = [] # Download sharegpt if necessary if not os.path.isfile(dataset_path) and not os.path.isfile(default_sharegpt_path): download_sharegpt_dataset(default_sharegpt_path) @@ -331,42 +341,52 @@ def sample_sharegpt_requests( # Load the dataset. with open(dataset_path) as f: - dataset = json.load(f) - # Filter out the conversations with less than 2 turns. - dataset = [data for data in dataset if len(data["conversations"]) >= 2] - # Only keep the first two turns of each conversation. - dataset = [ - (data["conversations"][0]["value"], data["conversations"][1]["value"]) - for data in dataset - ] - - # Shuffle the dataset. - random.shuffle(dataset) - - # Filter out sequences that are too long or too short - filtered_dataset: List[Tuple[str, int, int]] = [] - for i in range(len(dataset)): - if len(filtered_dataset) == num_requests: - break - - # Tokenize the prompts and completions. - prompt = dataset[i][0] - prompt_token_ids = tokenizer(prompt).input_ids - completion = dataset[i][1] - completion_token_ids = tokenizer(completion).input_ids - prompt_len = len(prompt_token_ids) - output_len = ( - len(completion_token_ids) if fixed_output_len is None else fixed_output_len - ) - if prompt_len < 4 or output_len < 4: - # Prune too short sequences. - continue - if prompt_len > 1024 or prompt_len + output_len > 2048: - # Prune too long sequences. - continue - filtered_dataset.append((prompt, prompt_len, output_len)) + datasets = json.load(f) + for data in datasets: + if len(data["conversations"]) >= 2: + prompt = data["conversations"][0]["value"] + res = data["conversations"][1]["value"] + prompt_token_ids = tokenizer(prompt).input_ids + completion_token_ids = tokenizer(res).input_ids + + if ( + len(prompt_token_ids) + len(completion_token_ids) < max_seqlen + and len(prompt_token_ids) > 0 + and len(completion_token_ids) > 0 + ): + prompts.append(prompt) + prompt_lens.append(len(prompt_token_ids)) + response_lens.append(len(completion_token_ids)) + if len(prompts) > num_requests: + break + + sampled_ids = [random.randint(0, len(prompts) - 1) for _ in range(num_requests)] + sampled_prompts = [prompts[idx] for idx in sampled_ids] + sampled_prompts_lens = [prompt_lens[idx] for idx in sampled_ids] + sampled_response_lens = [response_lens[idx] for idx in sampled_ids] + + for i, (prompt_len, gen_len) in enumerate( + zip(sampled_prompts_lens, sampled_response_lens) + ): + total = prompt_len + gen_len + if total > max_seqlen: + print(f"truncating long prompt+gen_len {prompt_len=} {gen_len=}") + gen_len = max_seqlen - prompt_len + sampled_response_lens[i] = gen_len + input_requests = list( + zip(sampled_prompts, sampled_prompts_lens, sampled_response_lens) + ) + + + with open(cache_path, "wb") as f: + pickle.dump(input_requests, f) + print(f"Saved input_requests_{num_requests} to cache.") + print(f"#Input tokens: {np.sum(sampled_prompts_lens)}") + print(f"#Output tokens: {np.sum(sampled_response_lens)}") + return input_requests + - return filtered_dataset +import pickle def sample_random_requests( @@ -377,7 +397,13 @@ def sample_random_requests( tokenizer: PreTrainedTokenizerBase, dataset_path: str, ) -> List[Tuple[str, int, int]]: - + cache_path = f"./input_cache_{num_prompts}" + # 尝试加载缓存的 input_requests + if os.path.isfile(cache_path): + with open(cache_path, "rb") as f: + input_requests = pickle.load(f) + print("Loaded input_requests from cache.") + return input_requests input_lens = np.random.randint( max(int(input_len * range_ratio), 1), input_len + 1, @@ -444,7 +470,9 @@ def sample_random_requests( ] ) input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) - + with open(cache_path, "wb") as f: + pickle.dump(input_requests, f) + print(f"Saved input_requests_{num_prompts} to cache.") print(f"#Input tokens: {np.sum(input_lens)}") print(f"#Output tokens: {np.sum(output_lens)}") return input_requests @@ -483,6 +511,9 @@ def calculate_metrics( tpots: List[float] = [] ttfts: List[float] = [] e2e_latencies: List[float] = [] + + input_lens: List[float] = [] + for i in range(len(outputs)): if outputs[i].success: output_len = outputs[i].output_len @@ -492,6 +523,9 @@ def calculate_metrics( ) retokenized_output_lens.append(retokenized_output_len) total_input += input_requests[i][1] + + input_lens.append(input_requests[i][1]) + if output_len > 1: tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1)) itls += outputs[i].itl @@ -510,6 +544,11 @@ def calculate_metrics( "on the benchmark arguments.", stacklevel=2, ) + + # metric_data = [input_lens, output_lens, ttfts] + # with open(f'metrics_{time.time()}.json', 'w') as f: + # json.dump(metric_data, f) + metrics = BenchmarkMetrics( completed=completed, total_input=total_input, @@ -557,6 +596,15 @@ async def benchmark( print("Starting initial single prompt test run...") test_prompt, test_prompt_len, test_output_len = input_requests[0] + + words = test_prompt.split() + + # 使用random.shuffle打乱单词列表 + random.shuffle(words) + + # 将打乱后的单词列表合并回文本 + test_prompt = " ".join(words) + test_input = RequestFuncInput( model=model_id, prompt=test_prompt, @@ -565,14 +613,6 @@ async def benchmark( output_len=test_output_len, extra_request_body=extra_request_body, ) - test_output = await request_func(request_func_input=test_input) - if not test_output.success: - raise ValueError( - "Initial test run failed - Please make sure benchmark arguments " - f"are correctly specified. Error: {test_output.error}" - ) - else: - print("Initial test run completed. Starting main benchmark run...") pbar = None if disable_tqdm else tqdm(total=len(input_requests)) @@ -731,6 +771,31 @@ async def benchmark( "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, "median_e2e_latency_ms": metrics.median_e2e_latency_ms, } + + balance_method = os.getenv("LOAD_BALANCE_METHOD") + new_item = { + "method": balance_method, + "mean_ttft": metrics.mean_ttft_ms, + "request_rate": request_rate, + "request_throughput": metrics.request_throughput, + "p99_ttft_ms": metrics.p99_ttft_ms, + "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, + "time": datetime.now().isoformat(), + } + file_name = f"{balance_method}_result.json" + if not os.path.exists(file_name): + with open(file_name, "w") as f: + json.dump([], f) + + with open(file_name, "r") as f: + tmp_data = json.load(f) + + tmp_data.append(new_item) + + with open(file_name, "w") as f: + json.dump(tmp_data, f, indent=4) + + print(f"add new item to {file_name}: {new_item}") return result @@ -827,7 +892,7 @@ def run_benchmark(args_: argparse.Namespace): dataset_path=args.dataset_path, num_requests=args.num_prompts, tokenizer=tokenizer, - fixed_output_len=args.sharegpt_output_len, + max_seqlen=args.sharegpt_max_seqlen, ) elif args.dataset_name == "random": input_requests = sample_random_requests( @@ -859,6 +924,7 @@ def run_benchmark(args_: argparse.Namespace): ) ) else: + return asyncio.run( benchmark( backend=backend, @@ -935,6 +1001,12 @@ def set_ulimit(target_soft_limit=65535): default=1000, help="Number of prompts to process. Default is 1000.", ) + parser.add_argument( + "--sharegpt-max-seqlen", + type=int, + default=8192, + help="Number of max request len. Default is 8192.", + ) parser.add_argument( "--sharegpt-output-len", type=int, diff --git a/python/sglang/bench_serving_improve.py b/python/sglang/bench_serving_improve.py new file mode 100644 index 0000000000..7d63bcdcf9 --- /dev/null +++ b/python/sglang/bench_serving_improve.py @@ -0,0 +1,1028 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py +# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py + +""" +Benchmark online serving. + +Usage: +python3 -m sglang.bench_serving --backend sglang --num-prompt 10 + +python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5 +python3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi +""" + +import argparse +import asyncio +import json +import os +import random +import resource +import sys +import time +import traceback +import warnings +from argparse import ArgumentParser +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union +from concurrent.futures import ThreadPoolExecutor, as_completed +import aiohttp +import numpy as np +import requests +from tqdm.asyncio import tqdm +from transformers import ( + AutoTokenizer, + PreTrainedTokenizer, + PreTrainedTokenizerBase, + PreTrainedTokenizerFast, +) + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + +global args + + +@dataclass +class RequestFuncInput: + prompt: str + api_url: str + prompt_len: int + output_len: int + model: str + extra_request_body: Dict[str, Any] + + +@dataclass +class RequestFuncOutput: + generated_text: str = "" + success: bool = False + latency: float = 0.0 + ttft: float = 0.0 # Time to first token + itl: List[float] = field(default_factory=list) # List of inter-token latencies + prompt_len: int = 0 + error: str = "" + output_len: int = 0 + + +def remove_prefix(text: str, prefix: str) -> str: + return text[len(prefix) :] if text.startswith(prefix) else text + + +# trt llm not support ignore_eos +# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505 +async def async_request_trt_llm( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith("generate_stream") + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + payload = { + "accumulate_tokens": True, + "text_input": request_func_input.prompt, + "temperature": 0.000001, + "top_p": 1.0, + "max_tokens": request_func_input.output_len, + "stream": True, + "min_length": request_func_input.output_len, + "end_id": 1048576, + **request_func_input.extra_request_body, + } + if args.disable_ignore_eos: + del payload["min_length"] + del payload["end_id"] + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data:") + + data = json.loads(chunk) + output.generated_text += data["text_output"] + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + + output.latency = most_recent_timestamp - st + output.success = True + output.output_len = request_func_input.output_len + + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +# set ignore_eos True by default +async def async_request_openai_completions( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith( + "completions" + ), "OpenAI Completions API URL must end with 'completions'." + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + payload = { + "model": request_func_input.model, + "prompt": request_func_input.prompt, + "temperature": 0.0, + "best_of": 1, + "max_tokens": request_func_input.output_len, + "stream": not args.disable_stream, + "ignore_eos": not args.disable_ignore_eos, + **request_func_input.extra_request_body, + } + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") + latency = time.perf_counter() - st + if chunk == "[DONE]": + pass + else: + data = json.loads(chunk) + + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if data["choices"][0]["text"]: + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text += data["choices"][0]["text"] + + output.generated_text = generated_text + output.success = True + output.latency = latency + output.output_len = request_func_input.output_len + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +def get_model(pretrained_model_name_or_path: str) -> str: + if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true": + import huggingface_hub.constants + from modelscope import snapshot_download + + model_path = snapshot_download( + model_id=pretrained_model_name_or_path, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"], + ) + + return model_path + return pretrained_model_name_or_path + + +def get_tokenizer( + pretrained_model_name_or_path: str, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + if pretrained_model_name_or_path is not None and not os.path.exists( + pretrained_model_name_or_path + ): + pretrained_model_name_or_path = get_model(pretrained_model_name_or_path) + return AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=True + ) + + +ASYNC_REQUEST_FUNCS = { + "sglang": async_request_openai_completions, + "vllm": async_request_openai_completions, + "lmdeploy": async_request_openai_completions, + "trt": async_request_trt_llm, +} + + +@dataclass +class BenchmarkMetrics: + completed: int + total_input: int + total_output: int + total_output_retokenized: int + request_throughput: float + input_throughput: float + output_throughput: float + output_throughput_retokenized: float + mean_ttft_ms: float + median_ttft_ms: float + std_ttft_ms: float + p99_ttft_ms: float + mean_tpot_ms: float + median_tpot_ms: float + std_tpot_ms: float + p99_tpot_ms: float + mean_itl_ms: float + median_itl_ms: float + std_itl_ms: float + p99_itl_ms: float + mean_e2e_latency_ms: float + median_e2e_latency_ms: float + + +default_sharegpt_path = "ShareGPT_V3_unfiltered_cleaned_split.json" + + +def download_sharegpt_dataset(path): + url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" + + print(f"Downloading dataset from {url}") + try: + response = requests.get(url, stream=True) + response.raise_for_status() + + total_size = int(response.headers.get("content-length", 0)) + block_size = 8192 + + with open(path, "wb") as f, tqdm( + desc="Downloading", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as progress_bar: + for data in response.iter_content(block_size): + size = f.write(data) + progress_bar.update(size) + + print(f"Dataset downloaded and saved to {path}") + except requests.RequestException as e: + raise Exception(f"Failed to download dataset: {e}") + + +def sample_sharegpt_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int] = None, +) -> List[Tuple[str, int, int]]: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Download sharegpt if necessary + if not os.path.isfile(dataset_path) and not os.path.isfile(default_sharegpt_path): + download_sharegpt_dataset(default_sharegpt_path) + dataset_path = default_sharegpt_path + else: + dataset_path = ( + dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path + ) + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + ] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: List[Tuple[str, int, int]] = [] + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = ( + len(completion_token_ids) if fixed_output_len is None else fixed_output_len + ) + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + filtered_dataset.append((prompt, prompt_len, output_len)) + + return filtered_dataset +import pickle + +def sample_random_requests( + input_len: int, + output_len: int, + num_prompts: int, + range_ratio: float, + tokenizer: PreTrainedTokenizerBase, + dataset_path: str, +) -> List[Tuple[str, int, int]]: + cache_path = f'./input_cache_{num_prompts}' + # 尝试加载缓存的 input_requests + if os.path.isfile(cache_path): + with open(cache_path, 'rb') as f: + input_requests = pickle.load(f) + print("Loaded input_requests from cache.") + return input_requests + input_lens = np.random.randint( + max(int(input_len * range_ratio), 1), + input_len + 1, + size=num_prompts, + ) + output_lens = np.random.randint( + int(output_len * range_ratio), + output_len + 1, + size=num_prompts, + ) + # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens + + # Download sharegpt if necessary + if not os.path.isfile(dataset_path) and not os.path.isfile( + default_sharegpt_path + ): + download_sharegpt_dataset(default_sharegpt_path) + dataset_path = default_sharegpt_path + else: + dataset_path = ( + dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path + ) + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + ] + + # Shuffle the dataset. + random.shuffle(dataset) + + def process_prompt(i): + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + prompt_len = len(prompt_token_ids) + if prompt_len > input_lens[i]: + input_ids = prompt_token_ids[: input_lens[i]] + else: + ratio = (input_lens[i] + prompt_len - 1) // prompt_len + input_ids = (prompt_token_ids * ratio)[: input_lens[i]] + prompt = tokenizer.decode(input_ids) + # print(f"sample {i} has been processed...") + return (prompt, int(input_lens[i]), int(output_lens[i])) + input_requests = [] + # Filter out sequences that are too long or too short + t1 = time.time() + with ThreadPoolExecutor(max_workers=os.cpu_count() * 5) as executor: + # 提交所有任务 + futures = [executor.submit(process_prompt, i) for i in range(num_prompts)] + + # 等待所有任务完成,并收集结果 + for future in as_completed(futures): + try: + result = future.result() + input_requests.append(result) + except Exception as e: + print(f"Task generated an exception: {e}") + t2 = time.time() + print(f"It takes {t2 - t1} seconds to prepare prompts....") + # 保存 input_requests 到缓存文件 + with open(cache_path, 'wb') as f: + pickle.dump(input_requests, f) + print(f"Saved input_requests_{num_prompts} to cache.") + print(f"#Input tokens: {np.sum(input_lens)}") + print(f"#Output tokens: {np.sum(output_lens)}") + return input_requests + + +async def get_request( + input_requests: List[Tuple[str, int, int]], + request_rate: float, +) -> AsyncGenerator[Tuple[str, int, int], None]: + input_requests = iter(input_requests) + for request in input_requests: + yield request + + if request_rate == float("inf"): + # If the request rate is infinity, then we don't need to wait. + continue + + # Sample the request interval from the exponential distribution. + interval = np.random.exponential(1.0 / request_rate) + # The next request will be sent after the interval. + await asyncio.sleep(interval) + + +def calculate_metrics( + input_requests: List[Tuple[str, int, int]], + outputs: List[RequestFuncOutput], + dur_s: float, + tokenizer: PreTrainedTokenizerBase, + backend: str, +) -> Tuple[BenchmarkMetrics, List[int]]: + output_lens: List[int] = [] + retokenized_output_lens: List[int] = [] + total_input = 0 + completed = 0 + itls: List[float] = [] + tpots: List[float] = [] + ttfts: List[float] = [] + e2e_latencies: List[float] = [] + + input_lens:List[float] = [] + + for i in range(len(outputs)): + if outputs[i].success: + output_len = outputs[i].output_len + output_lens.append(output_len) + retokenized_output_len = len( + tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids + ) + retokenized_output_lens.append(retokenized_output_len) + total_input += input_requests[i][1] + + input_lens.append(input_requests[i][1]) + + + if output_len > 1: + tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1)) + itls += outputs[i].itl + ttfts.append(outputs[i].ttft) + + e2e_latencies.append(outputs[i].latency) + + completed += 1 + else: + output_lens.append(0) + retokenized_output_lens.append(0) + + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2, + ) + + metric_data = [input_lens, output_lens, ttfts] + with open(f'metrics_{time.time()}.json', 'w') as f: + json.dump(metric_data, f) + + metrics = BenchmarkMetrics( + completed=completed, + total_input=total_input, + total_output=sum(output_lens), + total_output_retokenized=sum(retokenized_output_lens), + request_throughput=completed / dur_s, + input_throughput=total_input / dur_s, + output_throughput=sum(output_lens) / dur_s, + output_throughput_retokenized=sum(retokenized_output_lens) / dur_s, + mean_ttft_ms=np.mean(ttfts or 0) + * 1000, # ttfts is empty if streaming is not supported by backend + median_ttft_ms=np.median(ttfts or 0) * 1000, + std_ttft_ms=np.std(ttfts or 0) * 1000, + p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, + mean_tpot_ms=np.mean(tpots or 0) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + std_tpot_ms=np.std(tpots or 0) * 1000, + p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000, + mean_itl_ms=np.mean(itls or 0) * 1000, + median_itl_ms=np.median(itls or 0) * 1000, + std_itl_ms=np.std(itls or 0) * 1000, + p99_itl_ms=np.percentile(itls or 0, 99) * 1000, + mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000, + median_e2e_latency_ms=np.median(e2e_latencies) * 1000, + ) + + return metrics, output_lens + + +async def benchmark( + backend: str, + api_url: str, + model_id: str, + tokenizer: PreTrainedTokenizerBase, + input_requests: List[Tuple[str, int, int]], + request_rate: float, + disable_tqdm: bool, + enable_multi: bool, + extra_request_body: Dict[str, Any], +): + if backend in ASYNC_REQUEST_FUNCS: + request_func = ASYNC_REQUEST_FUNCS[backend] + else: + raise ValueError(f"Unknown backend: {backend}") + + print("Starting initial single prompt test run...") + test_prompt, test_prompt_len, test_output_len = input_requests[0] + + words = test_prompt.split() + + # 使用random.shuffle打乱单词列表 + random.shuffle(words) + + # 将打乱后的单词列表合并回文本 + test_prompt = ' '.join(words) + + + test_input = RequestFuncInput( + model=model_id, + prompt=test_prompt, + api_url=api_url, + prompt_len=test_prompt_len, + output_len=test_output_len, + extra_request_body=extra_request_body, + ) + + pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + + benchmark_start_time = time.perf_counter() + tasks: List[asyncio.Task] = [] + async for request in get_request(input_requests, request_rate): + prompt, prompt_len, output_len = request + request_func_input = RequestFuncInput( + model=model_id, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + extra_request_body=extra_request_body, + ) + tasks.append( + asyncio.create_task( + request_func(request_func_input=request_func_input, pbar=pbar) + ) + ) + outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) + + if pbar is not None: + pbar.close() + + benchmark_duration = time.perf_counter() - benchmark_start_time + + metrics, output_lens = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + backend=backend, + ) + + print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) + print("{:<40} {:<10}".format("Backend:", backend)) + print("{:<40} {:<10}".format("Traffic request rate:", request_rate)) + print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) + print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) + print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) + print( + "{:<40} {:<10}".format( + "Total generated tokens (retokenized):", metrics.total_output_retokenized + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Request throughput (req/s):", metrics.request_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Input token throughput (tok/s):", metrics.input_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Output token throughput (tok/s):", metrics.output_throughput + ) + ) + print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-")) + print( + "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms) + ) + print( + "{:<40} {:<10.2f}".format( + "Median E2E Latency (ms):", metrics.median_e2e_latency_ms + ) + ) + print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-")) + print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms)) + print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms)) + print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms)) + print( + "{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-") + ) + print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms)) + print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms)) + print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms)) + print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-")) + print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms)) + print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms)) + print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms)) + print("=" * 50) + + if ( + metrics.median_ttft_ms is not None + and metrics.mean_itl_ms is not None + and metrics.output_throughput is not None + ): + result = { + "backend": args.backend, + "dataset_name": args.dataset_name, + "request_rate": request_rate, + "total_input": metrics.total_input, + "total_output": metrics.total_output, + "total_output_retokenized": metrics.total_output_retokenized, + "mean_e2e_latency": metrics.mean_e2e_latency_ms, + "median_e2e_latency": metrics.median_e2e_latency_ms, + "median_ttft": metrics.median_ttft_ms, + "median_itl": metrics.median_itl_ms, + "output_token_throughput": metrics.output_throughput, + "sharegpt_output_len": args.sharegpt_output_len, + "random_input_len": args.random_input_len, + "random_output_len": args.random_output_len, + "random_range_ratio": args.random_range_ratio, + "benchmark_duration": benchmark_duration, + } + else: + print(f"Error running benchmark for request rate: {request_rate}") + print("-" * 30) + + # Determine output file name + if args.output_file: + output_file_name = args.output_file + else: + now = datetime.now().strftime("%m%d") + if args.dataset_name == "random": + output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl" + else: + output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl" + + # Append results to a JSONL file + with open(output_file_name, "a") as file: + file.write(json.dumps(result) + "\n") + + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "total_output_tokens_retokenized": metrics.total_output_retokenized, + "request_throughput": metrics.request_throughput, + "input_throughput": metrics.input_throughput, + "output_throughput": metrics.output_throughput, + "mean_ttft_ms": metrics.mean_ttft_ms, + "median_ttft_ms": metrics.median_ttft_ms, + "std_ttft_ms": metrics.std_ttft_ms, + "p99_ttft_ms": metrics.p99_ttft_ms, + "mean_tpot_ms": metrics.mean_tpot_ms, + "median_tpot_ms": metrics.median_tpot_ms, + "std_tpot_ms": metrics.std_tpot_ms, + "p99_tpot_ms": metrics.p99_tpot_ms, + "mean_itl_ms": metrics.mean_itl_ms, + "median_itl_ms": metrics.median_itl_ms, + "std_itl_ms": metrics.std_itl_ms, + "p99_itl_ms": metrics.p99_itl_ms, + "input_lens": [output.prompt_len for output in outputs], + "output_lens": output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], + "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, + "median_e2e_latency_ms": metrics.median_e2e_latency_ms, + } + return result + + +def parse_request_rate_range(request_rate_range): + if len(request_rate_range.split(",")) == 3: + start, stop, step = map(int, request_rate_range.split(",")) + return list(range(start, stop, step)) + else: + return list(map(int, request_rate_range.split(","))) + + +def check_chat_template(model_path): + try: + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + return "chat_template" in tokenizer.init_kwargs + except Exception as e: + print(f"Fail to load tokenizer config with error={e}") + return False + + +def run_benchmark(args_: argparse.Namespace): + global args + args = args_ + + set_ulimit() + random.seed(args.seed) + np.random.seed(args.seed) + + extra_request_body = {} + if args.extra_request_body: + extra_request_body = json.loads(args.extra_request_body) + + if args.port is None: + args.port = { + "sglang": 30000, + "lmdeploy": 23333, + "vllm": 8000, + "trt": 8000, + }.get(args.backend, 30000) + + api_url = ( + f"{args.base_url}/v1/completions" + if args.base_url + else f"http://{args.host}:{args.port}/v1/completions" + ) + model_url = ( + f"{args.base_url}/v1/models" + if args.base_url + else f"http://{args.host}:{args.port}/v1/models" + ) + + if args.backend == "trt": + api_url = ( + f"{args.base_url}/v2/models/ensemble/generate_stream" + if args.base_url + else f"http://{args.host}:{args.port}/v2/models/ensemble/generate_stream" + ) + if args.model is None: + print("Please provide a model using `--model` when using `trt` backend.") + sys.exit(1) + + if args.model is None: + try: + response = requests.get(model_url) + model_list = response.json().get("data", []) + args.model = model_list[0]["id"] if model_list else None + except Exception as e: + print(f"Failed to fetch model from {model_url}. Error: {e}") + print( + "Please specify the correct host and port using `--host` and `--port`." + ) + sys.exit(1) + + if args.model is None: + print("No model specified or found. Please provide a model using `--model`.") + sys.exit(1) + + if not check_chat_template(args.model): + print( + "\nWARNING It is recommended to use the `Chat` or `Instruct` model for benchmarking.\n" + "Because when the tokenizer counts the output tokens, if there is gibberish, it might count incorrectly.\n" + ) + + print(f"{args}\n") + + backend = args.backend + model_id = args.model + tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model + + tokenizer = get_tokenizer(tokenizer_id) + + if args.dataset_name == "sharegpt": + input_requests = sample_sharegpt_requests( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + tokenizer=tokenizer, + fixed_output_len=args.sharegpt_output_len, + ) + elif args.dataset_name == "random": + input_requests = sample_random_requests( + input_len=args.random_input_len, + output_len=args.random_output_len, + num_prompts=args.num_prompts, + range_ratio=args.random_range_ratio, + tokenizer=tokenizer, + dataset_path=args.dataset_path, + ) + else: + raise ValueError(f"Unknown dataset: {args.dataset_name}") + + if args.multi: + request_rates = parse_request_rate_range(args.request_rate_range) + + for rate in request_rates: + asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_rate=rate, + disable_tqdm=args.disable_tqdm, + enable_multi=args.multi, + extra_request_body=extra_request_body, + ) + ) + else: + return asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_rate=args.request_rate, + disable_tqdm=args.disable_tqdm, + enable_multi=args.multi, + extra_request_body=extra_request_body, + ) + ) + + +# to avoid relying on SGLang's components +def set_ulimit(target_soft_limit=65535): + resource_type = resource.RLIMIT_NOFILE + current_soft, current_hard = resource.getrlimit(resource_type) + + if current_soft < target_soft_limit: + try: + resource.setrlimit(resource_type, (target_soft_limit, current_hard)) + except ValueError as e: + print(f"Fail to set RLIMIT_NOFILE: {e}") + + +if __name__ == "__main__": + parser = ArgumentParser(description="Benchmark the online serving throughput.") + parser.add_argument( + "--backend", + type=str, + choices=list(ASYNC_REQUEST_FUNCS.keys()), + default="sglang", + help="Must specify a backend, depending on the LLM Inference Engine.", + ) + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server or API base url if not using http host and port.", + ) + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0." + ) + parser.add_argument( + "--port", + type=int, + help="If not set, the default port is configured according to its default value for different LLM Inference Engines.", + ) + parser.add_argument( + "--dataset-name", + type=str, + default="sharegpt", + choices=["sharegpt", "random"], + help="Name of the dataset to benchmark on.", + ) + parser.add_argument( + "--dataset-path", type=str, default="", help="Path to the dataset." + ) + parser.add_argument( + "--model", + type=str, + help="Name or path of the model. If not set, the default model will request /v1/models for conf.", + ) + parser.add_argument( + "--tokenizer", + type=str, + help="Name or path of the tokenizer. If not set, using the model conf.", + ) + parser.add_argument( + "--num-prompts", + type=int, + default=1000, + help="Number of prompts to process. Default is 1000.", + ) + parser.add_argument( + "--sharegpt-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output length from the ShareGPT dataset.", + ) + parser.add_argument( + "--random-input-len", + type=int, + default=1024, + help="Number of input tokens per request, used only for random dataset.", + ) + parser.add_argument( + "--random-output-len", + type=int, + default=128, + help="Number of output tokens per request, used only for random dataset.", + ) + parser.add_argument( + "--random-range-ratio", + type=float, + default=0.0, + help="Range of sampled ratio of input/output length, " + "used only for random dataset.", + ) + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, then all the requests are sent at time 0. " + "Otherwise, we use Poisson process to synthesize the request arrival times. Default is 128.0.", + ) + parser.add_argument("--seed", type=int, default=0, help="Default is 0.") + parser.add_argument( + "--multi", + action="store_true", + help="Use request rate range rather than single value.", + ) + parser.add_argument( + "--request-rate-range", + type=str, + default="2,34,2", + help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.", + ) + parser.add_argument("--output-file", type=str, help="Output JSONL file name.") + parser.add_argument( + "--disable-tqdm", + action="store_true", + help="Specify to disable tqdm progress bar.", + ) + parser.add_argument( + "--disable-stream", + action="store_true", + help="Disable streaming mode.", + ) + parser.add_argument( + "--disable-ignore-eos", + action="store_true", + help="Disable ignoring EOS.", + ) + parser.add_argument( + "--extra-request-body", + metavar='{"key1": "value1", "key2": "value2"}', + type=str, + help="Append given JSON object to the request payload. You can use this to specify" + "additional generate params like sampling params.", + ) + args = parser.parse_args() + run_benchmark(args) diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index b02ce9f81e..0689cc11d4 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -25,7 +25,7 @@ def __init__(self): self.layer_sync_threshold = 8192 # Runtime constants: others - self.num_continue_decode_steps = 10 + self.num_continue_decode_steps = 1 self.retract_decode_steps = 20 self.flashinfer_workspace_size = 192 * 1024 * 1024 diff --git a/python/sglang/iter_search_best_qps.py b/python/sglang/iter_search_best_qps.py new file mode 100644 index 0000000000..d978cd21ba --- /dev/null +++ b/python/sglang/iter_search_best_qps.py @@ -0,0 +1,1098 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py +# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py + +""" +Benchmark online serving. + +Usage: +python3 -m sglang.bench_serving --backend sglang --num-prompt 10 + +python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5 +python3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi +""" + +import argparse +import asyncio +import json +import os +import random +import resource +import sys +import time +import traceback +import warnings +from argparse import ArgumentParser +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union + +import aiohttp +import numpy as np +import requests +from tqdm.asyncio import tqdm +from transformers import ( + AutoTokenizer, + PreTrainedTokenizer, + PreTrainedTokenizerBase, + PreTrainedTokenizerFast, +) + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + +global args + + +@dataclass +class RequestFuncInput: + prompt: str + api_url: str + prompt_len: int + output_len: int + model: str + extra_request_body: Dict[str, Any] + + +@dataclass +class RequestFuncOutput: + generated_text: str = "" + success: bool = False + latency: float = 0.0 + ttft: float = 0.0 # Time to first token + itl: List[float] = field(default_factory=list) # List of inter-token latencies + prompt_len: int = 0 + error: str = "" + output_len: int = 0 + + +def remove_prefix(text: str, prefix: str) -> str: + return text[len(prefix) :] if text.startswith(prefix) else text + + +# trt llm not support ignore_eos +# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505 +async def async_request_trt_llm( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith("generate_stream") + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + payload = { + "accumulate_tokens": True, + "text_input": request_func_input.prompt, + "temperature": 0.000001, + "top_p": 1.0, + "max_tokens": request_func_input.output_len, + "stream": True, + "min_length": request_func_input.output_len, + "end_id": 1048576, + **request_func_input.extra_request_body, + } + if args.disable_ignore_eos: + del payload["min_length"] + del payload["end_id"] + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data:") + + data = json.loads(chunk) + output.generated_text += data["text_output"] + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + + output.latency = most_recent_timestamp - st + output.success = True + output.output_len = request_func_input.output_len + + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +# set ignore_eos True by default +async def async_request_openai_completions( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith( + "completions" + ), "OpenAI Completions API URL must end with 'completions'." + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + payload = { + "model": request_func_input.model, + "prompt": request_func_input.prompt, + "temperature": 0.0, + "best_of": 1, + "max_tokens": request_func_input.output_len, + "stream": not args.disable_stream, + "ignore_eos": not args.disable_ignore_eos, + **request_func_input.extra_request_body, + } + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") + latency = time.perf_counter() - st + if chunk == "[DONE]": + pass + else: + data = json.loads(chunk) + + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if data["choices"][0]["text"]: + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text += data["choices"][0]["text"] + + output.generated_text = generated_text + output.success = True + output.latency = latency + output.output_len = request_func_input.output_len + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +def get_model(pretrained_model_name_or_path: str) -> str: + if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true": + import huggingface_hub.constants + from modelscope import snapshot_download + + model_path = snapshot_download( + model_id=pretrained_model_name_or_path, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"], + ) + + return model_path + return pretrained_model_name_or_path + + +def get_tokenizer( + pretrained_model_name_or_path: str, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + if pretrained_model_name_or_path is not None and not os.path.exists( + pretrained_model_name_or_path + ): + pretrained_model_name_or_path = get_model(pretrained_model_name_or_path) + return AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=True + ) + + +ASYNC_REQUEST_FUNCS = { + "sglang": async_request_openai_completions, + "vllm": async_request_openai_completions, + "lmdeploy": async_request_openai_completions, + "trt": async_request_trt_llm, +} + + +@dataclass +class BenchmarkMetrics: + completed: int + total_input: int + total_output: int + total_output_retokenized: int + request_throughput: float + input_throughput: float + output_throughput: float + output_throughput_retokenized: float + mean_ttft_ms: float + median_ttft_ms: float + std_ttft_ms: float + p99_ttft_ms: float + mean_tpot_ms: float + median_tpot_ms: float + std_tpot_ms: float + p99_tpot_ms: float + mean_itl_ms: float + median_itl_ms: float + std_itl_ms: float + p99_itl_ms: float + mean_e2e_latency_ms: float + median_e2e_latency_ms: float + + +default_sharegpt_path = "ShareGPT_V3_unfiltered_cleaned_split.json" + + +def download_sharegpt_dataset(path): + url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" + + print(f"Downloading dataset from {url}") + try: + response = requests.get(url, stream=True) + response.raise_for_status() + + total_size = int(response.headers.get("content-length", 0)) + block_size = 8192 + + with open(path, "wb") as f, tqdm( + desc="Downloading", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as progress_bar: + for data in response.iter_content(block_size): + size = f.write(data) + progress_bar.update(size) + + print(f"Dataset downloaded and saved to {path}") + except requests.RequestException as e: + raise Exception(f"Failed to download dataset: {e}") + + +def sample_sharegpt_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int] = None, +) -> List[Tuple[str, int, int]]: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Download sharegpt if necessary + if not os.path.isfile(dataset_path) and not os.path.isfile(default_sharegpt_path): + download_sharegpt_dataset(default_sharegpt_path) + dataset_path = default_sharegpt_path + else: + dataset_path = ( + dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path + ) + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + ] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: List[Tuple[str, int, int]] = [] + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = ( + len(completion_token_ids) if fixed_output_len is None else fixed_output_len + ) + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + filtered_dataset.append((prompt, prompt_len, output_len)) + + return filtered_dataset + + +def sample_random_requests( + input_len: int, + output_len: int, + num_prompts: int, + range_ratio: float, + tokenizer: PreTrainedTokenizerBase, + dataset_path: str, +) -> List[Tuple[str, int, int]]: + + input_lens = np.random.randint( + max(int(input_len * range_ratio), 1), + input_len + 1, + size=num_prompts, + ) + output_lens = np.random.randint( + int(output_len * range_ratio), + output_len + 1, + size=num_prompts, + ) + + if True: + # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens + + # Download sharegpt if necessary + if not os.path.isfile(dataset_path) and not os.path.isfile( + default_sharegpt_path + ): + download_sharegpt_dataset(default_sharegpt_path) + dataset_path = default_sharegpt_path + else: + dataset_path = ( + dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path + ) + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + ] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + input_requests: List[Tuple[str, int, int]] = [] + for i in range(num_prompts): + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + prompt_len = len(prompt_token_ids) + + if prompt_len > input_lens[i]: + input_ids = prompt_token_ids[: input_lens[i]] + else: + ratio = (input_lens[i] + prompt_len - 1) // prompt_len + input_ids = (prompt_token_ids * ratio)[: input_lens[i]] + prompt = tokenizer.decode(input_ids) + input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) + else: + # Sample token ids from random integers. This can cause some NaN issues. + offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) + input_requests = [] + for i in range(num_prompts): + prompt = tokenizer.decode( + [ + (offsets[i] + i + j) % tokenizer.vocab_size + for j in range(input_lens[i]) + ] + ) + input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) + + print(f"#Input tokens: {np.sum(input_lens)}") + print(f"#Output tokens: {np.sum(output_lens)}") + return input_requests + + +async def get_request( + input_requests: List[Tuple[str, int, int]], + request_rate: float, +) -> AsyncGenerator[Tuple[str, int, int], None]: + input_requests = iter(input_requests) + for request in input_requests: + yield request + + if request_rate == float("inf"): + # If the request rate is infinity, then we don't need to wait. + continue + + # Sample the request interval from the exponential distribution. + interval = np.random.exponential(1.0 / request_rate) + # The next request will be sent after the interval. + await asyncio.sleep(interval) + + +def calculate_metrics( + input_requests: List[Tuple[str, int, int]], + outputs: List[RequestFuncOutput], + dur_s: float, + tokenizer: PreTrainedTokenizerBase, + backend: str, +) -> Tuple[BenchmarkMetrics, List[int]]: + output_lens: List[int] = [] + retokenized_output_lens: List[int] = [] + total_input = 0 + completed = 0 + itls: List[float] = [] + tpots: List[float] = [] + ttfts: List[float] = [] + e2e_latencies: List[float] = [] + for i in range(len(outputs)): + if outputs[i].success: + output_len = outputs[i].output_len + output_lens.append(output_len) + retokenized_output_len = len( + tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids + ) + retokenized_output_lens.append(retokenized_output_len) + total_input += input_requests[i][1] + if output_len > 1: + tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1)) + itls += outputs[i].itl + ttfts.append(outputs[i].ttft) + + e2e_latencies.append(outputs[i].latency) + + completed += 1 + else: + output_lens.append(0) + retokenized_output_lens.append(0) + + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2, + ) + metrics = BenchmarkMetrics( + completed=completed, + total_input=total_input, + total_output=sum(output_lens), + total_output_retokenized=sum(retokenized_output_lens), + request_throughput=completed / dur_s, + input_throughput=total_input / dur_s, + output_throughput=sum(output_lens) / dur_s, + output_throughput_retokenized=sum(retokenized_output_lens) / dur_s, + mean_ttft_ms=np.mean(ttfts or 0) + * 1000, # ttfts is empty if streaming is not supported by backend + median_ttft_ms=np.median(ttfts or 0) * 1000, + std_ttft_ms=np.std(ttfts or 0) * 1000, + p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, + mean_tpot_ms=np.mean(tpots or 0) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + std_tpot_ms=np.std(tpots or 0) * 1000, + p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000, + mean_itl_ms=np.mean(itls or 0) * 1000, + median_itl_ms=np.median(itls or 0) * 1000, + std_itl_ms=np.std(itls or 0) * 1000, + p99_itl_ms=np.percentile(itls or 0, 99) * 1000, + mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000, + median_e2e_latency_ms=np.median(e2e_latencies) * 1000, + ) + + return metrics, output_lens + + +async def benchmark( + backend: str, + api_url: str, + model_id: str, + tokenizer: PreTrainedTokenizerBase, + input_requests: List[Tuple[str, int, int]], + request_rate: float, + disable_tqdm: bool, + enable_multi: bool, + extra_request_body: Dict[str, Any], +): + if backend in ASYNC_REQUEST_FUNCS: + request_func = ASYNC_REQUEST_FUNCS[backend] + else: + raise ValueError(f"Unknown backend: {backend}") + + print("Starting initial single prompt test run...") + test_prompt, test_prompt_len, test_output_len = input_requests[0] + + words = test_prompt.split() + + # 使用random.shuffle打乱单词列表 + random.shuffle(words) + + # 将打乱后的单词列表合并回文本 + test_prompt = ' '.join(words) + + + test_input = RequestFuncInput( + model=model_id, + prompt=test_prompt, + api_url=api_url, + prompt_len=test_prompt_len, + output_len=test_output_len, + extra_request_body=extra_request_body, + ) + + pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + + benchmark_start_time = time.perf_counter() + tasks: List[asyncio.Task] = [] + async for request in get_request(input_requests, request_rate): + prompt, prompt_len, output_len = request + request_func_input = RequestFuncInput( + model=model_id, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + extra_request_body=extra_request_body, + ) + tasks.append( + asyncio.create_task( + request_func(request_func_input=request_func_input, pbar=pbar) + ) + ) + outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) + + if pbar is not None: + pbar.close() + + benchmark_duration = time.perf_counter() - benchmark_start_time + + metrics, output_lens = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + backend=backend, + ) + + print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) + print("{:<40} {:<10}".format("Backend:", backend)) + print("{:<40} {:<10}".format("Traffic request rate:", request_rate)) + print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) + print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) + print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) + print( + "{:<40} {:<10}".format( + "Total generated tokens (retokenized):", metrics.total_output_retokenized + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Request throughput (req/s):", metrics.request_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Input token throughput (tok/s):", metrics.input_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Output token throughput (tok/s):", metrics.output_throughput + ) + ) + print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-")) + print( + "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms) + ) + print( + "{:<40} {:<10.2f}".format( + "Median E2E Latency (ms):", metrics.median_e2e_latency_ms + ) + ) + print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-")) + print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms)) + print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms)) + print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms)) + print( + "{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-") + ) + print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms)) + print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms)) + print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms)) + print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-")) + print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms)) + print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms)) + print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms)) + print("=" * 50) + + if ( + metrics.median_ttft_ms is not None + and metrics.mean_itl_ms is not None + and metrics.output_throughput is not None + ): + result = { + "backend": args.backend, + "dataset_name": args.dataset_name, + "request_rate": request_rate, + "total_input": metrics.total_input, + "total_output": metrics.total_output, + "total_output_retokenized": metrics.total_output_retokenized, + "mean_e2e_latency": metrics.mean_e2e_latency_ms, + "median_e2e_latency": metrics.median_e2e_latency_ms, + "median_ttft": metrics.median_ttft_ms, + "median_itl": metrics.median_itl_ms, + "output_token_throughput": metrics.output_throughput, + "sharegpt_output_len": args.sharegpt_output_len, + "random_input_len": args.random_input_len, + "random_output_len": args.random_output_len, + "random_range_ratio": args.random_range_ratio, + "benchmark_duration": benchmark_duration, + } + else: + print(f"Error running benchmark for request rate: {request_rate}") + print("-" * 30) + + # Determine output file name + if args.output_file: + output_file_name = args.output_file + else: + now = datetime.now().strftime("%m%d") + if args.dataset_name == "random": + output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl" + else: + output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl" + + # Append results to a JSONL file + with open(output_file_name, "a") as file: + file.write(json.dumps(result) + "\n") + + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "total_output_tokens_retokenized": metrics.total_output_retokenized, + "request_throughput": metrics.request_throughput, + "input_throughput": metrics.input_throughput, + "output_throughput": metrics.output_throughput, + "mean_ttft_ms": metrics.mean_ttft_ms, + "median_ttft_ms": metrics.median_ttft_ms, + "std_ttft_ms": metrics.std_ttft_ms, + "p99_ttft_ms": metrics.p99_ttft_ms, + "mean_tpot_ms": metrics.mean_tpot_ms, + "median_tpot_ms": metrics.median_tpot_ms, + "std_tpot_ms": metrics.std_tpot_ms, + "p99_tpot_ms": metrics.p99_tpot_ms, + "mean_itl_ms": metrics.mean_itl_ms, + "median_itl_ms": metrics.median_itl_ms, + "std_itl_ms": metrics.std_itl_ms, + "p99_itl_ms": metrics.p99_itl_ms, + "input_lens": [output.prompt_len for output in outputs], + "output_lens": output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], + "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, + "median_e2e_latency_ms": metrics.median_e2e_latency_ms, + } + return result + + +def parse_request_rate_range(request_rate_range): + if len(request_rate_range.split(",")) == 3: + start, stop, step = map(int, request_rate_range.split(",")) + return list(range(start, stop, step)) + else: + return list(map(int, request_rate_range.split(","))) + + +def check_chat_template(model_path): + try: + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + return "chat_template" in tokenizer.init_kwargs + except Exception as e: + print(f"Fail to load tokenizer config with error={e}") + return False + + +def run_benchmark(args_: argparse.Namespace): + global args + args = args_ + + set_ulimit() + random.seed(args.seed) + np.random.seed(args.seed) + + extra_request_body = {} + if args.extra_request_body: + extra_request_body = json.loads(args.extra_request_body) + + if args.port is None: + args.port = { + "sglang": 30000, + "lmdeploy": 23333, + "vllm": 8000, + "trt": 8000, + }.get(args.backend, 30000) + + api_url = ( + f"{args.base_url}/v1/completions" + if args.base_url + else f"http://{args.host}:{args.port}/v1/completions" + ) + model_url = ( + f"{args.base_url}/v1/models" + if args.base_url + else f"http://{args.host}:{args.port}/v1/models" + ) + + if args.backend == "trt": + api_url = ( + f"{args.base_url}/v2/models/ensemble/generate_stream" + if args.base_url + else f"http://{args.host}:{args.port}/v2/models/ensemble/generate_stream" + ) + if args.model is None: + print("Please provide a model using `--model` when using `trt` backend.") + sys.exit(1) + + if args.model is None: + try: + response = requests.get(model_url) + model_list = response.json().get("data", []) + args.model = model_list[0]["id"] if model_list else None + except Exception as e: + print(f"Failed to fetch model from {model_url}. Error: {e}") + print( + "Please specify the correct host and port using `--host` and `--port`." + ) + sys.exit(1) + + if args.model is None: + print("No model specified or found. Please provide a model using `--model`.") + sys.exit(1) + + if not check_chat_template(args.model): + print( + "\nWARNING It is recommended to use the `Chat` or `Instruct` model for benchmarking.\n" + "Because when the tokenizer counts the output tokens, if there is gibberish, it might count incorrectly.\n" + ) + + print(f"{args}\n") + + backend = args.backend + model_id = args.model + tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model + + tokenizer = get_tokenizer(tokenizer_id) + + if args.dataset_name == "sharegpt": + input_requests = sample_sharegpt_requests( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + tokenizer=tokenizer, + fixed_output_len=args.sharegpt_output_len, + ) + elif args.dataset_name == "random": + input_requests = sample_random_requests( + input_len=args.random_input_len, + output_len=args.random_output_len, + num_prompts=args.num_prompts, + range_ratio=args.random_range_ratio, + tokenizer=tokenizer, + dataset_path=args.dataset_path, + ) + else: + raise ValueError(f"Unknown dataset: {args.dataset_name}") + + if args.multi: + request_rates = parse_request_rate_range(args.request_rate_range) + + for rate in request_rates: + asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_rate=rate, + disable_tqdm=args.disable_tqdm, + enable_multi=args.multi, + extra_request_body=extra_request_body, + ) + ) + else: + return asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_rate=args.request_rate, + disable_tqdm=args.disable_tqdm, + enable_multi=args.multi, + extra_request_body=extra_request_body, + ) + ) + + +# to avoid relying on SGLang's components +def set_ulimit(target_soft_limit=65535): + resource_type = resource.RLIMIT_NOFILE + current_soft, current_hard = resource.getrlimit(resource_type) + + if current_soft < target_soft_limit: + try: + resource.setrlimit(resource_type, (target_soft_limit, current_hard)) + except ValueError as e: + print(f"Fail to set RLIMIT_NOFILE: {e}") + + +import json + +def update_qps_json_file(data_list, file_path): + # 读取现有数据 + try: + with open(file_path, 'r', encoding='utf-8') as json_file: + qps_result_list = json.load(json_file) + except FileNotFoundError: + qps_result_list = [] + + # 获取现有数据中所有的request_rate + existing_rates = {item['request_rate'] for item in qps_result_list} + + # 添加新数据,如果request_rate不存在 + new_entries_added = False + for new_data in data_list: + if new_data['request_rate'] not in existing_rates: + qps_result_list.append(new_data) + existing_rates.add(new_data['request_rate']) + new_entries_added = True + else: + print(f"数据已存在,request_rate为{new_data['request_rate']}的项将被跳过。") + + # 如果有新数据被添加,则排序并写入文件 + if new_entries_added: + # 按request_rate排序 + qps_result_list.sort(key=lambda x: x['request_rate']) + + # 写入更新后的数据 + with open(file_path, 'w', encoding='utf-8') as json_file: + json.dump(qps_result_list, json_file, ensure_ascii=False, indent=4) + print("新数据已添加并排序。") + + +if __name__ == "__main__": + parser = ArgumentParser(description="Benchmark the online serving throughput.") + parser.add_argument( + "--backend", + type=str, + choices=list(ASYNC_REQUEST_FUNCS.keys()), + default="sglang", + help="Must specify a backend, depending on the LLM Inference Engine.", + ) + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server or API base url if not using http host and port.", + ) + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0." + ) + parser.add_argument( + "--port", + type=int, + help="If not set, the default port is configured according to its default value for different LLM Inference Engines.", + ) + parser.add_argument( + "--dataset-name", + type=str, + default="sharegpt", + choices=["sharegpt", "random"], + help="Name of the dataset to benchmark on.", + ) + parser.add_argument( + "--dataset-path", type=str, default="", help="Path to the dataset." + ) + parser.add_argument( + "--model", + type=str, + help="Name or path of the model. If not set, the default model will request /v1/models for conf.", + ) + parser.add_argument( + "--tokenizer", + type=str, + help="Name or path of the tokenizer. If not set, using the model conf.", + ) + parser.add_argument( + "--num-prompts", + type=int, + default=1000, + help="Number of prompts to process. Default is 1000.", + ) + parser.add_argument( + "--sharegpt-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output length from the ShareGPT dataset.", + ) + parser.add_argument( + "--random-input-len", + type=int, + default=1024, + help="Number of input tokens per request, used only for random dataset.", + ) + parser.add_argument( + "--random-output-len", + type=int, + default=128, + help="Number of output tokens per request, used only for random dataset.", + ) + parser.add_argument( + "--random-range-ratio", + type=float, + default=0.0, + help="Range of sampled ratio of input/output length, " + "used only for random dataset.", + ) + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, then all the requests are sent at time 0. " + "Otherwise, we use Poisson process to synthesize the request arrival times. Default is 128.0.", + ) + parser.add_argument("--seed", type=int, default=0, help="Default is 0.") + parser.add_argument( + "--multi", + action="store_true", + help="Use request rate range rather than single value.", + ) + parser.add_argument( + "--request-rate-range", + type=str, + default="2,34,2", + help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.", + ) + parser.add_argument("--output-file", type=str, help="Output JSONL file name.") + parser.add_argument( + "--disable-tqdm", + action="store_true", + help="Specify to disable tqdm progress bar.", + ) + parser.add_argument( + "--disable-stream", + action="store_true", + help="Disable streaming mode.", + ) + parser.add_argument( + "--disable-ignore-eos", + action="store_true", + help="Disable ignoring EOS.", + ) + parser.add_argument( + "--extra-request-body", + metavar='{"key1": "value1", "key2": "value2"}', + type=str, + help="Append given JSON object to the request payload. You can use this to specify" + "additional generate params like sampling params.", + ) + parser.add_argument( + "--request-rate-list", + type=json.loads, + default="[1, 5, 10, 20, 35, 50, 70, 100]", + help="request rate list", + ) + + args = parser.parse_args() + qps_result_dict = [] + + + qps_list = args.request_rate_list + for request_rate in qps_list: + args.request_rate = request_rate + result = run_benchmark(args) + qps_result_dict.append({ + "request_rate": request_rate, + "mean_ttft_ms": float(result['mean_ttft_ms']), + "request_throughput": float(result["request_throughput"]) + }) + + # 将字典列表以JSON格式写入文件 + update_qps_json_file(qps_result_dict, 'qps_results.json') + # with open('qps_results.json', 'w', encoding='utf-8') as json_file: + # json.dump(qps_result_dict, json_file, ensure_ascii=False, indent=4) + + # print(f"QPS结果列表已保存到'qps_results.json'") + time.sleep(10) + + import matplotlib.pyplot as plt + + request_rates = [result['request_rate'] for result in qps_result_dict] + mean_ttft_ms = [result['mean_ttft_ms'] for result in qps_result_dict] + request_throughput = [result['request_throughput'] for result in qps_result_dict] + + fig, ax1 = plt.subplots() + + color = 'tab:red' + ax1.set_xlabel('Request Rate (QPS)') + ax1.set_ylabel('Mean Time to First Byte (ms)', color=color) + ax1.plot(request_rates, mean_ttft_ms, marker='o', color=color) + ax1.tick_params(axis='y', labelcolor=color) + + ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis + + color = 'tab:blue' + ax2.set_ylabel('Request Throughput', color=color) # we already handled the x-label with ax1 + ax2.plot(request_rates, request_throughput, marker='x', color=color) + ax2.tick_params(axis='y', labelcolor=color) + + fig.tight_layout() # otherwise the right y-label is slightly clipped + plt.title('Benchmark Results') + plt.savefig("./QPS.png") + plt.close() + + + + diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 91dc0dc4e9..0137b579c9 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -10,5 +10,6 @@ ServerArgs.add_cli_args(parser) args = parser.parse_args() server_args = ServerArgs.from_cli_args(args) + launch_server(server_args) diff --git a/python/sglang/search_best_qps.py b/python/sglang/search_best_qps.py new file mode 100644 index 0000000000..400254837b --- /dev/null +++ b/python/sglang/search_best_qps.py @@ -0,0 +1,1068 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py +# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py + +""" +Benchmark online serving. + +Usage: +python3 -m sglang.bench_serving --backend sglang --num-prompt 10 + +python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5 +python3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi +""" + +import argparse +import asyncio +import json +import os +import random +import resource +import sys +import time +import traceback +import warnings +from argparse import ArgumentParser +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union + +import aiohttp +import numpy as np +import requests +from tqdm.asyncio import tqdm +from transformers import ( + AutoTokenizer, + PreTrainedTokenizer, + PreTrainedTokenizerBase, + PreTrainedTokenizerFast, +) + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + +global args + + +@dataclass +class RequestFuncInput: + prompt: str + api_url: str + prompt_len: int + output_len: int + model: str + extra_request_body: Dict[str, Any] + + +@dataclass +class RequestFuncOutput: + generated_text: str = "" + success: bool = False + latency: float = 0.0 + ttft: float = 0.0 # Time to first token + itl: List[float] = field(default_factory=list) # List of inter-token latencies + prompt_len: int = 0 + error: str = "" + output_len: int = 0 + + +def remove_prefix(text: str, prefix: str) -> str: + return text[len(prefix) :] if text.startswith(prefix) else text + + +# trt llm not support ignore_eos +# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505 +async def async_request_trt_llm( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith("generate_stream") + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + payload = { + "accumulate_tokens": True, + "text_input": request_func_input.prompt, + "temperature": 0.000001, + "top_p": 1.0, + "max_tokens": request_func_input.output_len, + "stream": True, + "min_length": request_func_input.output_len, + "end_id": 1048576, + **request_func_input.extra_request_body, + } + if args.disable_ignore_eos: + del payload["min_length"] + del payload["end_id"] + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data:") + + data = json.loads(chunk) + output.generated_text += data["text_output"] + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + + output.latency = most_recent_timestamp - st + output.success = True + output.output_len = request_func_input.output_len + + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +# set ignore_eos True by default +async def async_request_openai_completions( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith( + "completions" + ), "OpenAI Completions API URL must end with 'completions'." + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + payload = { + "model": request_func_input.model, + "prompt": request_func_input.prompt, + "temperature": 0.0, + "best_of": 1, + "max_tokens": request_func_input.output_len, + "stream": not args.disable_stream, + "ignore_eos": not args.disable_ignore_eos, + **request_func_input.extra_request_body, + } + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") + latency = time.perf_counter() - st + if chunk == "[DONE]": + pass + else: + data = json.loads(chunk) + + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if data["choices"][0]["text"]: + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text += data["choices"][0]["text"] + + output.generated_text = generated_text + output.success = True + output.latency = latency + output.output_len = request_func_input.output_len + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +def get_model(pretrained_model_name_or_path: str) -> str: + if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true": + import huggingface_hub.constants + from modelscope import snapshot_download + + model_path = snapshot_download( + model_id=pretrained_model_name_or_path, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"], + ) + + return model_path + return pretrained_model_name_or_path + + +def get_tokenizer( + pretrained_model_name_or_path: str, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + if pretrained_model_name_or_path is not None and not os.path.exists( + pretrained_model_name_or_path + ): + pretrained_model_name_or_path = get_model(pretrained_model_name_or_path) + return AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=True + ) + + +ASYNC_REQUEST_FUNCS = { + "sglang": async_request_openai_completions, + "vllm": async_request_openai_completions, + "lmdeploy": async_request_openai_completions, + "trt": async_request_trt_llm, +} + + +@dataclass +class BenchmarkMetrics: + completed: int + total_input: int + total_output: int + total_output_retokenized: int + request_throughput: float + input_throughput: float + output_throughput: float + output_throughput_retokenized: float + mean_ttft_ms: float + median_ttft_ms: float + std_ttft_ms: float + p99_ttft_ms: float + mean_tpot_ms: float + median_tpot_ms: float + std_tpot_ms: float + p99_tpot_ms: float + mean_itl_ms: float + median_itl_ms: float + std_itl_ms: float + p99_itl_ms: float + mean_e2e_latency_ms: float + median_e2e_latency_ms: float + + +default_sharegpt_path = "ShareGPT_V3_unfiltered_cleaned_split.json" + + +def download_sharegpt_dataset(path): + url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" + + print(f"Downloading dataset from {url}") + try: + response = requests.get(url, stream=True) + response.raise_for_status() + + total_size = int(response.headers.get("content-length", 0)) + block_size = 8192 + + with open(path, "wb") as f, tqdm( + desc="Downloading", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as progress_bar: + for data in response.iter_content(block_size): + size = f.write(data) + progress_bar.update(size) + + print(f"Dataset downloaded and saved to {path}") + except requests.RequestException as e: + raise Exception(f"Failed to download dataset: {e}") + + +def sample_sharegpt_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int] = None, +) -> List[Tuple[str, int, int]]: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Download sharegpt if necessary + if not os.path.isfile(dataset_path) and not os.path.isfile(default_sharegpt_path): + download_sharegpt_dataset(default_sharegpt_path) + dataset_path = default_sharegpt_path + else: + dataset_path = ( + dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path + ) + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + ] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: List[Tuple[str, int, int]] = [] + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = ( + len(completion_token_ids) if fixed_output_len is None else fixed_output_len + ) + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + filtered_dataset.append((prompt, prompt_len, output_len)) + + return filtered_dataset + + +def sample_random_requests( + input_len: int, + output_len: int, + num_prompts: int, + range_ratio: float, + tokenizer: PreTrainedTokenizerBase, + dataset_path: str, +) -> List[Tuple[str, int, int]]: + + input_lens = np.random.randint( + max(int(input_len * range_ratio), 1), + input_len + 1, + size=num_prompts, + ) + output_lens = np.random.randint( + int(output_len * range_ratio), + output_len + 1, + size=num_prompts, + ) + + if True: + # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens + + # Download sharegpt if necessary + if not os.path.isfile(dataset_path) and not os.path.isfile( + default_sharegpt_path + ): + download_sharegpt_dataset(default_sharegpt_path) + dataset_path = default_sharegpt_path + else: + dataset_path = ( + dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path + ) + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + ] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + input_requests: List[Tuple[str, int, int]] = [] + for i in range(num_prompts): + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + prompt_len = len(prompt_token_ids) + + if prompt_len > input_lens[i]: + input_ids = prompt_token_ids[: input_lens[i]] + else: + ratio = (input_lens[i] + prompt_len - 1) // prompt_len + input_ids = (prompt_token_ids * ratio)[: input_lens[i]] + prompt = tokenizer.decode(input_ids) + input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) + else: + # Sample token ids from random integers. This can cause some NaN issues. + offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) + input_requests = [] + for i in range(num_prompts): + prompt = tokenizer.decode( + [ + (offsets[i] + i + j) % tokenizer.vocab_size + for j in range(input_lens[i]) + ] + ) + input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) + + print(f"#Input tokens: {np.sum(input_lens)}") + print(f"#Output tokens: {np.sum(output_lens)}") + return input_requests + + +async def get_request( + input_requests: List[Tuple[str, int, int]], + request_rate: float, +) -> AsyncGenerator[Tuple[str, int, int], None]: + input_requests = iter(input_requests) + for request in input_requests: + yield request + + if request_rate == float("inf"): + # If the request rate is infinity, then we don't need to wait. + continue + + # Sample the request interval from the exponential distribution. + interval = np.random.exponential(1.0 / request_rate) + # The next request will be sent after the interval. + await asyncio.sleep(interval) + + +def calculate_metrics( + input_requests: List[Tuple[str, int, int]], + outputs: List[RequestFuncOutput], + dur_s: float, + tokenizer: PreTrainedTokenizerBase, + backend: str, +) -> Tuple[BenchmarkMetrics, List[int]]: + output_lens: List[int] = [] + retokenized_output_lens: List[int] = [] + total_input = 0 + completed = 0 + itls: List[float] = [] + tpots: List[float] = [] + ttfts: List[float] = [] + e2e_latencies: List[float] = [] + for i in range(len(outputs)): + if outputs[i].success: + output_len = outputs[i].output_len + output_lens.append(output_len) + retokenized_output_len = len( + tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids + ) + retokenized_output_lens.append(retokenized_output_len) + total_input += input_requests[i][1] + if output_len > 1: + tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1)) + itls += outputs[i].itl + ttfts.append(outputs[i].ttft) + + e2e_latencies.append(outputs[i].latency) + + completed += 1 + else: + output_lens.append(0) + retokenized_output_lens.append(0) + + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2, + ) + metrics = BenchmarkMetrics( + completed=completed, + total_input=total_input, + total_output=sum(output_lens), + total_output_retokenized=sum(retokenized_output_lens), + request_throughput=completed / dur_s, + input_throughput=total_input / dur_s, + output_throughput=sum(output_lens) / dur_s, + output_throughput_retokenized=sum(retokenized_output_lens) / dur_s, + mean_ttft_ms=np.mean(ttfts or 0) + * 1000, # ttfts is empty if streaming is not supported by backend + median_ttft_ms=np.median(ttfts or 0) * 1000, + std_ttft_ms=np.std(ttfts or 0) * 1000, + p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, + mean_tpot_ms=np.mean(tpots or 0) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + std_tpot_ms=np.std(tpots or 0) * 1000, + p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000, + mean_itl_ms=np.mean(itls or 0) * 1000, + median_itl_ms=np.median(itls or 0) * 1000, + std_itl_ms=np.std(itls or 0) * 1000, + p99_itl_ms=np.percentile(itls or 0, 99) * 1000, + mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000, + median_e2e_latency_ms=np.median(e2e_latencies) * 1000, + ) + + return metrics, output_lens + + +async def benchmark( + backend: str, + api_url: str, + model_id: str, + tokenizer: PreTrainedTokenizerBase, + input_requests: List[Tuple[str, int, int]], + request_rate: float, + disable_tqdm: bool, + enable_multi: bool, + extra_request_body: Dict[str, Any], +): + if backend in ASYNC_REQUEST_FUNCS: + request_func = ASYNC_REQUEST_FUNCS[backend] + else: + raise ValueError(f"Unknown backend: {backend}") + + print("Starting initial single prompt test run...") + test_prompt, test_prompt_len, test_output_len = input_requests[0] + + words = test_prompt.split() + + # 使用random.shuffle打乱单词列表 + random.shuffle(words) + + # 将打乱后的单词列表合并回文本 + test_prompt = ' '.join(words) + + + test_input = RequestFuncInput( + model=model_id, + prompt=test_prompt, + api_url=api_url, + prompt_len=test_prompt_len, + output_len=test_output_len, + extra_request_body=extra_request_body, + ) + test_output = await request_func(request_func_input=test_input) + if not test_output.success: + raise ValueError( + "Initial test run failed - Please make sure benchmark arguments " + f"are correctly specified. Error: {test_output.error}" + ) + else: + print("Initial test run completed. Starting main benchmark run...") + + pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + + benchmark_start_time = time.perf_counter() + tasks: List[asyncio.Task] = [] + async for request in get_request(input_requests, request_rate): + prompt, prompt_len, output_len = request + request_func_input = RequestFuncInput( + model=model_id, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + extra_request_body=extra_request_body, + ) + tasks.append( + asyncio.create_task( + request_func(request_func_input=request_func_input, pbar=pbar) + ) + ) + outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) + + if pbar is not None: + pbar.close() + + benchmark_duration = time.perf_counter() - benchmark_start_time + + metrics, output_lens = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + backend=backend, + ) + + print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) + print("{:<40} {:<10}".format("Backend:", backend)) + print("{:<40} {:<10}".format("Traffic request rate:", request_rate)) + print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) + print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) + print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) + print( + "{:<40} {:<10}".format( + "Total generated tokens (retokenized):", metrics.total_output_retokenized + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Request throughput (req/s):", metrics.request_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Input token throughput (tok/s):", metrics.input_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Output token throughput (tok/s):", metrics.output_throughput + ) + ) + print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-")) + print( + "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms) + ) + print( + "{:<40} {:<10.2f}".format( + "Median E2E Latency (ms):", metrics.median_e2e_latency_ms + ) + ) + print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-")) + print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms)) + print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms)) + print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms)) + print( + "{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-") + ) + print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms)) + print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms)) + print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms)) + print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-")) + print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms)) + print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms)) + print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms)) + print("=" * 50) + + if ( + metrics.median_ttft_ms is not None + and metrics.mean_itl_ms is not None + and metrics.output_throughput is not None + ): + result = { + "backend": args.backend, + "dataset_name": args.dataset_name, + "request_rate": request_rate, + "total_input": metrics.total_input, + "total_output": metrics.total_output, + "total_output_retokenized": metrics.total_output_retokenized, + "mean_e2e_latency": metrics.mean_e2e_latency_ms, + "median_e2e_latency": metrics.median_e2e_latency_ms, + "median_ttft": metrics.median_ttft_ms, + "median_itl": metrics.median_itl_ms, + "output_token_throughput": metrics.output_throughput, + "sharegpt_output_len": args.sharegpt_output_len, + "random_input_len": args.random_input_len, + "random_output_len": args.random_output_len, + "random_range_ratio": args.random_range_ratio, + "benchmark_duration": benchmark_duration, + } + else: + print(f"Error running benchmark for request rate: {request_rate}") + print("-" * 30) + + # Determine output file name + if args.output_file: + output_file_name = args.output_file + else: + now = datetime.now().strftime("%m%d") + if args.dataset_name == "random": + output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl" + else: + output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl" + + # Append results to a JSONL file + with open(output_file_name, "a") as file: + file.write(json.dumps(result) + "\n") + + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "total_output_tokens_retokenized": metrics.total_output_retokenized, + "request_throughput": metrics.request_throughput, + "input_throughput": metrics.input_throughput, + "output_throughput": metrics.output_throughput, + "mean_ttft_ms": metrics.mean_ttft_ms, + "median_ttft_ms": metrics.median_ttft_ms, + "std_ttft_ms": metrics.std_ttft_ms, + "p99_ttft_ms": metrics.p99_ttft_ms, + "mean_tpot_ms": metrics.mean_tpot_ms, + "median_tpot_ms": metrics.median_tpot_ms, + "std_tpot_ms": metrics.std_tpot_ms, + "p99_tpot_ms": metrics.p99_tpot_ms, + "mean_itl_ms": metrics.mean_itl_ms, + "median_itl_ms": metrics.median_itl_ms, + "std_itl_ms": metrics.std_itl_ms, + "p99_itl_ms": metrics.p99_itl_ms, + "input_lens": [output.prompt_len for output in outputs], + "output_lens": output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], + "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, + "median_e2e_latency_ms": metrics.median_e2e_latency_ms, + } + return result + + +def parse_request_rate_range(request_rate_range): + if len(request_rate_range.split(",")) == 3: + start, stop, step = map(int, request_rate_range.split(",")) + return list(range(start, stop, step)) + else: + return list(map(int, request_rate_range.split(","))) + + +def check_chat_template(model_path): + try: + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + return "chat_template" in tokenizer.init_kwargs + except Exception as e: + print(f"Fail to load tokenizer config with error={e}") + return False + + +def run_benchmark(args_: argparse.Namespace): + global args + args = args_ + + set_ulimit() + random.seed(args.seed) + np.random.seed(args.seed) + + extra_request_body = {} + if args.extra_request_body: + extra_request_body = json.loads(args.extra_request_body) + + if args.port is None: + args.port = { + "sglang": 30000, + "lmdeploy": 23333, + "vllm": 8000, + "trt": 8000, + }.get(args.backend, 30000) + + api_url = ( + f"{args.base_url}/v1/completions" + if args.base_url + else f"http://{args.host}:{args.port}/v1/completions" + ) + model_url = ( + f"{args.base_url}/v1/models" + if args.base_url + else f"http://{args.host}:{args.port}/v1/models" + ) + + if args.backend == "trt": + api_url = ( + f"{args.base_url}/v2/models/ensemble/generate_stream" + if args.base_url + else f"http://{args.host}:{args.port}/v2/models/ensemble/generate_stream" + ) + if args.model is None: + print("Please provide a model using `--model` when using `trt` backend.") + sys.exit(1) + + if args.model is None: + try: + response = requests.get(model_url) + model_list = response.json().get("data", []) + args.model = model_list[0]["id"] if model_list else None + except Exception as e: + print(f"Failed to fetch model from {model_url}. Error: {e}") + print( + "Please specify the correct host and port using `--host` and `--port`." + ) + sys.exit(1) + + if args.model is None: + print("No model specified or found. Please provide a model using `--model`.") + sys.exit(1) + + if not check_chat_template(args.model): + print( + "\nWARNING It is recommended to use the `Chat` or `Instruct` model for benchmarking.\n" + "Because when the tokenizer counts the output tokens, if there is gibberish, it might count incorrectly.\n" + ) + + print(f"{args}\n") + + backend = args.backend + model_id = args.model + tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model + + tokenizer = get_tokenizer(tokenizer_id) + + if args.dataset_name == "sharegpt": + input_requests = sample_sharegpt_requests( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + tokenizer=tokenizer, + fixed_output_len=args.sharegpt_output_len, + ) + elif args.dataset_name == "random": + input_requests = sample_random_requests( + input_len=args.random_input_len, + output_len=args.random_output_len, + num_prompts=args.num_prompts, + range_ratio=args.random_range_ratio, + tokenizer=tokenizer, + dataset_path=args.dataset_path, + ) + else: + raise ValueError(f"Unknown dataset: {args.dataset_name}") + + if args.multi: + request_rates = parse_request_rate_range(args.request_rate_range) + + for rate in request_rates: + asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_rate=rate, + disable_tqdm=args.disable_tqdm, + enable_multi=args.multi, + extra_request_body=extra_request_body, + ) + ) + else: + return asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_rate=args.request_rate, + disable_tqdm=args.disable_tqdm, + enable_multi=args.multi, + extra_request_body=extra_request_body, + ) + ) + + +# to avoid relying on SGLang's components +def set_ulimit(target_soft_limit=65535): + resource_type = resource.RLIMIT_NOFILE + current_soft, current_hard = resource.getrlimit(resource_type) + + if current_soft < target_soft_limit: + try: + resource.setrlimit(resource_type, (target_soft_limit, current_hard)) + except ValueError as e: + print(f"Fail to set RLIMIT_NOFILE: {e}") + + +if __name__ == "__main__": + parser = ArgumentParser(description="Benchmark the online serving throughput.") + parser.add_argument( + "--backend", + type=str, + choices=list(ASYNC_REQUEST_FUNCS.keys()), + default="sglang", + help="Must specify a backend, depending on the LLM Inference Engine.", + ) + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server or API base url if not using http host and port.", + ) + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0." + ) + parser.add_argument( + "--port", + type=int, + help="If not set, the default port is configured according to its default value for different LLM Inference Engines.", + ) + parser.add_argument( + "--dataset-name", + type=str, + default="sharegpt", + choices=["sharegpt", "random"], + help="Name of the dataset to benchmark on.", + ) + parser.add_argument( + "--dataset-path", type=str, default="", help="Path to the dataset." + ) + parser.add_argument( + "--model", + type=str, + help="Name or path of the model. If not set, the default model will request /v1/models for conf.", + ) + parser.add_argument( + "--tokenizer", + type=str, + help="Name or path of the tokenizer. If not set, using the model conf.", + ) + parser.add_argument( + "--num-prompts", + type=int, + default=1000, + help="Number of prompts to process. Default is 1000.", + ) + parser.add_argument( + "--sharegpt-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output length from the ShareGPT dataset.", + ) + parser.add_argument( + "--random-input-len", + type=int, + default=1024, + help="Number of input tokens per request, used only for random dataset.", + ) + parser.add_argument( + "--random-output-len", + type=int, + default=128, + help="Number of output tokens per request, used only for random dataset.", + ) + parser.add_argument( + "--random-range-ratio", + type=float, + default=0.0, + help="Range of sampled ratio of input/output length, " + "used only for random dataset.", + ) + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, then all the requests are sent at time 0. " + "Otherwise, we use Poisson process to synthesize the request arrival times. Default is 128.0.", + ) + parser.add_argument("--seed", type=int, default=0, help="Default is 0.") + parser.add_argument( + "--multi", + action="store_true", + help="Use request rate range rather than single value.", + ) + parser.add_argument( + "--request-rate-range", + type=str, + default="2,34,2", + help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.", + ) + parser.add_argument("--output-file", type=str, help="Output JSONL file name.") + parser.add_argument( + "--disable-tqdm", + action="store_true", + help="Specify to disable tqdm progress bar.", + ) + parser.add_argument( + "--disable-stream", + action="store_true", + help="Disable streaming mode.", + ) + parser.add_argument( + "--disable-ignore-eos", + action="store_true", + help="Disable ignoring EOS.", + ) + parser.add_argument( + "--extra-request-body", + metavar='{"key1": "value1", "key2": "value2"}', + type=str, + help="Append given JSON object to the request payload. You can use this to specify" + "additional generate params like sampling params.", + ) + + + parser.add_argument( + "--search-qps-start", + type=float, + default=0.0, + help="Binary search best QPS, start request_rate, real request rate will multi dp_size", + ) + + parser.add_argument( + "--search-qps-end", + type=int, + default=0, + help="Binary search best QPS, end request_rate, real request rate will multi dp_size", + ) + + parser.add_argument( + "--search-qps-thread", + type=int, + default=0, + help="The qps value that we want(ms)", + ) + + + args = parser.parse_args() + + mean_ttft_ms = 0.0 + search_qps_thread = args.search_qps_thread + search_qps_start = args.search_qps_start + search_qps_end = args.search_qps_end + + + qps_result_dict = [] + while abs(mean_ttft_ms - search_qps_thread) < 500 and search_qps_start < search_qps_end: # 误差在0.5s以内都可以 + mid_qps = (search_qps_start + search_qps_end) / 2 + args.request_rate = mid_qps + result = run_benchmark(args) + mean_ttft_ms = float(result['mean_ttft_ms']) + + qps_result_dict.append({ + "request_rate": mid_qps, + "mean_ttft_ms": mean_ttft_ms + }) + + if mean_ttft_ms < search_qps_thread: + # 增加mean_ttft_ms就是增加request_rate + search_qps_end = mid_qps + else: + search_qps_start = mid_qps + + + print(f"request_rate\t\tmean_ttft_ms") + for item in qps_result_dict: + print(f"{item['request_rate']}\t\t{item['mean_ttft_ms']}") \ No newline at end of file diff --git a/python/sglang/srt/managers/controller_flex.py b/python/sglang/srt/managers/controller_flex.py new file mode 100644 index 0000000000..ba9a6ccaa4 --- /dev/null +++ b/python/sglang/srt/managers/controller_flex.py @@ -0,0 +1,533 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +""" +A controller that manages multiple data parallel workers. +Each data parallel worker can manage multiple tensor parallel workers. +""" +import dataclasses +import logging +import multiprocessing +import multiprocessing.shared_memory +import os +import random +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor +from enum import Enum, auto + +import numpy as np +import torch +import zmq + + +def _key_match(key0, key1): + i = 0 + for k0, k1 in zip(key0, key1): + if k0 != k1: + break + i += 1 + return i + + +def get_match_len(node, key, match_length: int) -> int: + if len(key) == 0: + return match_length + + if key[0] in node.children.keys(): + child = node.children[key[0]] + prefix_len = _key_match(child.key, key) + match_length += prefix_len + if prefix_len < len(child.key): + return match_length + else: + return get_match_len(child, key[prefix_len:], match_length) + else: + return match_length + + +import threading +import time + +from sglang.srt.managers.controller_single import ( + start_controller_process as start_controller_process_single, +) +from sglang.srt.managers.io_struct import ( + AbortReq, + ControllerInfo, + FlushCacheReq, + TokenizedGenerateReqInput, +) +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.utils import kill_parent_process +from sglang.utils import get_cache_info, get_exception_traceback + +logger = logging.getLogger(__name__) + + +class LoadBalanceMethod(Enum): + """Load balance method.""" + + ROUND_ROBIN = auto() + SHORTEST_QUEUE = auto() + RESOURCES_AWARE = auto() + POWER_OF_2_CHOICE = auto() + PRE_RADIX = auto() + + @classmethod + def from_str(cls, method: str): + method = method.upper() + try: + return cls[method] + except KeyError as exc: + raise ValueError(f"Invalid load balance method: {method}") from exc + + +@dataclasses.dataclass +class WorkerHandle: + """Store the handle of a data parallel worker.""" + + proc: multiprocessing.Process + queue: multiprocessing.Queue + + +# class FlexScheduler: +# """A scheduler which dispatch """ + + +class ControllerMultiFlex: + """A controller that manages multiple data parallel workers.""" + + def __init__( + self, + server_args: ServerArgs, + port_args: PortArgs, + model_overide_args, + ): + # Parse args + self.server_args = server_args + self.port_args = port_args + self.model_overide_args = model_overide_args + self.load_balance_method = LoadBalanceMethod.from_str( + server_args.load_balance_method + ) + + # Init communication + context = zmq.Context() + self.recv_from_tokenizer = context.socket(zmq.PULL) + self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}") + + self.recv_from_tree_cache = context.socket(zmq.PULL) + self.recv_from_tree_cache.setsockopt(zmq.RCVHWM, 1000) + self.recv_from_tree_cache.bind(f"tcp://127.0.0.1:41935") + + self.pre_radix = server_args.load_balance_method == "pre_radix" + self.dp_size = server_args.dp_size + + # Dispatch method + self.round_robin_counter = 0 + dispatch_lookup = { + LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler, + LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler, + LoadBalanceMethod.RESOURCES_AWARE: self.resources_aware_scheduler, + LoadBalanceMethod.POWER_OF_2_CHOICE: self.power_of_2_choice, + LoadBalanceMethod.PRE_RADIX: self.pre_radix_scheduler, + } + self.dispatching = dispatch_lookup[self.load_balance_method] + + self.newest_tree_cache = {} + + # Start data parallel workers + self.workers = [] + self.controller_info = ControllerInfo(server_args, model_overide_args) + for i in range(server_args.dp_size): + self.start_dp_worker(i) + + self.scheduler_time = 0 + + self.cnt = 0 + + if self.pre_radix: + self.recv_tree_cache_lock = threading.Lock() + threading.Thread(target=self.loop_for_recv_tree_cache).start() + + def start_dp_worker(self, dp_worker_id: int): + tp_size = self.server_args.tp_size + + pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe( + duplex=False + ) + + gpu_ids = list(range(dp_worker_id * tp_size, (dp_worker_id + 1) * tp_size)) + queue = multiprocessing.Queue() + proc = multiprocessing.Process( + target=start_controller_process_single, + args=( + self.server_args, + self.port_args, + pipe_controller_writer, + self.model_overide_args, + True, + gpu_ids, + dp_worker_id, + queue, + self.controller_info, + ), + ) + proc.start() + + controller_init_state = pipe_controller_reader.recv() + if controller_init_state != "init ok": + raise RuntimeError( + f"Initialization failed. controller_init_state: {controller_init_state}" + ) + self.workers.append( + WorkerHandle( + proc=proc, + queue=queue, + ) + ) + + def compute_prefix_length(self, gpu_id, radix_cache, input_ids): + return gpu_id, get_match_len(radix_cache.root_node, input_ids, 0) + + def pre_radix_scheduler(self, input_requests): + if len(input_requests) == 0: + return + + # available_mem = [k.value for k in self.controller_info.available_kv_cache] + # num_reqs_waiting = [k.value for k in self.controller_info.waiting_reqs] + # num_reqs_running = [k.value for k in self.controller_info.running_reqs] + + # all_waitting = False + # if min(num_reqs_waiting) > 0: + # # 最小值都大于0,全部waiting + # all_waitting = True + # else: + # # 最小值都是0, 则全部waiting + # all_waitting = False + # # 选出不waiting + # no_waiting = [1 if waiting == 0 else 0 for waiting in num_reqs_waiting] + + # num_reqs_waiting = [k.value for k in self.controller_info.waiting_reqs] + + for r in input_requests: + prefix_lens = [0] * self.dp_size + + with self.recv_tree_cache_lock: + for gpu_id, radix_cache in self.newest_tree_cache.items(): + # t_1 = time.time() + pre_len = get_match_len(radix_cache.root_node, r.input_ids, 0) + # t_2 = time.time() + prefix_lens[gpu_id] = pre_len + + # with ThreadPoolExecutor() as executor: + # futures = [] + # for gpu_id, radix_cache in self.newest_tree_cache.items(): + # future = executor.submit( + # self.compute_prefix_length, + # gpu_id, + # radix_cache, + # r.input_ids, + # ) + # futures.append(future) + + # for future in futures: + # gpu_id, pre_len = future.result() + # prefix_lens[gpu_id] = pre_len + + # t4 = time.time() + # with open("match.log", "a+") as f: + # f.write(f"[rid={r.rid[:5]}]{prefix_lens}\n") + + # t7 = time.time() + max_len = max(prefix_lens) + max_len_indices = [i for i, x in enumerate(prefix_lens) if x == max_len] + # t8 = time.time() + + # logger.info(f"find max idx = {t8 - t7}") + + if len(max_len_indices) == 1: + # t9 = time.time() + selected_worker_index = max_len_indices[0] + self.workers[selected_worker_index].queue.put(r) + # t10 = time.time() + # logger.info(f"len one = {t10 - t9}") + # t5 = time.time() + # logger.info(f"if time = {t5 - t4}") + else: + self.resources_aware_scheduler([r]) + # t11 = time.time() + # if all_waitting: + # # 全部waiting,选最小的 + + # ratio = [ + # run / wait + # for run, wait in zip(num_reqs_running, num_reqs_waiting) + # ] + + # # run越大 认为后续释放的可能性越多,wait越少,说明后续计算能力更强 + # min_value = max(ratio) + # # 找到所有最小值的索引 + # min_indices = [i for i, x in enumerate(ratio) if x == min_value] + # # 从这些索引中随机选择一个 + # index = random.choice(min_indices) + # # 从waitting最小的找到available最大的 + # # index = max(min_indices, key=lambda i: available_mem[i]) + # # index = min(min_indices, key=lambda i: num_reqs_running[i]) + # self.workers[index].queue.put(r) + # num_reqs_waiting[index] += 1 + # available_mem[index] -= len(r.input_ids) + + # else: + # # 选出不waiting的且available mem最大的 + # # no_waiting 和available做乘法,找最大 + + # filter_result = [a * b for a, b in zip(no_waiting, available_mem)] + # index = filter_result.index(max(filter_result)) + # self.workers[index].queue.put(r) + + # # num_reqs_running[index] += 1 + # available_mem[index] -= len(r.input_ids) + # t12 = time.time() + # logger.info(f"len two = {t12 - t11}") + # t5 = time.time() + # logger.info(f"else time = {t5 - t4}") + # t6 = time.time() + # logger.info(f"real dispatch time = {t6 - t8}") + + def resources_aware_scheduler(self, input_requests): + if len(input_requests) == 0: + return + # remained_token = [k.value for k in self.controller_info.waiting_prefill_compute] + available_mem = [k.value for k in self.controller_info.available_kv_cache] + num_reqs_waiting = [k.value for k in self.controller_info.waiting_reqs] + num_reqs_running = [k.value for k in self.controller_info.running_reqs] + # with open('three_list.txt', 'a') as file: # 'a' 模式表示追加到文件末尾 + # print(f"available_mem={available_mem""" """}\nnum_reqs_waiting={num_reqs_waiting}\nnum_reqs_running={num_reqs_running}\n") + + # ava_resource = available_mem.copy() + # =======================method2======================= + # # 认为available + waiting为可用资源 + # for i in range(len(self.workers)): + # q = self.workers[i].queue + # qsize = q.qsize() + # for _ in range(qsize): + # req = q.get() + # ava_resource[i] += len(req.input_ids) + # q.put(req) # 将元素重新放回原队列 + + # # 选择ava最大的调度 + # for r in input_requests: + # index = ava_resource.index(max(ava_resource)) + # self.workers[index].queue.put(r) + # ava_resource[index] -= len(r.input_ids) + + # =======================method2======================= + + # =======================method1======================= + + # 判断是否是全部waiting + all_waitting = False + if min(num_reqs_waiting) > 0: + # 最小值都大于0,全部waiting + all_waitting = True + else: + # 最小值都是0, 则全部waiting + all_waitting = False + # 选出不waiting + no_waiting = [1 if waiting == 0 else 0 for waiting in num_reqs_waiting] + for r in input_requests: + # t1 = time.time() + if all_waitting: + # 全部waiting,选最小的 + + ratio = [ + run / wait for run, wait in zip(num_reqs_running, num_reqs_waiting) + ] + + # run越大 认为后续释放的可能性越多,wait越少,说明后续计算能力更强 + min_value = max(ratio) + # 找到所有最小值的索引 + min_indices = [i for i, x in enumerate(ratio) if x == min_value] + # 从这些索引中随机选择一个 + index = random.choice(min_indices) + # 从waitting最小的找到available最大的 + # index = max(min_indices, key=lambda i: available_mem[i]) + # index = min(min_indices, key=lambda i: num_reqs_running[i]) + self.workers[index].queue.put(r) + num_reqs_waiting[index] += 1 + available_mem[index] -= len(r.input_ids) + else: + # 选出不waiting的且available mem最大的 + # no_waiting 和available做乘法,找最大 + + filter_result = [a * b for a, b in zip(no_waiting, available_mem)] + index = filter_result.index(max(filter_result)) + self.workers[index].queue.put(r) + + # num_reqs_running[index] += 1 + available_mem[index] -= len(r.input_ids) + # t2 = time.time() + # logger.info(f"real dispatch time = {t2 - t1}") + + # =======================method1======================= + + def power_of_2_choice(self, input_requests): + if len(input_requests) == 0: + return + num_reqs_waiting = [k.value for k in self.controller_info.waiting_reqs] + num_reqs_running = [k.value for k in self.controller_info.running_reqs] + available_mem = [k.value for k in self.controller_info.available_kv_cache] + + instances_len = len(self.workers) + + # 比较两个worker的指标 + def compare_metrics(ins1, ins2): + if num_reqs_waiting[ins1] != num_reqs_waiting[ins2]: + return ins1 if num_reqs_waiting[ins1] < num_reqs_waiting[ins2] else ins2 + if num_reqs_running[ins1] != num_reqs_running[ins2]: + return ins1 if num_reqs_running[ins1] < num_reqs_running[ins2] else ins2 + if available_mem[ins1] != available_mem[ins2]: + return ins1 if available_mem[ins1] > available_mem[ins2] else ins2 + return ins1 + + for r in input_requests: + # 随机选两个worker + ins1, ins2 = random.sample(range(0, instances_len), 2) + ins_end = compare_metrics(ins1, ins2) + self.workers[ins_end].queue.put(r) + # available_mem[ins_end] -= len(r.input_ids) + # num_reqs_running[ins_end] += 1 + # num_reqs_waiting[ins_end] += 1 + + def round_robin_scheduler(self, input_requests): + for r in input_requests: + self.workers[self.round_robin_counter].queue.put(r) + self.round_robin_counter = (self.round_robin_counter + 1) % len( + self.workers + ) + + def shortest_queue_scheduler(self, input_requests): + for r in input_requests: + queue_sizes = [worker.queue.qsize() for worker in self.workers] + wid = np.argmin(queue_sizes) + self.workers[wid].queue.put(r) + + def loop_for_forward(self): + while True: + recv_reqs = self.recv_requests() + + if len(recv_reqs) != 0: + # logger.info(f"len requests=[{len(recv_reqs)}]") + t1 = time.time() + + # if self.pre_radix: + # self.recv_tree_cache() + + self.dispatching(recv_reqs) + t2 = time.time() + logger.info(f"scheduler time = {t2 - t1}") + + def loop_for_recv_tree_cache(self): + while True: + self.recv_tree_cache() + + def recv_tree_cache(self): + flag = False + while True: + try: + recv_radix_cache = self.recv_from_tree_cache.recv_pyobj(zmq.NOBLOCK) + except zmq.ZMQError: + break + + gpu_id = recv_radix_cache.gpu_id + + if ( + gpu_id not in self.newest_tree_cache + or recv_radix_cache.time > self.newest_tree_cache[gpu_id].time + ): + with self.recv_tree_cache_lock: + if gpu_id in self.newest_tree_cache: + del self.newest_tree_cache[gpu_id] + self.newest_tree_cache[gpu_id] = recv_radix_cache + flag = True + + del recv_radix_cache + # 使用日志记录器记录信息 + if flag: + # logger.info(f"latest_cache={len(self.newest_tree_cache)}") + pass + torch.cuda.empty_cache() # 清空未被引用的显存 + + def recv_requests(self): + recv_reqs = [] + + while True: + try: + recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK) + except zmq.ZMQError: + break + + if isinstance(recv_req, FlushCacheReq): + # TODO(lsyin): apply more specific flushCacheReq + for worker in self.workers: + worker.queue.put(recv_req) + elif isinstance(recv_req, AbortReq): + in_queue = False + for i, req in enumerate(recv_reqs): + if req.rid == recv_req.rid: + recv_reqs[i] = recv_req + in_queue = True + break + if not in_queue: + # Send abort req to all TP groups + for worker in self.workers: + worker.queue.put(recv_req) + elif isinstance(recv_req, TokenizedGenerateReqInput): + recv_reqs.append(recv_req) + else: + logger.error(f"Invalid object: {recv_req}") + + return recv_reqs + + +def start_controller_process( + server_args: ServerArgs, + port_args: PortArgs, + pipe_writer, + model_overide_args: dict, +): + """Start a controller process.""" + + logging.basicConfig( + level=getattr(logging, server_args.log_level.upper()), + format="%(message)s", + ) + + try: + controller = ControllerMultiFlex(server_args, port_args, model_overide_args) + except Exception: + pipe_writer.send(get_exception_traceback()) + raise + + pipe_writer.send("init ok") + + try: + controller.loop_for_forward() + except Exception: + logger.error("Exception in ControllerMultiFlex:\n" + get_exception_traceback()) + finally: + for w in controller.workers: + os.kill(w.proc.pid, 9) + kill_parent_process() diff --git a/python/sglang/srt/managers/controller_single.py b/python/sglang/srt/managers/controller_single.py index 415325b131..53b52adcfc 100644 --- a/python/sglang/srt/managers/controller_single.py +++ b/python/sglang/srt/managers/controller_single.py @@ -20,8 +20,10 @@ import os from typing import List +import numpy as np import zmq +from sglang.srt.managers.io_struct import ControllerInfo from sglang.srt.managers.tp_worker import ( ModelTpServer, broadcast_recv_input, @@ -29,6 +31,7 @@ ) from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import kill_parent_process +from sglang.srt.managers.io_struct import ControllerInfo from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) @@ -46,12 +49,15 @@ def __init__( is_data_parallel_worker: bool, dp_worker_id: int, mp_queue: multiprocessing.Queue, + controller_info: ControllerInfo = None, ): # Parse args self.tp_size = server_args.tp_size self.is_dp_worker = is_data_parallel_worker self.dp_worker_id = dp_worker_id self.mp_queue = mp_queue + # Need by multi flex infer + self.controller_info = controller_info # Init communication context = zmq.Context(2) @@ -87,6 +93,8 @@ def __init__( server_args, port_args.nccl_ports[dp_worker_id], model_overide_args, + controller_info, + dp_worker_id, ) self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group @@ -132,6 +140,7 @@ def start_controller_process( gpu_ids: List[int] = None, dp_worker_id: int = None, queue: multiprocessing.connection.Connection = None, + controller_info: ControllerInfo = None, ): """Start a controller process.""" @@ -155,6 +164,7 @@ def start_controller_process( is_data_parallel_worker, dp_worker_id, queue, + controller_info, ) except Exception: pipe_writer.send(get_exception_traceback()) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 2d12505ae4..7f4fe7e65d 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -18,14 +18,21 @@ processes (TokenizerManager, DetokenizerManager, Controller). """ +import multiprocessing import uuid +import multiprocessing from dataclasses import dataclass +from multiprocessing import Value from typing import Dict, List, Optional, Union +from multiprocessing import Value +import numpy as np import torch +import numpy as np from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.sampling_params import SamplingParams +from sglang.utils import get_cache_info @dataclass @@ -263,3 +270,31 @@ class AbortReq: @dataclass class DetokenizeReqInput: input_ids: List[int] + + +class ControllerInfo: + def __init__(self, server_args, model_overide_args): + self.available_kv_cache = [] + self.waiting_prefill_compute = [] + self.running_reqs = [] + self.waiting_reqs = [] + self.swap_in_queue = [] + self.lock = multiprocessing.Lock() + for i in range(server_args.dp_size): + self.available_kv_cache.append(Value("i", 0)) + self.waiting_prefill_compute.append(Value("i", 0)) + self.running_reqs.append(Value("i", 0)) + self.waiting_reqs.append(Value("i", 0)) + self.swap_in_queue.append(multiprocessing.Queue()) + self.swap_out_queue = multiprocessing.Queue() + cache_shape = get_cache_info(server_args, model_overide_args) + + # TODO: Make it editable by user @kavioyu + cpu_cache_num = 10240 + self.cache_shape = (10240,) + cache_shape + dtype_size = 2 # support float16 or bfloat16 + cache_size = np.product(self.cache_shape) * dtype_size + + shm = multiprocessing.shared_memory.SharedMemory(create=True, size=cache_size) + self.cpu_kv_cache = shm.name + del shm diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index a461fa1812..d72604ef26 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -432,6 +432,14 @@ def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor reqs = self.reqs input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] extend_num_tokens = sum(len(ids) for ids in input_ids) + # num = 0 + # for req in reqs: + # num += len(req.origin_input_ids) + # if num != extend_num_tokens: + # print("*" * 100) + # print(extend_num_tokens) + # for req in reqs: + # print(len(req.origin_input_ids)) seq_lens = [] # Allocate memory diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 4c757737ec..8a75eed871 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -21,6 +21,7 @@ import pickle import time import warnings +from multiprocessing import shared_memory from typing import Any, List, Optional, Union import torch @@ -36,6 +37,7 @@ AbortReq, BatchEmbeddingOut, BatchTokenIDOut, + ControllerInfo, FlushCacheReq, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, @@ -76,12 +78,17 @@ def __init__( server_args: ServerArgs, nccl_port: int, model_overide_args: dict, + controller_info: Optional[ControllerInfo] = None, + dp_worker_id: Optional[int] = None, ): + suppress_other_loggers() # Copy arguments self.gpu_id = gpu_id self.tp_rank = tp_rank + + self.dp_rank = dp_worker_id self.tp_size = server_args.tp_size self.dp_size = server_args.dp_size self.schedule_policy = server_args.schedule_policy @@ -107,6 +114,20 @@ def __init__( nccl_port=nccl_port, server_args=server_args, ) + + # Flex DP inference + if controller_info: + self.controller_info = controller_info + shm = shared_memory.SharedMemory(self.controller_info.cpu_kv_cache) + self.swap_cache = torch.frombuffer( + buffer=shm.buf, dtype=self.model_runner.dtype + ).reshape(self.controller_info.cache_shape) + self.controller_info.available_kv_cache[self.dp_rank].value = ( + self.model_runner.token_to_kv_pool.available_size() + ) + else: + self.controller_info = None + if server_args.skip_tokenizer_init: self.tokenizer = self.processor = None else: @@ -165,6 +186,8 @@ def __init__( req_to_token_pool=self.model_runner.req_to_token_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool, disable=server_args.disable_radix_cache, + gpu_id=gpu_id, + pre_radix=(server_args.load_balance_method == "pre_radix"), ) self.tree_cache_metrics = {"total": 0, "hit": 0} self.scheduler = PolicyScheduler(self.schedule_policy, self.tree_cache) @@ -176,6 +199,7 @@ def __init__( self.running_batch: ScheduleBatch = None self.out_pyobjs = [] self.decode_forward_ct = 0 + self.forward_ct = 0 self.stream_interval = server_args.stream_interval self.num_generated_tokens = 0 self.last_stats_tic = time.time() @@ -252,8 +276,9 @@ def forward_step(self): self.forward_decode_batch(self.running_batch) # Print stats - if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0: + if self.tp_rank == 0 and self.decode_forward_ct % 10 == 0: self.print_decode_stats() + pass if self.running_batch.is_empty(): self.running_batch = None @@ -281,6 +306,13 @@ def print_decode_stats(self): f"#queue-req: {len(self.waiting_queue)}" ) + with open( + f"token_usage_gpu_{self.gpu_id}.log", mode="a+", encoding="utf-8" + ) as f: + f.write( + f"{self.gpu_id}\t\t{num_used / self.max_total_num_tokens:.2f}\t\t{len(self.waiting_queue)}\n" + ) + def check_memory(self): available_size = ( self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() @@ -358,7 +390,6 @@ def handle_generate_request( ), self.max_req_input_len - 1 - len(req.origin_input_ids), ) - self.waiting_queue.append(req) def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: @@ -415,9 +446,13 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: adder.log_input_tokens + adder.log_hit_tokens ) / 10**9 self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9 - tree_cache_hit_rate = ( - self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"] - ) + + try: + tree_cache_hit_rate = ( + self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"] + ) + except ZeroDivisionError: + tree_cache_hit_rate = 1.0 logger.info( f"[gpu={self.gpu_id}] Prefill batch. " f"#new-seq: {len(can_run_list)}, " @@ -444,6 +479,27 @@ def forward_prefill_batch(self, batch: ScheduleBatch): self.model_config.vocab_size, self.int_token_logit_bias ) + if self.controller_info: + num = 0 + for req in batch.reqs: + num += len(req.origin_input_ids) + with self.controller_info.lock: + self.controller_info.waiting_prefill_compute[self.dp_rank].value -= num + self.controller_info.available_kv_cache[self.dp_rank].value = ( + self.token_to_kv_pool.available_size() + + self.tree_cache.evictable_size() + ) + self.controller_info.running_reqs[ + self.dp_rank + ].value = batch.batch_size() + ( + self.running_batch.batch_size() + if self.running_batch is not None + else 0 + ) + + self.controller_info.waiting_reqs[self.dp_rank].value = len( + self.waiting_queue + ) if self.model_runner.is_generation: # Forward and sample the next tokens if batch.extend_num_tokens != 0: @@ -583,17 +639,26 @@ def forward_decode_batch(self, batch: ScheduleBatch): self.new_token_ratio = new_token_ratio logger.info( - "decode out of memory happened, " + f"[gpu{self.gpu_id}]decode out of memory happened, " f"#retracted_reqs: {len(retracted_reqs)}, " f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}" ) + + num = 0 + for req in retracted_reqs: + num += len(req.fill_ids) + + if self.controller_info is not None: + with self.controller_info.lock: + self.controller_info.waiting_prefill_compute[ + self.dp_rank + ].value += num self.waiting_queue.extend(retracted_reqs) else: self.new_token_ratio = max( self.new_token_ratio - self.new_token_ratio_decay, self.min_new_token_ratio, ) - if not self.disable_regex_jump_forward: # Check for jump-forward jump_forward_reqs = batch.check_for_jump_forward(self.model_runner) @@ -605,6 +670,20 @@ def forward_decode_batch(self, batch: ScheduleBatch): self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30) batch.prepare_for_decode() + if self.controller_info is not None and self.decode_forward_ct % 10 == 0: + with self.controller_info.lock: + self.controller_info.available_kv_cache[self.dp_rank].value = ( + self.token_to_kv_pool.available_size() + + self.tree_cache.evictable_size() + ) + self.controller_info.running_reqs[self.dp_rank].value = ( + batch.batch_size() + ) + + self.controller_info.waiting_reqs[self.dp_rank].value = len( + self.waiting_queue + ) + # Forward and sample the next tokens output = self.model_runner.forward(batch, ForwardMode.DECODE) next_token_ids = batch.sample(output.next_token_logits) @@ -636,6 +715,12 @@ def forward_decode_batch(self, batch: ScheduleBatch): self.handle_finished_requests(batch) + def swap_in_decode_request(self, req: Req): + pass + + def swap_out_decode_request(self, req: Req): + pass + def handle_finished_requests(self, batch: ScheduleBatch): output_rids = [] output_meta_info = [] diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index a1c685405a..a089ee5285 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -55,20 +55,87 @@ def _key_match(key0: List, key1: List): return i +import threading +from copy import deepcopy +from dataclasses import dataclass + +import zmq + + +@dataclass +class RadixCacheSend: + gpu_id: int + root_node: TreeNode + time: time + + class RadixCache(BasePrefixCache): def __init__( self, req_to_token_pool: ReqToTokenPool, token_to_kv_pool: BaseTokenToKVPool, disable: bool = False, + gpu_id: int = 0, + pre_radix: bool = False, ): self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool = token_to_kv_pool self.disable = disable + self.pre_radix = pre_radix + self.gpu_id = gpu_id + self.send_cnt = 0 + self.change_cnt = 0 + if pre_radix: + context = zmq.Context() + self.send_radix_tree = context.socket(zmq.PUSH) + self.send_radix_tree.setsockopt(zmq.SNDHWM, 1000) + self.send_radix_tree.connect(f"tcp://127.0.0.1:41935") + + self.change_cnt_lock = threading.Lock() + + threading.Thread(target=self.loop_for_send_tree_cache).start() + + else: + print(f"dp[{self.gpu_id}] will not use pre_radix....") + self.reset() ##### Public API ##### + def loop_for_send_tree_cache(self): + while True: + with self.change_cnt_lock: + if self.change_cnt != 0: + self.change_cnt -= 1 + self.send_prefix_tree() + time.sleep(0.5) + + def send_prefix_tree(self): + # t1 = time.time() + if self.pre_radix: + self.send_cnt += 1 + try: + try: + node = deepcopy(self.root_node) + except Exception as e: + return + self.send_radix_tree.send_pyobj( + RadixCacheSend( + gpu_id=self.gpu_id, root_node=node, time=time.time() + ), + zmq.NOBLOCK, + ) + # if self.send_cnt % 10 == 0: + # print(f"[{self.gpu_id}] has send [{self.send_cnt}] caches") + del node + # torch.cuda.empty_cache() + except zmq.Again as e: + print( + "=======================================Radix Cache Queue is full, drop out new radix cache tree=======================================" + ) + # t2 = time.time() + # print(f"send radix time = {t2 - t1}") + def reset(self): self.root_node = TreeNode() self.root_node.key = [] @@ -76,6 +143,8 @@ def reset(self): self.root_node.lock_ref = 1 self.evictable_size_ = 0 + # self.send_prefix_tree() + def match_prefix(self, key: List, **kwargs): if self.disable: return [], self.root_node @@ -95,7 +164,12 @@ def insert(self, key: List, value=None): if value is None: value = [x for x in key] - return self._insert_helper(self.root_node, key, value) + res = self._insert_helper(self.root_node, key, value) + + if self.pre_radix: + with self.change_cnt_lock: + self.change_cnt += 1 + return res def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None): """Cache request when it finishes.""" @@ -118,6 +192,8 @@ def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None): self.req_to_token_pool.free(req.req_pool_idx) self.dec_lock_ref(req.last_node) + # self.send_prefix_tree() + def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None): """Cache request when it is unfinished.""" if self.disable: @@ -146,6 +222,8 @@ def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None): req.prefix_indices = new_indices req.last_node = new_last_node + # self.send_prefix_tree() + def pretty_print(self): self._print_helper(self.root_node, 0) print(f"#tokens: {self.total_size()}") @@ -175,6 +253,10 @@ def evict(self, num_tokens: int, evict_callback: Callable): if len(x.parent.children) == 0: heapq.heappush(leaves, x.parent) + # self.send_prefix_tree() + if self.pre_radix: + with self.change_cnt_lock: + self.change_cnt += 1 def inc_lock_ref(self, node: TreeNode): if self.disable: diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 8b67663357..2ba4a245a0 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -44,6 +44,9 @@ from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.srt.constrained import disable_cache from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.managers.controller_flex import ( + start_controller_process as start_controller_process_flex, +) from sglang.srt.managers.controller_multi import ( start_controller_process as start_controller_process_multi, ) @@ -294,7 +297,7 @@ def launch_server( if server_args.dp_size == 1: start_process = start_controller_process_single else: - start_process = start_controller_process_multi + start_process = start_controller_process_flex proc_controller = mp.Process( target=start_process, args=(server_args, port_args, pipe_controller_writer, model_overide_args), diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 474c80b256..0ecdd2ba9d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -348,6 +348,9 @@ def add_cli_args(parser: argparse.ArgumentParser): choices=[ "round_robin", "shortest_queue", + "resources_aware", + "power_of_2_choice", + "pre_radix", ], ) diff --git a/python/sglang/utils.py b/python/sglang/utils.py index c880d259d5..a0035b29ec 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -15,6 +15,11 @@ import numpy as np import requests +from sglang.srt.model_config import ModelConfig, AttentionArch +from sglang.srt.server_args import ServerArgs + +from sglang.srt.model_config import AttentionArch, ModelConfig +from sglang.srt.server_args import ServerArgs logger = logging.getLogger(__name__) @@ -228,6 +233,22 @@ def find_printable_text(text: str): return text[: text.rfind(" ") + 1] +def get_cache_info(server_args: ServerArgs, model_overide_args): + """Extract the kv cache infromation from ServerArgs.""" + model_config = ModelConfig( + server_args.model_path, + server_args.trust_remote_code, + context_length=server_args.context_length, + model_overide_args=model_overide_args, + ) + assert ( + model_config.attention_arch == AttentionArch.MHA + ), "FlexController Only Support MHA Currently" + tp_size = server_args.tp_size + shape = (model_config.get_num_kv_heads(tp_size), model_config.head_dim) + return shape + + def graceful_registry(sub_module_name: str): def graceful_shutdown(signum, frame): logger.info( @@ -259,4 +280,4 @@ def __getattr__(self, name: str): def __call__(self, *args, **kwargs): module = self._load() - return module(*args, **kwargs) + return module(*args, **kwargs) \ No newline at end of file diff --git a/server.sh b/server.sh new file mode 100644 index 0000000000..5545bd5eae --- /dev/null +++ b/server.sh @@ -0,0 +1 @@ +python3 -m sglang.launch_server --model-path Qwen/Qwen1.5-1.8B-Chat --host 0.0.0.0 --port 8080 --mem-fraction-static 0.6 --chunked-prefill-size 512