Skip to content

Commit

Permalink
Add LLM Engine
Browse files Browse the repository at this point in the history
  • Loading branch information
JianyuZhan committed Sep 3, 2024
1 parent 1c136b2 commit a00e992
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 14 deletions.
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"test_triton_attn_backend.py",
"test_update_weights.py",
"test_vision_openai_server.py",
"test_llm_engine.py",
],
"sampling/penaltylib": glob.glob(
"sampling/penaltylib/**/test_*.py", recursive=True
Expand Down
58 changes: 58 additions & 0 deletions test/srt/test_llm_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
import unittest

from sglang import LLM, SamplingParams
from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST


class TestLLMGeneration(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model_name = DEFAULT_MODEL_NAME_FOR_TEST
cls.prompts_list = [
"Hello, my name is",
"The capital of China is",
"What is the meaning of life?",
"The future of AI is",
]
cls.single_prompt = "What is the meaning of life?"
# Turn off tokernizers parallelism to enable running multiple tests
os.environ["TOKENIZERS_PARALLELISM"] = "false"

def test_generate_with_sampling_params(self):
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model=self.model_name)
outputs = llm.generate(self.prompts_list, sampling_params)

self.assertEqual(len(outputs), len(self.prompts_list))
for output in outputs:
self.assertIn(output["index"], range(len(self.prompts_list)))
self.assertTrue(output["text"].strip())

def test_generate_without_sampling_params(self):
llm = LLM(model=self.model_name)
outputs = llm.generate(self.prompts_list)

self.assertEqual(len(outputs), len(self.prompts_list))
for output in outputs:
self.assertIn(output["index"], range(len(self.prompts_list)))
self.assertTrue(output["text"].strip())

def test_generate_with_single_prompt_and_sampling_params(self):
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model=self.model_name)
outputs = llm.generate(self.single_prompt, sampling_params)

self.assertEqual(len(outputs), 1)
self.assertTrue(outputs[0]["text"].strip())

def test_generate_with_single_prompt_without_sampling_params(self):
llm = LLM(model=self.model_name)
outputs = llm.generate(self.single_prompt)

self.assertEqual(len(outputs), 1)
self.assertTrue(outputs[0]["text"].strip())


if __name__ == "__main__":
unittest.main()
12 changes: 6 additions & 6 deletions test/srt/test_moe_serving_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from types import SimpleNamespace

from sglang.bench_serving import run_benchmark
from sglang.srt.server_args import ServerArgs
from sglang.srt.serving.engine_args import EngineArgs
from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import (
DEFAULT_MOE_MODEL_NAME_FOR_TEST,
Expand Down Expand Up @@ -69,9 +69,9 @@ def run_test(self, disable_radix_cache, disable_flashinfer, chunked_prefill_size

def test_default(self):
res = self.run_test(
disable_radix_cache=ServerArgs.disable_radix_cache,
disable_flashinfer=ServerArgs.disable_flashinfer,
chunked_prefill_size=ServerArgs.chunked_prefill_size,
disable_radix_cache=EngineArgs.disable_radix_cache,
disable_flashinfer=EngineArgs.disable_flashinfer,
chunked_prefill_size=EngineArgs.chunked_prefill_size,
)

if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
Expand All @@ -80,8 +80,8 @@ def test_default(self):
def test_default_without_radix_cache(self):
res = self.run_test(
disable_radix_cache=True,
disable_flashinfer=ServerArgs.disable_flashinfer,
chunked_prefill_size=ServerArgs.chunked_prefill_size,
disable_flashinfer=EngineArgs.disable_flashinfer,
chunked_prefill_size=EngineArgs.chunked_prefill_size,
)

if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
Expand Down
16 changes: 8 additions & 8 deletions test/srt/test_serving_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from types import SimpleNamespace

from sglang.bench_serving import run_benchmark
from sglang.srt.server_args import ServerArgs
from sglang.srt.serving.engine_args import EngineArgs
from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
Expand Down Expand Up @@ -68,9 +68,9 @@ def run_test(self, disable_radix_cache, disable_flashinfer, chunked_prefill_size

def test_default(self):
res = self.run_test(
disable_radix_cache=ServerArgs.disable_radix_cache,
disable_flashinfer=ServerArgs.disable_flashinfer,
chunked_prefill_size=ServerArgs.chunked_prefill_size,
disable_radix_cache=EngineArgs.disable_radix_cache,
disable_flashinfer=EngineArgs.disable_flashinfer,
chunked_prefill_size=EngineArgs.chunked_prefill_size,
)

if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
Expand All @@ -79,17 +79,17 @@ def test_default(self):
def test_default_without_radix_cache(self):
res = self.run_test(
disable_radix_cache=True,
disable_flashinfer=ServerArgs.disable_flashinfer,
chunked_prefill_size=ServerArgs.chunked_prefill_size,
disable_flashinfer=EngineArgs.disable_flashinfer,
chunked_prefill_size=EngineArgs.chunked_prefill_size,
)

if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
assert res["output_throughput"] > 2800

def test_default_without_chunked_prefill(self):
res = self.run_test(
disable_radix_cache=ServerArgs.disable_radix_cache,
disable_flashinfer=ServerArgs.disable_flashinfer,
disable_radix_cache=EngineArgs.disable_radix_cache,
disable_flashinfer=EngineArgs.disable_flashinfer,
chunked_prefill_size=-1,
)

Expand Down

0 comments on commit a00e992

Please sign in to comment.