Skip to content

Commit

Permalink
add radix cache scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
yukavio committed Sep 24, 2024
1 parent a2b9a62 commit e034fcf
Show file tree
Hide file tree
Showing 15 changed files with 3,800 additions and 68 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,8 @@ work_dirs/
*.csv

!logo.png
test.py
a.txt
b.txt
*.txt
launch_server.py.lprof
164 changes: 122 additions & 42 deletions python/sglang/bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,15 +311,25 @@ def download_sharegpt_dataset(path):
raise Exception(f"Failed to download dataset: {e}")


import pickle


def sample_sharegpt_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
max_seqlen: int,
) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")

cache_path = f"./input_cache_v2_{num_requests}"
# 尝试加载缓存的 input_requests
if os.path.isfile(cache_path):
with open(cache_path, "rb") as f:
input_requests = pickle.load(f)
print("Loaded input_requests from cache.")
return input_requests
prompts = []
prompt_lens = []
response_lens = []
# Download sharegpt if necessary
if not os.path.isfile(dataset_path) and not os.path.isfile(default_sharegpt_path):
download_sharegpt_dataset(default_sharegpt_path)
Expand All @@ -331,42 +341,52 @@ def sample_sharegpt_requests(

# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]

# Shuffle the dataset.
random.shuffle(dataset)

# Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = []
for i in range(len(dataset)):
if len(filtered_dataset) == num_requests:
break

# Tokenize the prompts and completions.
prompt = dataset[i][0]
prompt_token_ids = tokenizer(prompt).input_ids
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = (
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
)
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
continue
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))
datasets = json.load(f)
for data in datasets:
if len(data["conversations"]) >= 2:
prompt = data["conversations"][0]["value"]
res = data["conversations"][1]["value"]
prompt_token_ids = tokenizer(prompt).input_ids
completion_token_ids = tokenizer(res).input_ids

if (
len(prompt_token_ids) + len(completion_token_ids) < max_seqlen
and len(prompt_token_ids) > 0
and len(completion_token_ids) > 0
):
prompts.append(prompt)
prompt_lens.append(len(prompt_token_ids))
response_lens.append(len(completion_token_ids))
if len(prompts) > num_requests:
break

sampled_ids = [random.randint(0, len(prompts) - 1) for _ in range(num_requests)]
sampled_prompts = [prompts[idx] for idx in sampled_ids]
sampled_prompts_lens = [prompt_lens[idx] for idx in sampled_ids]
sampled_response_lens = [response_lens[idx] for idx in sampled_ids]

for i, (prompt_len, gen_len) in enumerate(
zip(sampled_prompts_lens, sampled_response_lens)
):
total = prompt_len + gen_len
if total > max_seqlen:
print(f"truncating long prompt+gen_len {prompt_len=} {gen_len=}")
gen_len = max_seqlen - prompt_len
sampled_response_lens[i] = gen_len
input_requests = list(
zip(sampled_prompts, sampled_prompts_lens, sampled_response_lens)
)


with open(cache_path, "wb") as f:
pickle.dump(input_requests, f)
print(f"Saved input_requests_{num_requests} to cache.")
print(f"#Input tokens: {np.sum(sampled_prompts_lens)}")
print(f"#Output tokens: {np.sum(sampled_response_lens)}")
return input_requests

return filtered_dataset

import pickle


def sample_random_requests(
Expand All @@ -377,7 +397,13 @@ def sample_random_requests(
tokenizer: PreTrainedTokenizerBase,
dataset_path: str,
) -> List[Tuple[str, int, int]]:

cache_path = f"./input_cache_{num_prompts}"
# 尝试加载缓存的 input_requests
if os.path.isfile(cache_path):
with open(cache_path, "rb") as f:
input_requests = pickle.load(f)
print("Loaded input_requests from cache.")
return input_requests
input_lens = np.random.randint(
max(int(input_len * range_ratio), 1),
input_len + 1,
Expand Down Expand Up @@ -444,7 +470,9 @@ def sample_random_requests(
]
)
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))

with open(cache_path, "wb") as f:
pickle.dump(input_requests, f)
print(f"Saved input_requests_{num_prompts} to cache.")
print(f"#Input tokens: {np.sum(input_lens)}")
print(f"#Output tokens: {np.sum(output_lens)}")
return input_requests
Expand Down Expand Up @@ -483,6 +511,9 @@ def calculate_metrics(
tpots: List[float] = []
ttfts: List[float] = []
e2e_latencies: List[float] = []

input_lens: List[float] = []

for i in range(len(outputs)):
if outputs[i].success:
output_len = outputs[i].output_len
Expand All @@ -492,6 +523,9 @@ def calculate_metrics(
)
retokenized_output_lens.append(retokenized_output_len)
total_input += input_requests[i][1]

input_lens.append(input_requests[i][1])

if output_len > 1:
tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
itls += outputs[i].itl
Expand All @@ -510,6 +544,11 @@ def calculate_metrics(
"on the benchmark arguments.",
stacklevel=2,
)

# metric_data = [input_lens, output_lens, ttfts]
# with open(f'metrics_{time.time()}.json', 'w') as f:
# json.dump(metric_data, f)

metrics = BenchmarkMetrics(
completed=completed,
total_input=total_input,
Expand Down Expand Up @@ -557,6 +596,15 @@ async def benchmark(

print("Starting initial single prompt test run...")
test_prompt, test_prompt_len, test_output_len = input_requests[0]

words = test_prompt.split()

# 使用random.shuffle打乱单词列表
random.shuffle(words)

# 将打乱后的单词列表合并回文本
test_prompt = " ".join(words)

test_input = RequestFuncInput(
model=model_id,
prompt=test_prompt,
Expand Down Expand Up @@ -723,6 +771,31 @@ async def benchmark(
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
}

balance_method = os.getenv("LOAD_BALANCE_METHOD")
new_item = {
"method": balance_method,
"mean_ttft": metrics.mean_ttft_ms,
"request_rate": request_rate,
"request_throughput": metrics.request_throughput,
"p99_ttft_ms": metrics.p99_ttft_ms,
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
"time": datetime.now().isoformat(),
}
file_name = f"{balance_method}_result.json"
if not os.path.exists(file_name):
with open(file_name, "w") as f:
json.dump([], f)

with open(file_name, "r") as f:
tmp_data = json.load(f)

tmp_data.append(new_item)

with open(file_name, "w") as f:
json.dump(tmp_data, f, indent=4)

print(f"add new item to {file_name}: {new_item}")
return result


Expand Down Expand Up @@ -819,7 +892,7 @@ def run_benchmark(args_: argparse.Namespace):
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
fixed_output_len=args.sharegpt_output_len,
max_seqlen=args.sharegpt_max_seqlen,
)
elif args.dataset_name == "random":
input_requests = sample_random_requests(
Expand Down Expand Up @@ -851,6 +924,7 @@ def run_benchmark(args_: argparse.Namespace):
)
)
else:

return asyncio.run(
benchmark(
backend=backend,
Expand Down Expand Up @@ -927,6 +1001,12 @@ def set_ulimit(target_soft_limit=65535):
default=1000,
help="Number of prompts to process. Default is 1000.",
)
parser.add_argument(
"--sharegpt-max-seqlen",
type=int,
default=8192,
help="Number of max request len. Default is 8192.",
)
parser.add_argument(
"--sharegpt-output-len",
type=int,
Expand Down
Loading

0 comments on commit e034fcf

Please sign in to comment.