Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flex scheduler #1142

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 122 additions & 50 deletions python/sglang/bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,15 +311,25 @@ def download_sharegpt_dataset(path):
raise Exception(f"Failed to download dataset: {e}")


import pickle


def sample_sharegpt_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
max_seqlen: int,
) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")

cache_path = f"./input_cache_v2_{num_requests}"
# 尝试加载缓存的 input_requests
if os.path.isfile(cache_path):
with open(cache_path, "rb") as f:
input_requests = pickle.load(f)
print("Loaded input_requests from cache.")
return input_requests
prompts = []
prompt_lens = []
response_lens = []
# Download sharegpt if necessary
if not os.path.isfile(dataset_path) and not os.path.isfile(default_sharegpt_path):
download_sharegpt_dataset(default_sharegpt_path)
Expand All @@ -331,42 +341,52 @@ def sample_sharegpt_requests(

# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]

# Shuffle the dataset.
random.shuffle(dataset)

# Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = []
for i in range(len(dataset)):
if len(filtered_dataset) == num_requests:
break

# Tokenize the prompts and completions.
prompt = dataset[i][0]
prompt_token_ids = tokenizer(prompt).input_ids
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = (
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
)
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
continue
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))
datasets = json.load(f)
for data in datasets:
if len(data["conversations"]) >= 2:
prompt = data["conversations"][0]["value"]
res = data["conversations"][1]["value"]
prompt_token_ids = tokenizer(prompt).input_ids
completion_token_ids = tokenizer(res).input_ids

if (
len(prompt_token_ids) + len(completion_token_ids) < max_seqlen
and len(prompt_token_ids) > 0
and len(completion_token_ids) > 0
):
prompts.append(prompt)
prompt_lens.append(len(prompt_token_ids))
response_lens.append(len(completion_token_ids))
if len(prompts) > num_requests:
break

sampled_ids = [random.randint(0, len(prompts) - 1) for _ in range(num_requests)]
sampled_prompts = [prompts[idx] for idx in sampled_ids]
sampled_prompts_lens = [prompt_lens[idx] for idx in sampled_ids]
sampled_response_lens = [response_lens[idx] for idx in sampled_ids]

for i, (prompt_len, gen_len) in enumerate(
zip(sampled_prompts_lens, sampled_response_lens)
):
total = prompt_len + gen_len
if total > max_seqlen:
print(f"truncating long prompt+gen_len {prompt_len=} {gen_len=}")
gen_len = max_seqlen - prompt_len
sampled_response_lens[i] = gen_len
input_requests = list(
zip(sampled_prompts, sampled_prompts_lens, sampled_response_lens)
)


with open(cache_path, "wb") as f:
pickle.dump(input_requests, f)
print(f"Saved input_requests_{num_requests} to cache.")
print(f"#Input tokens: {np.sum(sampled_prompts_lens)}")
print(f"#Output tokens: {np.sum(sampled_response_lens)}")
return input_requests


return filtered_dataset
import pickle


def sample_random_requests(
Expand All @@ -377,7 +397,13 @@ def sample_random_requests(
tokenizer: PreTrainedTokenizerBase,
dataset_path: str,
) -> List[Tuple[str, int, int]]:

cache_path = f"./input_cache_{num_prompts}"
# 尝试加载缓存的 input_requests
if os.path.isfile(cache_path):
with open(cache_path, "rb") as f:
input_requests = pickle.load(f)
print("Loaded input_requests from cache.")
return input_requests
input_lens = np.random.randint(
max(int(input_len * range_ratio), 1),
input_len + 1,
Expand Down Expand Up @@ -444,7 +470,9 @@ def sample_random_requests(
]
)
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))

with open(cache_path, "wb") as f:
pickle.dump(input_requests, f)
print(f"Saved input_requests_{num_prompts} to cache.")
print(f"#Input tokens: {np.sum(input_lens)}")
print(f"#Output tokens: {np.sum(output_lens)}")
return input_requests
Expand Down Expand Up @@ -483,6 +511,9 @@ def calculate_metrics(
tpots: List[float] = []
ttfts: List[float] = []
e2e_latencies: List[float] = []

input_lens: List[float] = []

for i in range(len(outputs)):
if outputs[i].success:
output_len = outputs[i].output_len
Expand All @@ -492,6 +523,9 @@ def calculate_metrics(
)
retokenized_output_lens.append(retokenized_output_len)
total_input += input_requests[i][1]

input_lens.append(input_requests[i][1])

if output_len > 1:
tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
itls += outputs[i].itl
Expand All @@ -510,6 +544,11 @@ def calculate_metrics(
"on the benchmark arguments.",
stacklevel=2,
)

# metric_data = [input_lens, output_lens, ttfts]
# with open(f'metrics_{time.time()}.json', 'w') as f:
# json.dump(metric_data, f)

metrics = BenchmarkMetrics(
completed=completed,
total_input=total_input,
Expand Down Expand Up @@ -557,6 +596,15 @@ async def benchmark(

print("Starting initial single prompt test run...")
test_prompt, test_prompt_len, test_output_len = input_requests[0]

words = test_prompt.split()

# 使用random.shuffle打乱单词列表
random.shuffle(words)

# 将打乱后的单词列表合并回文本
test_prompt = " ".join(words)

test_input = RequestFuncInput(
model=model_id,
prompt=test_prompt,
Expand All @@ -565,14 +613,6 @@ async def benchmark(
output_len=test_output_len,
extra_request_body=extra_request_body,
)
test_output = await request_func(request_func_input=test_input)
if not test_output.success:
raise ValueError(
"Initial test run failed - Please make sure benchmark arguments "
f"are correctly specified. Error: {test_output.error}"
)
else:
print("Initial test run completed. Starting main benchmark run...")

pbar = None if disable_tqdm else tqdm(total=len(input_requests))

Expand Down Expand Up @@ -731,6 +771,31 @@ async def benchmark(
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
}

balance_method = os.getenv("LOAD_BALANCE_METHOD")
new_item = {
"method": balance_method,
"mean_ttft": metrics.mean_ttft_ms,
"request_rate": request_rate,
"request_throughput": metrics.request_throughput,
"p99_ttft_ms": metrics.p99_ttft_ms,
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
"time": datetime.now().isoformat(),
}
file_name = f"{balance_method}_result.json"
if not os.path.exists(file_name):
with open(file_name, "w") as f:
json.dump([], f)

with open(file_name, "r") as f:
tmp_data = json.load(f)

tmp_data.append(new_item)

with open(file_name, "w") as f:
json.dump(tmp_data, f, indent=4)

print(f"add new item to {file_name}: {new_item}")
return result


Expand Down Expand Up @@ -827,7 +892,7 @@ def run_benchmark(args_: argparse.Namespace):
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
fixed_output_len=args.sharegpt_output_len,
max_seqlen=args.sharegpt_max_seqlen,
)
elif args.dataset_name == "random":
input_requests = sample_random_requests(
Expand Down Expand Up @@ -859,6 +924,7 @@ def run_benchmark(args_: argparse.Namespace):
)
)
else:

return asyncio.run(
benchmark(
backend=backend,
Expand Down Expand Up @@ -935,6 +1001,12 @@ def set_ulimit(target_soft_limit=65535):
default=1000,
help="Number of prompts to process. Default is 1000.",
)
parser.add_argument(
"--sharegpt-max-seqlen",
type=int,
default=8192,
help="Number of max request len. Default is 8192.",
)
parser.add_argument(
"--sharegpt-output-len",
type=int,
Expand Down
Loading