Skip to content

Commit

Permalink
done
Browse files Browse the repository at this point in the history
  • Loading branch information
sunggg committed Jan 6, 2024
1 parent b0e9851 commit fb29bef
Showing 1 changed file with 77 additions and 36 deletions.
113 changes: 77 additions & 36 deletions serve/benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,47 +64,47 @@ def sample_requests(
return sampled_requests


def run_mii(requests: List[Tuple[str, int, int]], model, num_shards) -> float:
def run_mii(requests: List[Tuple[str, int, int]], args) -> float:
from mii import pipeline

engine = pipeline(model, tensor_parallel=num_shards)
engine = pipeline(args.model, tensor_parallel=args.num_shards)
prompts = [prompt for prompt, _, _ in requests]
# FIXME: hardcoded
output_len = 128
start = time.perf_counter()
engine(prompts, max_new_tokens=output_len)
engine(
prompts,
max_new_tokens=args.max_output_tokens,
ignore_eos=SAMPLER_SETTING["ignore_eos"],
temperature=SAMPLER_SETTING["temperature"],
)
end = time.perf_counter()
return end - start


def run_vllm(requests: List[Tuple[str, int, int]], model, num_shards) -> float:
def run_vllm(requests: List[Tuple[str, int, int]], args) -> float:
from vllm import LLM, SamplingParams

# Fixme
llm = LLM(
model=model,
tokenizer=model,
quantization=None,
tensor_parallel_size=num_shards,
seed=0,
model=args.model,
tokenizer=args.model,
dtype=args.dtype,
quantization=args.quantization,
tensor_parallel_size=args.num_shards,
trust_remote_code=True,
dtype="auto",
max_model_len=2000,
max_model_len=None, # derive from the model
)

# Add the requests to the engine.
for prompt, _, output_len in requests:
sampling_params = SamplingParams(
n=1,
use_beam_search=False,
temperature=SAMPLER_SETTING["temperature"],
ignore_eos=SAMPLER_SETTING["ignore_eos"],
max_tokens=output_len,
)
for prompt, _, _ in requests:
llm._add_request(
prompt=prompt,
prompt_token_ids=None,
sampling_params=sampling_params,
sampling_params=SamplingParams(
n=args.num_sequences_to_sample,
use_beam_search=False,
temperature=SAMPLER_SETTING["temperature"],
ignore_eos=SAMPLER_SETTING["ignore_eos"],
max_tokens=args.max_output_tokens,
),
)

start = time.perf_counter()
Expand All @@ -113,8 +113,8 @@ def run_vllm(requests: List[Tuple[str, int, int]], model, num_shards) -> float:
return end - start


def run_mlc(engine, requests, num_sequences_to_sample) -> float:
for i, (prompt, _, output_len) in enumerate(requests):
def run_mlc(engine, requests, args) -> float:
for i, (prompt, _, _) in enumerate(requests):
engine.add(
[
Request(
Expand All @@ -124,9 +124,9 @@ def run_mlc(engine, requests, num_sequences_to_sample) -> float:
temperature=SAMPLER_SETTING["temperature"]
),
stopping_criteria=StoppingCriteria(
max_tokens=output_len, stop_sequences=None
max_tokens=args.max_output_tokens, stop_sequences=None
),
num_sequences=num_sequences_to_sample,
num_sequences=args.num_sequences_to_sample,
debug_options=DebugOptions(
ignore_eos=SAMPLER_SETTING["ignore_eos"], prompt=prompt
),
Expand Down Expand Up @@ -190,23 +190,27 @@ def main(args: argparse.Namespace):
requests = sample_requests(
args.dataset, args.num_prompts, engine.tokenizer._tokenizer
)
elapsed_time = run_mlc(engine, requests, args.num_sequences_to_sample)
elapsed_time = run_mlc(engine, requests, args)
else:
from transformers import AutoTokenizer

model = "/opt/models/mistral/Mixtral-8x7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model)
assert (
args.model is not None
), "Please provide model path for vllm and deepspeed mii."
assert (
args.num_shards is not None
), "Please provide number of gpus for vllm and deepspeed mii."

tokenizer = AutoTokenizer.from_pretrained(args.model)
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
if args.backend == "mii":
num_shards = 1
elapsed_time = run_mii(requests, model, num_shards)
elapsed_time = run_mii(requests, args)
elif args.backend == "vllm":
num_shards = 2
elapsed_time = run_vllm(requests, model, num_shards)
elapsed_time = run_vllm(requests, args)

total_num_tokens = sum(
prompt_len + output_len * args.num_sequences_to_sample
for _, prompt_len, output_len in requests
prompt_len + args.max_output_tokens * args.num_sequences_to_sample
for _, prompt_len, _ in requests
)
req_per_sec = len(requests) / elapsed_time
tok_per_sec = total_num_tokens / elapsed_time
Expand Down Expand Up @@ -238,12 +242,49 @@ def main(args: argparse.Namespace):
parser.add_argument(
"--num-prompts", type=int, default=1000, help="Number of prompts to process."
)
parser.add_argument(
"--max-output-tokens",
type=int,
default=128,
help="Maximum number of generation tokens.",
)
parser.add_argument(
"--report-path",
type=str,
default=None,
help="Append the current result to the given path if provided.",
)
# flags for vllm and deepspeed mii
parser.add_argument(
"--model",
type=str,
default=None,
help="Model path. This is for vLLM and Deepspeed MII.",
)
parser.add_argument(
"--num-shards",
type=int,
default=None,
help="Number of GPUs. This is for vLLM and Deepspeed MII.",
)
# flags for vllm
parser.add_argument(
"--quantization",
"-q",
choices=["awq", "gptq", "squeezellm", None],
default=None,
)
parser.add_argument(
"--dtype",
type=str,
default="auto",
choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
help="data type for model weights and activations. "
'The "auto" option will use FP16 precision '
"for FP32 and FP16 models, and BF16 precision "
"for BF16 models.",
)

args = parser.parse_args()
args = postproc_mlc_serve_args(args)

Expand Down

0 comments on commit fb29bef

Please sign in to comment.