From 688cb2c936a61251aff582281ed5d08b1d836f3f Mon Sep 17 00:00:00 2001 From: Jianyu Zhan Date: Mon, 26 Aug 2024 15:42:00 +0000 Subject: [PATCH] Add LLM Engine --- .../sglang/srt/managers/tokenizer_manager.py | 9 +-- test/srt/run_suite.py | 1 + test/srt/test_llm_engine.py | 58 +++++++++++++++++++ test/srt/test_moe_serving_throughput.py | 12 ++-- test/srt/test_serving_throughput.py | 16 ++--- 5 files changed, 76 insertions(+), 20 deletions(-) create mode 100644 test/srt/test_llm_engine.py diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 07ae04bc2a..8dee7061fb 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -84,13 +84,10 @@ def __init__( self.recv_from_detokenizer = context.socket(zmq.PULL) self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{engine_args.tokenizer_port}") -<<<<<<< HEAD self.send_to_controller = context.socket(zmq.PUSH) - self.send_to_controller.connect(f"tcp://127.0.0.1:{port_args.controller_port}") -======= - self.send_to_router = context.socket(zmq.PUSH) - self.send_to_router.connect(f"tcp://127.0.0.1:{engine_args.controller_port}") ->>>>>>> 7b5d19c (Add LLM Engine) + self.send_to_controller.connect( + f"tcp://127.0.0.1:{engine_args.controller_port}" + ) # Read model args self.model_path = engine_args.model_path diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index cafcf3f2d5..baab1261d8 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -19,6 +19,7 @@ "test_triton_attn_backend.py", "test_update_weights.py", "test_vision_openai_server.py", + "test_llm_engine.py", ], "sampling/penaltylib": glob.glob( "sampling/penaltylib/**/test_*.py", recursive=True diff --git a/test/srt/test_llm_engine.py b/test/srt/test_llm_engine.py new file mode 100644 index 0000000000..fdac42a977 --- /dev/null +++ b/test/srt/test_llm_engine.py @@ -0,0 +1,58 @@ +import os +import unittest + +from sglang import LLM, SamplingParams +from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST + + +class TestLLMGeneration(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model_name = DEFAULT_MODEL_NAME_FOR_TEST + cls.prompts_list = [ + "Hello, my name is", + "The capital of China is", + "What is the meaning of life?", + "The future of AI is", + ] + cls.single_prompt = "What is the meaning of life?" + # Turn off tokernizers parallelism to enable running multiple tests + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + def test_generate_with_sampling_params(self): + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + llm = LLM(model=self.model_name) + outputs = llm.generate(self.prompts_list, sampling_params) + + self.assertEqual(len(outputs), len(self.prompts_list)) + for output in outputs: + self.assertIn(output["index"], range(len(self.prompts_list))) + self.assertTrue(output["text"].strip()) + + def test_generate_without_sampling_params(self): + llm = LLM(model=self.model_name) + outputs = llm.generate(self.prompts_list) + + self.assertEqual(len(outputs), len(self.prompts_list)) + for output in outputs: + self.assertIn(output["index"], range(len(self.prompts_list))) + self.assertTrue(output["text"].strip()) + + def test_generate_with_single_prompt_and_sampling_params(self): + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + llm = LLM(model=self.model_name) + outputs = llm.generate(self.single_prompt, sampling_params) + + self.assertEqual(len(outputs), 1) + self.assertTrue(outputs[0]["text"].strip()) + + def test_generate_with_single_prompt_without_sampling_params(self): + llm = LLM(model=self.model_name) + outputs = llm.generate(self.single_prompt) + + self.assertEqual(len(outputs), 1) + self.assertTrue(outputs[0]["text"].strip()) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_moe_serving_throughput.py b/test/srt/test_moe_serving_throughput.py index 2acf626c1c..b0ebb36a62 100644 --- a/test/srt/test_moe_serving_throughput.py +++ b/test/srt/test_moe_serving_throughput.py @@ -3,7 +3,7 @@ from types import SimpleNamespace from sglang.bench_serving import run_benchmark -from sglang.srt.server_args import ServerArgs +from sglang.srt.serving.engine_args import EngineArgs from sglang.srt.utils import kill_child_process from sglang.test.test_utils import ( DEFAULT_MOE_MODEL_NAME_FOR_TEST, @@ -69,9 +69,9 @@ def run_test(self, disable_radix_cache, disable_flashinfer, chunked_prefill_size def test_default(self): res = self.run_test( - disable_radix_cache=ServerArgs.disable_radix_cache, - disable_flashinfer=ServerArgs.disable_flashinfer, - chunked_prefill_size=ServerArgs.chunked_prefill_size, + disable_radix_cache=EngineArgs.disable_radix_cache, + disable_flashinfer=EngineArgs.disable_flashinfer, + chunked_prefill_size=EngineArgs.chunked_prefill_size, ) if os.getenv("SGLANG_IS_IN_CI", "false") == "true": @@ -80,8 +80,8 @@ def test_default(self): def test_default_without_radix_cache(self): res = self.run_test( disable_radix_cache=True, - disable_flashinfer=ServerArgs.disable_flashinfer, - chunked_prefill_size=ServerArgs.chunked_prefill_size, + disable_flashinfer=EngineArgs.disable_flashinfer, + chunked_prefill_size=EngineArgs.chunked_prefill_size, ) if os.getenv("SGLANG_IS_IN_CI", "false") == "true": diff --git a/test/srt/test_serving_throughput.py b/test/srt/test_serving_throughput.py index d4ed12612a..52ad476ae3 100644 --- a/test/srt/test_serving_throughput.py +++ b/test/srt/test_serving_throughput.py @@ -3,7 +3,7 @@ from types import SimpleNamespace from sglang.bench_serving import run_benchmark -from sglang.srt.server_args import ServerArgs +from sglang.srt.serving.engine_args import EngineArgs from sglang.srt.utils import kill_child_process from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, @@ -68,9 +68,9 @@ def run_test(self, disable_radix_cache, disable_flashinfer, chunked_prefill_size def test_default(self): res = self.run_test( - disable_radix_cache=ServerArgs.disable_radix_cache, - disable_flashinfer=ServerArgs.disable_flashinfer, - chunked_prefill_size=ServerArgs.chunked_prefill_size, + disable_radix_cache=EngineArgs.disable_radix_cache, + disable_flashinfer=EngineArgs.disable_flashinfer, + chunked_prefill_size=EngineArgs.chunked_prefill_size, ) if os.getenv("SGLANG_IS_IN_CI", "false") == "true": @@ -79,8 +79,8 @@ def test_default(self): def test_default_without_radix_cache(self): res = self.run_test( disable_radix_cache=True, - disable_flashinfer=ServerArgs.disable_flashinfer, - chunked_prefill_size=ServerArgs.chunked_prefill_size, + disable_flashinfer=EngineArgs.disable_flashinfer, + chunked_prefill_size=EngineArgs.chunked_prefill_size, ) if os.getenv("SGLANG_IS_IN_CI", "false") == "true": @@ -88,8 +88,8 @@ def test_default_without_radix_cache(self): def test_default_without_chunked_prefill(self): res = self.run_test( - disable_radix_cache=ServerArgs.disable_radix_cache, - disable_flashinfer=ServerArgs.disable_flashinfer, + disable_radix_cache=EngineArgs.disable_radix_cache, + disable_flashinfer=EngineArgs.disable_flashinfer, chunked_prefill_size=-1, )