From 60efeb70044bfc0a1bc3513d4cf30d97ee90b855 Mon Sep 17 00:00:00 2001 From: Jianyu Zhan Date: Mon, 26 Aug 2024 15:42:00 +0000 Subject: [PATCH] Add LLM Engine --- .github/workflows/e2e-test.yml | 8 + .gitignore | 1 + examples/usage/llm_engine.py | 25 ++ python/sglang/__init__.py | 4 + python/sglang/api.py | 4 +- python/sglang/bench_latency.py | 2 +- python/sglang/launch_server.py | 4 +- .../sglang/srt/managers/controller_multi.py | 29 +-- .../sglang/srt/managers/controller_single.py | 38 ++- .../srt/managers/detokenizer_manager.py | 27 +- .../sglang/srt/managers/tokenizer_manager.py | 210 ++++++++++----- python/sglang/srt/managers/tp_worker.py | 100 ++++--- .../srt/model_executor/forward_batch_info.py | 6 +- .../sglang/srt/model_executor/model_runner.py | 76 +++--- python/sglang/srt/serving/__init__.py | 0 python/sglang/srt/serving/engine.py | 243 ++++++++++++++++++ python/sglang/srt/serving/engine_args.py | 216 ++++++++++++++++ python/sglang/srt/{ => serving}/server.py | 215 ++++------------ .../sglang/srt/{ => serving}/server_args.py | 235 +++++++---------- python/sglang/srt/utils.py | 4 +- python/sglang/test/runners.py | 2 +- test/srt/run_suite.py | 1 + test/srt/test_chunked_prefill.py | 1 + test/srt/test_llm_engine.py | 58 +++++ test/srt/test_moe_serving_throughput.py | 12 +- test/srt/test_serving_throughput.py | 16 +- 26 files changed, 989 insertions(+), 548 deletions(-) create mode 100644 examples/usage/llm_engine.py create mode 100644 python/sglang/srt/serving/__init__.py create mode 100644 python/sglang/srt/serving/engine.py create mode 100644 python/sglang/srt/serving/engine_args.py rename python/sglang/srt/{ => serving}/server.py (66%) rename python/sglang/srt/{ => serving}/server_args.py (71%) create mode 100644 test/srt/test_llm_engine.py diff --git a/.github/workflows/e2e-test.yml b/.github/workflows/e2e-test.yml index 11c94775c1..76d442a8e2 100644 --- a/.github/workflows/e2e-test.yml +++ b/.github/workflows/e2e-test.yml @@ -32,6 +32,14 @@ jobs: pip install -e "python[all]" pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall + - name: Set PYTHONPATH + run: | + echo "PYTHONPATH=$PYTHONPATH:$(pwd)/python" >> $GITHUB_ENV + + - name: Verify import + run: | + python3 -c "import sglang.srt.serving" + - name: Benchmark Serving Throughput timeout-minutes: 10 run: | diff --git a/.gitignore b/.gitignore index ca43e1ccba..15e29a02f7 100644 --- a/.gitignore +++ b/.gitignore @@ -50,6 +50,7 @@ coverage.xml .hypothesis/ .pytest_cache/ cover/ +human-eval/ # Translations *.mo diff --git a/examples/usage/llm_engine.py b/examples/usage/llm_engine.py new file mode 100644 index 0000000000..d0d73fc5eb --- /dev/null +++ b/examples/usage/llm_engine.py @@ -0,0 +1,25 @@ +from sglang import LLM, SamplingParams + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The capital of China is", + "What is the meaning of life?", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# Create an LLM. +llm = LLM(model="deepseek-ai/deepseek-llm-7b-chat", tensor_parallel_size=1) + +outputs = llm.generate(prompts, sampling_params) + +# Print the outputs. +for output in outputs: + index = output["index"] + prompt = prompts[index] + answer = output["text"] + print("===============================") + print(f"Prompt: {prompt}") + print(f"Generated text: {output['text']}") diff --git a/python/sglang/__init__.py b/python/sglang/__init__.py index 71d7bfeccf..93b408a324 100644 --- a/python/sglang/__init__.py +++ b/python/sglang/__init__.py @@ -1,7 +1,9 @@ # SGL API Components from sglang.api import ( + LLM, Runtime, + SamplingParams, assistant, assistant_begin, assistant_end, @@ -30,6 +32,8 @@ # SGLang DSL APIs __all__ = [ + "LLM", + "SamplingParams", "Runtime", "assistant", "assistant_begin", diff --git a/python/sglang/api.py b/python/sglang/api.py index 9405606b71..41dedbbc44 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -18,6 +18,8 @@ SglSelect, SglVideo, ) +from sglang.srt.sampling.sampling_params import SamplingParams +from sglang.srt.serving.engine import LLM def function( @@ -35,7 +37,7 @@ def decorator(func): def Runtime(*args, **kwargs): # Avoid importing unnecessary dependency os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - from sglang.srt.server import Runtime + from sglang.srt.serving.server import Runtime return Runtime(*args, **kwargs) diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 3a48740857..5875e66d65 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -55,7 +55,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_params import SamplingParams -from sglang.srt.server_args import ServerArgs +from sglang.srt.serving.server_args import ServerArgs from sglang.srt.utils import suppress_other_loggers diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 1df64e848c..9372fe70c5 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -3,8 +3,8 @@ import argparse import os -from sglang.srt.server import launch_server -from sglang.srt.server_args import ServerArgs +from sglang.srt.serving.server import launch_server +from sglang.srt.serving.server_args import ServerArgs from sglang.srt.utils import kill_child_process if __name__ == "__main__": diff --git a/python/sglang/srt/managers/controller_multi.py b/python/sglang/srt/managers/controller_multi.py index d2b10e7fa2..1b7bcc76ed 100644 --- a/python/sglang/srt/managers/controller_multi.py +++ b/python/sglang/srt/managers/controller_multi.py @@ -34,7 +34,7 @@ FlushCacheReq, TokenizedGenerateReqInput, ) -from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.serving.engine_args import EngineArgs from sglang.srt.utils import configure_logger, kill_parent_process from sglang.utils import get_exception_traceback @@ -69,22 +69,19 @@ class ControllerMulti: def __init__( self, - server_args: ServerArgs, - port_args: PortArgs, - model_overide_args, + engine_args: EngineArgs, ): # Parse args - self.server_args = server_args - self.port_args = port_args - self.model_overide_args = model_overide_args + self.engine_args = engine_args + self.model_overide_args = engine_args.model_override_args self.load_balance_method = LoadBalanceMethod.from_str( - server_args.load_balance_method + engine_args.load_balance_method ) # Init communication context = zmq.Context() self.recv_from_tokenizer = context.socket(zmq.PULL) - self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}") + self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{engine_args.controller_port}") # Dispatch method self.round_robin_counter = 0 @@ -96,11 +93,11 @@ def __init__( # Start data parallel workers self.workers = [] - for i in range(server_args.dp_size): + for i in range(engine_args.dp_size): self.start_dp_worker(i) def start_dp_worker(self, dp_worker_id: int): - tp_size = self.server_args.tp_size + tp_size = self.engine_args.tp_size pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe( duplex=False @@ -111,7 +108,7 @@ def start_dp_worker(self, dp_worker_id: int): proc = multiprocessing.Process( target=start_controller_process_single, args=( - self.server_args, + self.engine_args, self.port_args, pipe_controller_writer, self.model_overide_args, @@ -186,17 +183,15 @@ def recv_requests(self): def start_controller_process( - server_args: ServerArgs, - port_args: PortArgs, + engine_args: EngineArgs, pipe_writer, - model_overide_args: dict, ): """Start a controller process.""" - configure_logger(server_args) + configure_logger(engine_args.log_level) try: - controller = ControllerMulti(server_args, port_args, model_overide_args) + controller = ControllerMulti(engine_args) except Exception: pipe_writer.send(get_exception_traceback()) raise diff --git a/python/sglang/srt/managers/controller_single.py b/python/sglang/srt/managers/controller_single.py index 4a16a6f6e4..55e38f9175 100644 --- a/python/sglang/srt/managers/controller_single.py +++ b/python/sglang/srt/managers/controller_single.py @@ -26,7 +26,7 @@ broadcast_recv_input, launch_tp_servers, ) -from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.serving.engine_args import EngineArgs from sglang.srt.utils import configure_logger, kill_parent_process from sglang.utils import get_exception_traceback @@ -38,16 +38,14 @@ class ControllerSingle: def __init__( self, - server_args: ServerArgs, - port_args: PortArgs, - model_overide_args: dict, + engine_args: EngineArgs, gpu_ids: List[int], is_data_parallel_worker: bool, dp_worker_id: int, mp_queue: multiprocessing.Queue, ): # Parse args - self.tp_size = server_args.tp_size + self.tp_size = engine_args.tp_size self.is_dp_worker = is_data_parallel_worker self.dp_worker_id = dp_worker_id self.mp_queue = mp_queue @@ -58,34 +56,32 @@ def __init__( if not self.is_dp_worker: self.recv_from_tokenizer = context.socket(zmq.PULL) self.recv_from_tokenizer.bind( - f"tcp://127.0.0.1:{port_args.controller_port}" + f"tcp://127.0.0.1:{engine_args.controller_port}" ) self.send_to_detokenizer = context.socket(zmq.PUSH) self.send_to_detokenizer.connect( - f"tcp://127.0.0.1:{port_args.detokenizer_port}" + f"tcp://127.0.0.1:{engine_args.detokenizer_port}" ) # Launch other tp ranks - tp_size_local = server_args.tp_size // server_args.nnodes + tp_size_local = engine_args.tp_size // engine_args.nnodes self.tp_procs = [] if tp_size_local > 1: tp_rank_range = range(1, tp_size_local) self.tp_procs = launch_tp_servers( gpu_ids, tp_rank_range, - server_args, - port_args.nccl_ports[dp_worker_id], - model_overide_args, + engine_args.nccl_ports[dp_worker_id], + engine_args, ) # Launch tp rank 0 self.tp_server = ModelTpServer( gpu_ids[0], 0, - server_args, - port_args.nccl_ports[dp_worker_id], - model_overide_args, + engine_args.nccl_ports[dp_worker_id], + engine_args, ) self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group @@ -123,10 +119,8 @@ def recv_requests_from_mp_queue(self): def start_controller_process( - server_args: ServerArgs, - port_args: PortArgs, + engine_args: EngineArgs, pipe_writer: multiprocessing.connection.Connection, - model_overide_args: dict, is_data_parallel_worker: bool = False, gpu_ids: List[int] = None, dp_worker_id: int = None, @@ -137,19 +131,17 @@ def start_controller_process( logger_prefix = f" DP{dp_worker_id} TP0" else: logger_prefix = " TP0" - configure_logger(server_args, prefix=logger_prefix) + configure_logger(engine_args.log_level, prefix=logger_prefix) if not is_data_parallel_worker: - tp_size_local = server_args.tp_size // server_args.nnodes - gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)] + tp_size_local = engine_args.tp_size // engine_args.nnodes + gpu_ids = [i for _ in range(engine_args.nnodes) for i in range(tp_size_local)] dp_worker_id = 0 queue = None try: controller = ControllerSingle( - server_args, - port_args, - model_overide_args, + engine_args, gpu_ids, is_data_parallel_worker, dp_worker_id, diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index cd5f63125c..84d277d51f 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -31,7 +31,7 @@ UpdateWeightReqOutput, ) from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR -from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.serving.engine_args import EngineArgs from sglang.utils import find_printable_text, get_exception_traceback asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -53,24 +53,23 @@ class DetokenizerManager: def __init__( self, - server_args: ServerArgs, - port_args: PortArgs, + engine_args: EngineArgs, ): # Init inter-process communication context = zmq.asyncio.Context(2) self.recv_from_router = context.socket(zmq.PULL) - self.recv_from_router.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}") + self.recv_from_router.bind(f"tcp://127.0.0.1:{engine_args.detokenizer_port}") self.send_to_tokenizer = context.socket(zmq.PUSH) - self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}") + self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{engine_args.tokenizer_port}") - if server_args.skip_tokenizer_init: + if engine_args.skip_tokenizer_init: self.tokenizer = None else: self.tokenizer = get_tokenizer( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, + engine_args.tokenizer_path, + tokenizer_mode=engine_args.tokenizer_mode, + trust_remote_code=engine_args.trust_remote_code, ) self.decode_status = {} @@ -171,15 +170,17 @@ async def handle_loop(self): def start_detokenizer_process( - server_args: ServerArgs, - port_args: PortArgs, + engine_args: EngineArgs, pipe_writer, ): try: - manager = DetokenizerManager(server_args, port_args) + manager = DetokenizerManager(engine_args) except Exception: pipe_writer.send(get_exception_traceback()) raise pipe_writer.send("init ok") - loop = asyncio.get_event_loop() + # Create a new event loop for this process because asyncio.get_event_loop() + # does not return a loop in a new thread or process in Python 3.10+. + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) loop.run_until_complete(manager.handle_loop()) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 5ad4152ea9..abcd06c721 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -52,7 +52,7 @@ ) from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.sampling.sampling_params import SamplingParams -from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.serving.engine_args import EngineArgs from sglang.srt.utils import is_generation_model, is_multimodal_model, load_image from sglang.utils import get_exception_traceback @@ -75,44 +75,42 @@ class TokenizerManager: def __init__( self, - server_args: ServerArgs, - port_args: PortArgs, - model_overide_args: dict = None, + engine_args: EngineArgs, ): - self.server_args = server_args + self.engine_args = engine_args # Init inter-process communication context = zmq.asyncio.Context(2) self.recv_from_detokenizer = context.socket(zmq.PULL) - self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}") + self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{engine_args.tokenizer_port}") self.send_to_router = context.socket(zmq.PUSH) - self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}") + self.send_to_router.connect(f"tcp://127.0.0.1:{engine_args.controller_port}") # Read model args - self.model_path = server_args.model_path - self.served_model_name = server_args.served_model_name + self.model_path = engine_args.model_path + self.served_model_name = engine_args.served_model_name self.hf_config = get_config( self.model_path, - trust_remote_code=server_args.trust_remote_code, - model_overide_args=model_overide_args, + trust_remote_code=engine_args.trust_remote_code, + model_overide_args=engine_args.model_override_args, ) self.is_generation = is_generation_model( - self.hf_config.architectures, self.server_args.is_embedding + self.hf_config.architectures, self.engine_args.is_embedding ) - self.context_len = server_args.context_length or get_context_length( + self.context_len = engine_args.context_length or get_context_length( self.hf_config ) # Create tokenizer - if server_args.skip_tokenizer_init: + if engine_args.skip_tokenizer_init: self.tokenizer = self.processor = None else: if is_multimodal_model(self.hf_config.architectures): self.processor = get_processor( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, + engine_args.tokenizer_path, + tokenizer_mode=engine_args.tokenizer_mode, + trust_remote_code=engine_args.trust_remote_code, ) self.tokenizer = self.processor.tokenizer os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -122,17 +120,19 @@ def __init__( self.executor = concurrent.futures.ProcessPoolExecutor( initializer=init_global_processor, mp_context=mp.get_context("fork"), - initargs=(server_args,), + initargs=(engine_args,), ) else: self.tokenizer = get_tokenizer( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, + engine_args.tokenizer_path, + tokenizer_mode=engine_args.tokenizer_mode, + trust_remote_code=engine_args.trust_remote_code, ) # Store states self.to_create_loop = True + self.handle_loop_task = None + self.should_stop_loop = False self.rid_to_state: Dict[str, ReqState] = {} # For update model weights @@ -465,7 +465,7 @@ async def _wait_for_response( out["index"] = response_index # Log requests - if self.server_args.log_requests and state.finished: + if self.engine_args.log_requests and state.finished: logger.info(f"in={obj}, out={out}") state.out_list = [] @@ -515,9 +515,9 @@ async def update_weights( if self.to_create_loop: self.create_handle_loop() - # default the load format to the server_args + # default the load format to the engine_args if obj.load_format is None: - obj.load_format = self.server_args.load_format + obj.load_format = self.engine_args.load_format if not self.model_update_lock.locked(): async with self.model_update_lock: @@ -528,8 +528,8 @@ async def update_weights( self.model_update_result = asyncio.Future() result = await self.model_update_result if result.success: - self.server_args.model_path = obj.model_path - self.server_args.load_format = obj.load_format + self.engine_args.model_path = obj.model_path + self.engine_args.load_format = obj.load_format self.model_path = obj.model_path return result.success, result.message else: @@ -555,53 +555,74 @@ def create_handle_loop(self): self.to_create_loop = False loop = asyncio.get_event_loop() - loop.create_task(self.handle_loop()) + self.handle_loop_task = loop.create_task(self.handle_loop()) async def handle_loop(self): """The event loop that handles requests""" - - while True: - recv_obj: Union[ - BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput - ] = await self.recv_from_detokenizer.recv_pyobj() - - if isinstance(recv_obj, UpdateWeightReqOutput): - self.model_update_result.set_result(recv_obj) + poller = zmq.asyncio.Poller() + poller.register(self.recv_from_detokenizer, zmq.POLLIN) + + try: + while True: + if self.should_stop_loop and not poller.poll(timeout=0): + logger.info( + "No more messages and shutdown requested, exiting loop." + ) + break + + # any new events? + events = dict(await poller.poll(timeout=100)) # 100ms + if self.recv_from_detokenizer in events: + # yes and process it + recv_obj: Union[ + BatchStrOut, + BatchEmbeddingOut, + BatchTokenIDOut, + UpdateWeightReqOutput, + ] = await self.recv_from_detokenizer.recv_pyobj() + + if isinstance(recv_obj, UpdateWeightReqOutput): + self.model_update_result.set_result(recv_obj) + continue + + await self._process_recv_obj(recv_obj) + except Exception as e: + logger.error(f"Exception in handle_loop: {e}") + + async def _process_recv_obj(self, recv_obj): + assert isinstance( + recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut) + ), f"Unexpected obj received: {type(recv_obj)}" + + for i, rid in enumerate(recv_obj.rids): + state = self.rid_to_state.get(rid, None) + if state is None: continue - assert isinstance( - recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut) - ), f"Unexpected obj received: {type(recv_obj)}" + recv_obj.meta_info[i]["id"] = rid + if isinstance(recv_obj, BatchStrOut): + out_dict = { + "text": recv_obj.output_strs[i], + "meta_info": recv_obj.meta_info[i], + } + elif isinstance(recv_obj, BatchTokenIDOut): + read_start = 0 if i == 0 else recv_obj.read_offsets[i - 1] + out_dict = { + "token_ids": recv_obj.decode_ids[ + read_start : recv_obj.read_offsets[i] + ], + "meta_info": recv_obj.meta_info[i], + } - for i, rid in enumerate(recv_obj.rids): - state = self.rid_to_state.get(rid, None) - if state is None: - continue - - recv_obj.meta_info[i]["id"] = rid - if isinstance(recv_obj, BatchStrOut): - out_dict = { - "text": recv_obj.output_strs[i], - "meta_info": recv_obj.meta_info[i], - } - elif isinstance(recv_obj, BatchTokenIDOut): - read_start = 0 if i == 0 else recv_obj.read_offsets[i - 1] - out_dict = { - "token_ids": recv_obj.decode_ids[ - read_start : recv_obj.read_offsets[i] - ], - "meta_info": recv_obj.meta_info[i], - } - - else: - assert isinstance(recv_obj, BatchEmbeddingOut) - out_dict = { - "embedding": recv_obj.embeddings[i], - "meta_info": recv_obj.meta_info[i], - } - state.out_list.append(out_dict) - state.finished = recv_obj.finished_reason[i] is not None - state.event.set() + else: + assert isinstance(recv_obj, BatchEmbeddingOut) + out_dict = { + "embedding": recv_obj.embeddings[i], + "meta_info": recv_obj.meta_info[i], + } + state.out_list.append(out_dict) + state.finished = recv_obj.finished_reason[i] is not None + state.event.set() def convert_logprob_style( self, @@ -656,6 +677,53 @@ def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool): ) return top_logprobs + def shutdown(self): + """Synchronous API for TokenizerManager shutdown. + + Note! This function is supposed to be called from Engine, and + we assume only one Engine instance is running, thus it is written + without synchronization prmitive. + + This API is synchronous, it will set the should_stop_loop flag to + bring the event loop down, and closes the sockets and the ZMQ context + in a safe manner. + """ + # This flags the handle_loop() to stop(when finishing its last job.) + self.should_stop_loop = True + + asyncio.run(self._shutdown_async()) + + logger.info("TokenizerManager shutdown completed!") + + async def _shutdown_async(self): + """Asynchronous part of shutdown logic. + This is not an exposed API. (Shall we do an async shutdown??) + """ + # Wait for the handle_loop() is done, which means + # self.recv_from_detokenizer is finished with its job. + if self.handle_loop_task is not None: + try: + await self.handle_loop_task + except asyncio.CancelledError: + pass + + # Close the sender socket first. + if not self.send_to_router.closed: + self.send_to_router.close() + logger.info("send_to_router socket closed.") + + # Now close the receiver + if not self.recv_from_detokenizer.closed: + self.recv_from_detokenizer.close() + logger.info("recv_from_detokenizer socket closed.") + + # Finally close the context, which is shared by + # recv_from_detokenizer and send_to_router, so we only + # close it once. + if not self.recv_from_detokenizer.context.closed: + self.recv_from_detokenizer.context.term() + logger.info("ZeroMQ context terminated.") + async def _get_pixel_values(self, image_data: List[Union[str, bytes]]): if not image_data: return None, None, None @@ -723,14 +791,14 @@ async def _process_single_image( global global_processor -def init_global_processor(server_args: ServerArgs): +def init_global_processor(engine_args: EngineArgs): """Init the global processor for multi modal models.""" global global_processor transformers.logging.set_verbosity_error() global_processor = get_processor( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, + engine_args.tokenizer_path, + tokenizer_mode=engine_args.tokenizer_mode, + trust_remote_code=engine_args.trust_remote_code, ) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 123b1f5d5d..83079e2a6a 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -54,7 +54,7 @@ from sglang.srt.model_config import ModelConfig from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.model_runner import ModelRunner -from sglang.srt.server_args import ServerArgs +from sglang.srt.serving.engine_args import EngineArgs from sglang.srt.utils import ( configure_logger, is_multimodal_model, @@ -74,60 +74,59 @@ def __init__( self, gpu_id: int, tp_rank: int, - server_args: ServerArgs, nccl_port: int, - model_overide_args: dict, + engine_args: EngineArgs, ): suppress_other_loggers() # Copy arguments self.gpu_id = gpu_id self.tp_rank = tp_rank - self.tp_size = server_args.tp_size - self.dp_size = server_args.dp_size - self.schedule_policy = server_args.schedule_policy - self.disable_regex_jump_forward = server_args.disable_regex_jump_forward + self.tp_size = engine_args.tp_size + self.dp_size = engine_args.dp_size + self.schedule_policy = engine_args.schedule_policy + self.disable_regex_jump_forward = engine_args.disable_regex_jump_forward # Init model and tokenizer self.model_config = ModelConfig( - server_args.model_path, - server_args.trust_remote_code, - context_length=server_args.context_length, - model_overide_args=model_overide_args, + engine_args.model_path, + engine_args.trust_remote_code, + context_length=engine_args.context_length, + model_overide_args=engine_args.model_override_args, ) self.model_runner = ModelRunner( model_config=self.model_config, - mem_fraction_static=server_args.mem_fraction_static, + mem_fraction_static=engine_args.mem_fraction_static, gpu_id=gpu_id, tp_rank=tp_rank, - tp_size=server_args.tp_size, + tp_size=engine_args.tp_size, nccl_port=nccl_port, - server_args=server_args, + engine_args=engine_args, ) - if server_args.skip_tokenizer_init: + if engine_args.skip_tokenizer_init: self.tokenizer = self.processor = None else: if is_multimodal_model(self.model_config.hf_config.architectures): self.processor = get_processor( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, + engine_args.tokenizer_path, + tokenizer_mode=engine_args.tokenizer_mode, + trust_remote_code=engine_args.trust_remote_code, ) self.tokenizer = self.processor.tokenizer else: self.tokenizer = get_tokenizer( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, + engine_args.tokenizer_path, + tokenizer_mode=engine_args.tokenizer_mode, + trust_remote_code=engine_args.trust_remote_code, ) self.max_total_num_tokens = self.model_runner.max_total_num_tokens - self.max_prefill_tokens = server_args.max_prefill_tokens + self.max_prefill_tokens = engine_args.max_prefill_tokens self.max_running_requests = min( ( self.max_total_num_tokens // 2 - if server_args.max_running_requests is None - else server_args.max_running_requests + if engine_args.max_running_requests is None + else engine_args.max_running_requests ), self.model_runner.req_to_token_pool.size - 1, ) @@ -137,12 +136,12 @@ def __init__( ) # Sync random seed - server_args.random_seed = broadcast_recv_input( - [server_args.random_seed], + engine_args.random_seed = broadcast_recv_input( + [engine_args.random_seed], self.tp_rank, self.model_runner.tp_group.cpu_group, )[0] - set_random_seed(server_args.random_seed) + set_random_seed(engine_args.random_seed) # Print info logger.info( @@ -154,8 +153,8 @@ def __init__( # Init cache if ( - server_args.chunked_prefill_size is not None - and server_args.disable_radix_cache + engine_args.chunked_prefill_size is not None + and engine_args.disable_radix_cache ): self.tree_cache = ChunkCache( req_to_token_pool=self.model_runner.req_to_token_pool, @@ -165,7 +164,7 @@ def __init__( self.tree_cache = RadixCache( req_to_token_pool=self.model_runner.req_to_token_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool, - disable=server_args.disable_radix_cache, + disable=engine_args.disable_radix_cache, ) self.tree_cache_metrics = {"total": 0, "hit": 0} self.scheduler = PolicyScheduler(self.schedule_policy, self.tree_cache) @@ -177,46 +176,46 @@ def __init__( self.running_batch: ScheduleBatch = None self.out_pyobjs = [] self.decode_forward_ct = 0 - self.stream_interval = server_args.stream_interval + self.stream_interval = engine_args.stream_interval self.num_generated_tokens = 0 self.last_stats_tic = time.time() # Chunked prefill - self.chunked_prefill_size = server_args.chunked_prefill_size + self.chunked_prefill_size = engine_args.chunked_prefill_size self.current_inflight_req = None self.is_mixed_chunk = ( - self.chunked_prefill_size is not None and server_args.enable_mixed_chunk + self.chunked_prefill_size is not None and engine_args.enable_mixed_chunk ) # Init the FSM cache for constrained generation - if not server_args.skip_tokenizer_init: + if not engine_args.skip_tokenizer_init: self.regex_fsm_cache = FSMCache( - server_args.tokenizer_path, + engine_args.tokenizer_path, { - "tokenizer_mode": server_args.tokenizer_mode, - "trust_remote_code": server_args.trust_remote_code, + "tokenizer_mode": engine_args.tokenizer_mode, + "trust_remote_code": engine_args.trust_remote_code, }, - skip_tokenizer_init=server_args.skip_tokenizer_init, + skip_tokenizer_init=engine_args.skip_tokenizer_init, json_schema_mode=False, ) self.json_fsm_cache = FSMCache( - server_args.tokenizer_path, + engine_args.tokenizer_path, { - "tokenizer_mode": server_args.tokenizer_mode, - "trust_remote_code": server_args.trust_remote_code, + "tokenizer_mode": engine_args.tokenizer_mode, + "trust_remote_code": engine_args.trust_remote_code, }, - skip_tokenizer_init=server_args.skip_tokenizer_init, + skip_tokenizer_init=engine_args.skip_tokenizer_init, json_schema_mode=True, ) self.jump_forward_cache = JumpForwardCache() # Init new token estimation assert ( - server_args.schedule_conservativeness >= 0 + engine_args.schedule_conservativeness >= 0 ), "Invalid schedule_conservativeness" self.min_new_token_ratio = min( global_config.base_min_new_token_ratio - * server_args.schedule_conservativeness, + * engine_args.schedule_conservativeness, 1.0, ) self.new_token_ratio = self.min_new_token_ratio @@ -874,20 +873,18 @@ def update_weights(self, recv_req): def run_tp_server( gpu_id: int, tp_rank: int, - server_args: ServerArgs, nccl_port: int, - model_overide_args: dict, + engine_args: EngineArgs, ): """Run a tensor parallel model server.""" - configure_logger(server_args, prefix=f" TP{tp_rank}") + configure_logger(engine_args.log_level, prefix=f" TP{tp_rank}") try: model_server = ModelTpServer( gpu_id, tp_rank, - server_args, nccl_port, - model_overide_args, + engine_args, ) tp_cpu_group = model_server.model_runner.tp_group.cpu_group @@ -902,16 +899,15 @@ def run_tp_server( def launch_tp_servers( gpu_ids: List[int], tp_rank_range: List[int], - server_args: ServerArgs, nccl_port: int, - model_overide_args: dict, + engine_args: EngineArgs, ): """Launch multiple tensor parallel servers.""" procs = [] for i in tp_rank_range: proc = multiprocessing.Process( target=run_tp_server, - args=(gpu_ids[i], i, server_args, nccl_port, model_overide_args), + args=(gpu_ids[i], i, nccl_port, engine_args), ) proc.start() procs.append(proc) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 3d40c9d755..eff132e58d 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -192,18 +192,18 @@ def from_schedule_batch( if ( forward_mode != ForwardMode.DECODE - or model_runner.server_args.disable_flashinfer + or model_runner.engine_args.disable_flashinfer ): ret.total_num_tokens = int(torch.sum(ret.seq_lens)) if forward_mode != ForwardMode.DECODE: ret.init_multimuldal_info(batch) - if model_runner.server_args.disable_flashinfer: + if model_runner.engine_args.disable_flashinfer: ret.init_triton_args(batch) flashinfer_use_ragged = False - if not model_runner.server_args.disable_flashinfer: + if not model_runner.engine_args.disable_flashinfer: if ( forward_mode != ForwardMode.DECODE and int(torch.sum(ret.seq_lens)) > 4096 diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index e6f5e74311..7c6d176935 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -54,7 +54,7 @@ ) from sglang.srt.model_config import AttentionArch, ModelConfig from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata -from sglang.srt.server_args import ServerArgs +from sglang.srt.serving.engine_args import EngineArgs from sglang.srt.utils import ( get_available_gpu_memory, is_generation_model, @@ -77,7 +77,7 @@ def __init__( tp_rank: int, tp_size: int, nccl_port: int, - server_args: ServerArgs, + engine_args: EngineArgs, ): # Parse args self.model_config = model_config @@ -86,16 +86,16 @@ def __init__( self.tp_rank = tp_rank self.tp_size = tp_size self.nccl_port = nccl_port - self.server_args = server_args + self.engine_args = engine_args self.is_multimodal_model = is_multimodal_model( self.model_config.hf_config.architectures ) global_server_args_dict.update( { - "disable_flashinfer": server_args.disable_flashinfer, - "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling, - "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32, - "enable_mla": server_args.enable_mla, + "disable_flashinfer": engine_args.disable_flashinfer, + "disable_flashinfer_sampling": engine_args.disable_flashinfer_sampling, + "triton_attention_reduce_in_fp32": engine_args.triton_attention_reduce_in_fp32, + "enable_mla": engine_args.enable_mla, } ) @@ -110,8 +110,8 @@ def __init__( self.load_model() self.init_memory_pool( min_per_gpu_memory, - server_args.max_num_reqs, - server_args.max_total_tokens, + engine_args.max_num_reqs, + engine_args.max_total_tokens, ) self.init_cublas() self.init_flashinfer() @@ -122,14 +122,14 @@ def init_torch_distributed(self): torch.cuda.set_device(self.gpu_id) logger.info("Init nccl begin.") - if not self.server_args.enable_p2p_check: + if not self.engine_args.enable_p2p_check: monkey_patch_vllm_p2p_access_check(self.gpu_id) - if self.server_args.nccl_init_addr: - nccl_init_method = f"tcp://{self.server_args.nccl_init_addr}" + if self.engine_args.nccl_init_addr: + nccl_init_method = f"tcp://{self.engine_args.nccl_init_addr}" else: nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}" - set_custom_all_reduce(not self.server_args.disable_custom_all_reduce) + set_custom_all_reduce(not self.engine_args.disable_custom_all_reduce) init_distributed_environment( backend="nccl", world_size=self.tp_size, @@ -146,7 +146,7 @@ def init_torch_distributed(self): # Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph, # so we disable padding in cuda graph. if not all(in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)): - self.server_args.disable_cuda_graph_padding = True + self.engine_args.disable_cuda_graph_padding = True logger.info( "Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism." ) @@ -169,20 +169,20 @@ def load_model(self): logger.info( "Compute capability below sm80. Use float16 due to lack of bfloat16 support." ) - self.server_args.dtype = "float16" + self.engine_args.dtype = "float16" if torch.cuda.get_device_capability()[1] < 5: raise RuntimeError("SGLang only supports sm75 and above.") monkey_patch_vllm_dummy_weight_loader() self.device_config = DeviceConfig() - self.load_config = LoadConfig(load_format=self.server_args.load_format) + self.load_config = LoadConfig(load_format=self.engine_args.load_format) self.vllm_model_config = VllmModelConfig( - model=self.server_args.model_path, - quantization=self.server_args.quantization, + model=self.engine_args.model_path, + quantization=self.engine_args.quantization, tokenizer=None, tokenizer_mode=None, - trust_remote_code=self.server_args.trust_remote_code, - dtype=self.server_args.dtype, + trust_remote_code=self.engine_args.trust_remote_code, + dtype=self.engine_args.dtype, seed=42, skip_tokenizer_init=True, ) @@ -215,7 +215,7 @@ def load_model(self): else None ) self.is_generation = is_generation_model( - self.model_config.hf_config.architectures, self.server_args.is_embedding + self.model_config.hf_config.architectures, self.engine_args.is_embedding ) logger.info( @@ -245,11 +245,11 @@ def update_weights(self, model_path: str, load_format: str): # TODO: Use a better method to check this vllm_model_config = VllmModelConfig( model=model_path, - quantization=self.server_args.quantization, + quantization=self.engine_args.quantization, tokenizer=None, tokenizer_mode=None, - trust_remote_code=self.server_args.trust_remote_code, - dtype=self.server_args.dtype, + trust_remote_code=self.engine_args.trust_remote_code, + dtype=self.engine_args.dtype, seed=42, skip_tokenizer_init=True, ) @@ -303,8 +303,8 @@ def model_load_weights(model, iter): return False, message self.model = model - self.server_args.model_path = model_path - self.server_args.load_format = load_format + self.engine_args.model_path = model_path + self.engine_args.load_format = load_format self.vllm_model_config = vllm_model_config self.load_config = load_config self.model_config.path = model_path @@ -318,7 +318,7 @@ def profile_max_num_token(self, total_gpu_memory: int): ) if ( self.model_config.attention_arch == AttentionArch.MLA - and self.server_args.enable_mla + and self.engine_args.enable_mla ): cell_size = ( (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim) @@ -345,10 +345,10 @@ def init_memory_pool( max_num_reqs: int = None, max_total_tokens: int = None, ): - if self.server_args.kv_cache_dtype == "auto": + if self.engine_args.kv_cache_dtype == "auto": self.kv_cache_dtype = self.dtype - elif self.server_args.kv_cache_dtype == "fp8_e5m2": - if self.server_args.disable_flashinfer or self.server_args.enable_mla: + elif self.engine_args.kv_cache_dtype == "fp8_e5m2": + if self.engine_args.disable_flashinfer or self.engine_args.enable_mla: logger.warning( "FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype" ) @@ -357,7 +357,7 @@ def init_memory_pool( self.kv_cache_dtype = torch.float8_e5m2 else: raise ValueError( - f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}." + f"Unsupported kv_cache_dtype: {self.engine_args.kv_cache_dtype}." ) self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory) @@ -392,7 +392,7 @@ def init_memory_pool( ) if ( self.model_config.attention_arch == AttentionArch.MLA - and self.server_args.enable_mla + and self.engine_args.enable_mla ): self.token_to_kv_pool = MLATokenToKVPool( self.max_total_num_tokens, @@ -403,7 +403,7 @@ def init_memory_pool( ) logger.info("using MLA Triton implementaion, flashinfer is disabled") # FIXME: temporarily only Triton MLA is supported - self.server_args.disable_flashinfer = True + self.engine_args.disable_flashinfer = True else: self.token_to_kv_pool = MHATokenToKVPool( self.max_total_num_tokens, @@ -428,7 +428,7 @@ def init_cublas(self): def init_flashinfer(self): """Init flashinfer attention kernel wrappers.""" - if self.server_args.disable_flashinfer: + if self.engine_args.disable_flashinfer: assert ( self.sliding_window_size is None ), "turn on flashinfer to support window attention" @@ -495,13 +495,13 @@ def init_cuda_graphs(self): from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner - if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer: + if self.engine_args.disable_cuda_graph or self.engine_args.disable_flashinfer: self.cuda_graph_runner = None return logger.info("Capture cuda graph begin. This can take up to several minutes.") - if self.server_args.disable_cuda_graph_padding: + if self.engine_args.disable_cuda_graph_padding: batch_size_list = list(range(1, 32)) + [64, 128] else: batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)] @@ -509,8 +509,8 @@ def init_cuda_graphs(self): self.cuda_graph_runner = CudaGraphRunner( self, max_batch_size_to_capture=max(batch_size_list), - use_torch_compile=self.server_args.enable_torch_compile, - disable_padding=self.server_args.disable_cuda_graph_padding, + use_torch_compile=self.engine_args.enable_torch_compile, + disable_padding=self.engine_args.disable_cuda_graph_padding, ) try: self.cuda_graph_runner.capture(batch_size_list) diff --git a/python/sglang/srt/serving/__init__.py b/python/sglang/srt/serving/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/sglang/srt/serving/engine.py b/python/sglang/srt/serving/engine.py new file mode 100644 index 0000000000..4957a82ad4 --- /dev/null +++ b/python/sglang/srt/serving/engine.py @@ -0,0 +1,243 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import asyncio +import logging +import multiprocessing as mp +import os +import sys +from dataclasses import fields +from typing import Dict, List, Optional, Union + +from sglang.srt.managers.controller_multi import ( + start_controller_process as start_controller_process_multi, +) +from sglang.srt.managers.controller_single import launch_tp_servers +from sglang.srt.managers.controller_single import ( + start_controller_process as start_controller_process_single, +) +from sglang.srt.managers.detokenizer_manager import start_detokenizer_process +from sglang.srt.managers.io_struct import GenerateReqInput +from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.srt.sampling.sampling_params import SamplingParams +from sglang.srt.serving.engine_args import EngineArgs +from sglang.srt.utils import kill_child_process, prepare_model, prepare_tokenizer + +logger = logging.getLogger(__name__) + + +class Engine: + """ + The core LLM Engine + """ + + def __init__(self, engine_args: EngineArgs): + self.engine_args = engine_args + + # Spin up the engine. + self.startup() + + def startup(self): + """ + Start the Engine, corresponding to the shutdown method. + """ + # Prepare model and tokenizer + self.engine_args.model_path = prepare_model(self.engine_args.model_path) + self.engine_args.tokenizer_path = prepare_tokenizer( + self.engine_args.tokenizer_path + ) + + # Launch processes for multi-node tensor parallelism + self.tp_procs = None + if self.engine_args.nnodes > 1 and self.engine_args.node_rank != 0: + tp_size_local = self.engine_args.tp_size // self.engine_args.nnodes + gpu_ids = [ + i for _ in range(self.engine_args.nnodes) for i in range(tp_size_local) + ] + tp_rank_range = list( + range( + self.engine_args.node_rank * tp_size_local, + (self.engine_args.node_rank + 1) * tp_size_local, + ) + ) + self.tp_procs = launch_tp_servers( + gpu_ids, + tp_rank_range, + self.engine_args.nccl_ports[0], + self.engine_args, + ) + try: + for p in self.tp_procs: + p.join() + finally: + kill_child_process(os.getpid(), including_parent=False) + return + + # Initialize TokenizerManager and other processes + self.tokenizer_manager = TokenizerManager(self.engine_args) + + pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False) + pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False) + + if self.engine_args.dp_size == 1: + start_process = start_controller_process_single + else: + start_process = start_controller_process_multi + + self.proc_controller = mp.Process( + target=start_process, + args=( + self.engine_args, + pipe_controller_writer, + ), + ) + self.proc_controller.start() + + self.proc_detoken = mp.Process( + target=start_detokenizer_process, + args=( + self.engine_args, + pipe_detoken_writer, + ), + ) + self.proc_detoken.start() + + # Wait for the model to finish loading + controller_init_state = pipe_controller_reader.recv() + detoken_init_state = pipe_detoken_reader.recv() + + if controller_init_state != "init ok" or detoken_init_state != "init ok": + self.proc_controller.kill() + self.proc_detoken.kill() + raise RuntimeError( + "Initialization failed. " + f"controller_init_state: {controller_init_state}, " + f"detoken_init_state: {detoken_init_state}" + ) + + assert self.proc_controller.is_alive() and self.proc_detoken.is_alive() + logger.info(f"Engine successfully started.") + + def shutdown(self): + # Shutdown the tokenizer_manager first, to make sure no more requests come in. + self.tokenizer_manager.shutdown() + + # Once tokenizer_manager is shut down, we can safely shutdown Engine + + # Terminate and join TP processes if they exist + if self.tp_procs: + for proc in self.tp_procs: + if proc.is_alive(): + proc.terminate() + proc.join() + + # Shutdown proc_controller(which processes requests from tokenizer_manager), and + # proc_detoken(which precoesses response to final ouput)/ + for proc in [self.proc_controller, self.proc_detoken]: + if proc.is_alive(): + proc.terminate() + proc.join() + + +class LLM: + def __init__( + self, + model: str, + tokenizer: Optional[str] = None, + tokenizer_mode: str = "auto", + skip_tokenizer_init: bool = False, + trust_remote_code: bool = True, + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + seed: int = 0, + context_length: Optional[int] = None, + **kwargs, + ) -> None: + engine_arg_fields = {field.name for field in fields(EngineArgs)} + filtered_kwargs = {k: v for k, v in kwargs.items() if k in engine_arg_fields} + + # Warn about any extra kwargs + extra_kwargs = {k: v for k, v in kwargs.items() if k not in engine_arg_fields} + if extra_kwargs: + logger.warn(f"Warning: Ignored unexpected kwargs: {extra_kwargs}") + + engine_args = EngineArgs( + model_path=model, + tokenizer_path=tokenizer, + tokenizer_mode=tokenizer_mode, + skip_tokenizer_init=skip_tokenizer_init, + trust_remote_code=trust_remote_code, + tp_size=tensor_parallel_size, + dtype=dtype, + quantization=quantization, + random_seed=seed, + context_length=context_length, + **filtered_kwargs, + ) + self.llm_engine = Engine(engine_args) + + def generate( + self, + prompts: Optional[Union[List[str], str]] = None, + sampling_params: Optional[ + Union["SamplingParams", List["SamplingParams"]] + ] = None, + prompt_token_ids: Optional[Union[List[List[int]], List[int]]] = None, + ): + if prompts is None and prompt_token_ids is None: + raise ValueError("Either 'prompts' or 'prompt_token_ids' must be provided.") + + if isinstance(prompts, str): + prompts = [prompts] + + if sampling_params is None: + sampling_params_dicts = [{} for _ in prompts] + elif isinstance(sampling_params, List): + sampling_params_dicts = [sp.to_dict() for sp in sampling_params] + else: + sampling_params_dicts = [sampling_params.to_srt_kwargs() for _ in prompts] + + gen_req_input = GenerateReqInput( + text=prompts, + input_ids=prompt_token_ids, + sampling_params=sampling_params_dicts, + ) + + try: + request = None + + # Use a synchronous call to run the async helper + results = asyncio.run(self._generate_async_helper(gen_req_input, request)) + + # Shutdown the engine + self.llm_engine.shutdown() + + return results + + except ValueError as e: + raise e + + async def _generate_async_helper(self, gen_req_input, request): + results = [] + async for response in self.llm_engine.tokenizer_manager.generate_request( + gen_req_input, request + ): + if isinstance(response, list): + # if gen_req_input is a list input, it is deemed a batched input, then the response is already a list + results.extend(response) + else: + results.append(response) + return results diff --git a/python/sglang/srt/serving/engine_args.py b/python/sglang/srt/serving/engine_args.py new file mode 100644 index 0000000000..974f44254b --- /dev/null +++ b/python/sglang/srt/serving/engine_args.py @@ -0,0 +1,216 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""The LLM engine arguments.""" + +import dataclasses +import logging +import os +import random +from typing import List, Optional, Union + +from sglang.srt.constrained import disable_cache +from sglang.srt.utils import ( + allocate_init_ports, + assert_pkg_version, + enable_show_time_cost, + maybe_set_triton_cache_manager, + set_ulimit, +) + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class EngineArgs: + # Model and tokenizer + model_path: str + tokenizer_path: Optional[str] = None + tokenizer_mode: str = "auto" + skip_tokenizer_init: bool = False + load_format: str = "auto" + dtype: str = "auto" + kv_cache_dtype: str = "auto" + trust_remote_code: bool = True + context_length: Optional[int] = None + quantization: Optional[str] = None + served_model_name: Optional[str] = None + random_seed: Optional[int] = None + stream_interval: int = 1 + tokenizer_port: Optional[int] = 0 + detokenizer_port: Optional[int] = 0 + controller_port: Optional[int] = 0 + is_embedding: bool = False + model_override_args: Optional[dict] = None + + # Scheduling + mem_fraction_static: Optional[float] = None + max_running_requests: Optional[int] = None + max_num_reqs: Optional[int] = None + max_total_tokens: Optional[int] = None + chunked_prefill_size: int = 8192 + max_prefill_tokens: int = 16384 + schedule_policy: str = "lpm" + schedule_conservativeness: float = 1.0 + + # Parallelism + tp_size: int = 1 + dp_size: int = 1 + load_balance_method: str = "round_robin" + nccl_init_addr: Optional[str] = None + nccl_ports: Optional[List[int]] = None + additional_ports: Optional[Union[List[int], int]] = None + nnodes: int = 1 + node_rank: Optional[int] = None + + # Optimization + disable_flashinfer: bool = False + disable_flashinfer_sampling: bool = False + disable_radix_cache: bool = False + disable_regex_jump_forward: bool = False + disable_cuda_graph: bool = False + disable_cuda_graph_padding: bool = False + disable_disk_cache: bool = False + disable_custom_all_reduce: bool = False + enable_mixed_chunk: bool = False + enable_torch_compile: bool = False + enable_p2p_check: bool = False + enable_mla: bool = False + triton_attention_reduce_in_fp32: bool = False + efficient_weight_load: bool = False + + # Observability + log_level: str = "info" + log_level_http: Optional[str] = None + log_requests: bool = False + show_time_cost: bool = False + + def __post_init__(self): + if self.tokenizer_path is None: + self.tokenizer_path = self.model_path + + if self.served_model_name is None: + self.served_model_name = self.model_path + + if self.chunked_prefill_size <= 0: + # Disable chunked prefill + self.chunked_prefill_size = None + + if self.mem_fraction_static is None: + if self.tp_size >= 16: + self.mem_fraction_static = 0.79 + elif self.tp_size >= 8: + self.mem_fraction_static = 0.83 + elif self.tp_size >= 4: + self.mem_fraction_static = 0.85 + elif self.tp_size >= 2: + self.mem_fraction_static = 0.87 + else: + self.mem_fraction_static = 0.88 + + if isinstance(self.additional_ports, int): + self.additional_ports = [self.additional_ports] + elif self.additional_ports is None: + self.additional_ports = [] + + if self.random_seed is None: + self.random_seed = random.randint(0, 1 << 30) + + self._check_args() + + self._alloc_port_args() + + self._set_envs_and_config() + + def _alloc_port_args(self): + if isinstance(self.additional_ports, int): + self.additional_ports = [self.additional_ports] + elif self.additional_ports is None: + self.additional_ports = [] + + _, ports = allocate_init_ports( + 30000, + self.additional_ports, + self.dp_size, + ) + self.tokenizer_port = ports[0] + self.controller_port = ports[1] + self.detokenizer_port = ports[2] + self.nccl_ports = ports[3:] + logger.info( + f"Allocated port args: tokenizer_port({self.tokenizer_port}), controller_port({self.controller_port})," + f"detokenizer_port({self.detokenizer_port}), nccl_ports({self.nccl_ports})" + ) + + def _check_args(self): + assert ( + self.tp_size % self.nnodes == 0 + ), "tp_size must be divisible by number of nodes" + + assert not ( + self.dp_size > 1 and self.node_rank is not None + ), "multi-node data parallel is not supported" + + if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path: + logger.info( + "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True" + ) + self.trust_remote_code = False + + if "gemma-2" in self.model_path.lower(): + logger.info( + f"When using sliding window in gemma-2, disable radix_cache, regex_jump_forward, and turn on flashinfer." + ) + # FIXME: compatibility with radix attention + self.disable_radix_cache = True + # FIXME: compatibility with jump forward + self.disable_regex_jump_forward = True + self.disable_flashinfer = False + # FIXME: compatibility with chunked prefill + self.chunked_prefill_size = None + + def _set_envs_and_config(self): + # Set global environments + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + os.environ["NCCL_CUMEM_ENABLE"] = "0" + os.environ["NCCL_NVLS_ENABLE"] = "0" + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + + # Set ulimit + set_ulimit() + + # Enable show time cost for debugging + if self.show_time_cost: + enable_show_time_cost() + + # Disable disk cache + if self.disable_disk_cache: + disable_cache() + + # Fix triton bugs + if self.tp_size * self.dp_size > 1: + # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. + maybe_set_triton_cache_manager() + + # Check flashinfer version + if not self.disable_flashinfer: + assert_pkg_version( + "flashinfer", + "0.1.6", + "Please uninstall the old version and " + "reinstall the latest version by following the instructions " + "at https://docs.flashinfer.ai/installation.html.", + ) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/serving/server.py similarity index 66% rename from python/sglang/srt/server.py rename to python/sglang/srt/serving/server.py index 5ba2a45e70..eb4bf9802f 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/serving/server.py @@ -24,6 +24,7 @@ import logging import multiprocessing as mp import os +import sys import threading import time from http import HTTPStatus @@ -40,22 +41,12 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint -from sglang.srt.constrained import disable_cache from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.managers.controller_multi import ( - start_controller_process as start_controller_process_multi, -) -from sglang.srt.managers.controller_single import launch_tp_servers -from sglang.srt.managers.controller_single import ( - start_controller_process as start_controller_process_single, -) -from sglang.srt.managers.detokenizer_manager import start_detokenizer_process from sglang.srt.managers.io_struct import ( EmbeddingReqInput, GenerateReqInput, UpdateWeightReqInput, ) -from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.openai_api.adapter import ( load_chat_template_for_openai_api, v1_batches, @@ -70,18 +61,12 @@ v1_retrieve_file_content, ) from sglang.srt.openai_api.protocol import ModelCard, ModelList -from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.serving.engine import Engine, EngineArgs +from sglang.srt.serving.server_args import ServerArgs from sglang.srt.utils import ( add_api_key_middleware, - allocate_init_ports, - assert_pkg_version, configure_logger, - enable_show_time_cost, kill_child_process, - maybe_set_triton_cache_manager, - prepare_model, - prepare_tokenizer, - set_ulimit, ) from sglang.utils import get_exception_traceback @@ -91,7 +76,10 @@ app = FastAPI() -tokenizer_manager = None +engine: Engine = None + +# for OpenAI files API +file_storage_pth: str @app.get("/health") @@ -107,7 +95,7 @@ async def health_generate(request: Request) -> Response: text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7} ) try: - async for _ in tokenizer_manager.generate_request(gri, request): + async for _ in engine.tokenizer_manager.generate_request(gri, request): break return Response(status_code=200) except Exception as e: @@ -118,20 +106,20 @@ async def health_generate(request: Request) -> Response: @app.get("/get_model_info") async def get_model_info(): result = { - "model_path": tokenizer_manager.model_path, - "is_generation": tokenizer_manager.is_generation, + "model_path": engine.tokenizer_manager.model_path, + "is_generation": engine.tokenizer_manager.is_generation, } return result @app.get("/get_server_args") async def get_server_args(): - return dataclasses.asdict(tokenizer_manager.server_args) + return dataclasses.asdict(engine.tokenizer_manager.server_args) @app.get("/flush_cache") async def flush_cache(): - tokenizer_manager.flush_cache() + engine.tokenizer_manager.flush_cache() return Response( content="Cache flushed.\nPlease check backend logs for more details. " "(When there are running or waiting requests, the operation will not be performed.)\n", @@ -142,7 +130,7 @@ async def flush_cache(): @app.post("/update_weights") async def update_weights(obj: UpdateWeightReqInput, request: Request): - success, message = await tokenizer_manager.update_weights(obj, request) + success, message = await engine.tokenizer_manager.update_weights(obj, request) content = {"message": message, "success": str(success)} if success: return JSONResponse( @@ -162,7 +150,9 @@ async def generate_request(obj: GenerateReqInput, request: Request): async def stream_results(): try: - async for out in tokenizer_manager.generate_request(obj, request): + async for out in engine.tokenizer_manager.generate_request( + obj, request + ): yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n" except ValueError as e: out = {"error": {"message": str(e)}} @@ -172,11 +162,13 @@ async def stream_results(): return StreamingResponse( stream_results(), media_type="text/event-stream", - background=tokenizer_manager.create_abort_task(obj), + background=engine.tokenizer_manager.create_abort_task(obj), ) else: try: - ret = await tokenizer_manager.generate_request(obj, request).__anext__() + ret = await engine.tokenizer_manager.generate_request( + obj, request + ).__anext__() return ret except ValueError as e: return JSONResponse( @@ -191,7 +183,7 @@ async def stream_results(): async def encode_request(obj: EmbeddingReqInput, request: Request): """Handle an embedding request.""" try: - ret = await tokenizer_manager.generate_request(obj, request).__anext__() + ret = await engine.tokenizer_manager.generate_request(obj, request).__anext__() return ret except ValueError as e: return JSONResponse( @@ -205,24 +197,24 @@ async def encode_request(obj: EmbeddingReqInput, request: Request): @app.post("/v1/completions") async def openai_v1_completions(raw_request: Request): - return await v1_completions(tokenizer_manager, raw_request) + return await v1_completions(engine.tokenizer_manager, raw_request) @app.post("/v1/chat/completions") async def openai_v1_chat_completions(raw_request: Request): - return await v1_chat_completions(tokenizer_manager, raw_request) + return await v1_chat_completions(engine.tokenizer_manager, raw_request) @app.post("/v1/embeddings") async def openai_v1_embeddings(raw_request: Request): - response = await v1_embeddings(tokenizer_manager, raw_request) + response = await v1_embeddings(engine.tokenizer_manager, raw_request) return response @app.get("/v1/models") def available_models(): """Show available models.""" - served_model_names = [tokenizer_manager.served_model_name] + served_model_names = [engine.tokenizer_manager.served_model_name] model_cards = [] for served_model_name in served_model_names: model_cards.append(ModelCard(id=served_model_name, root=served_model_name)) @@ -231,9 +223,7 @@ def available_models(): @app.post("/v1/files") async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")): - return await v1_files_create( - file, purpose, tokenizer_manager.server_args.file_storage_pth - ) + return await v1_files_create(file, purpose, file_storage_pth) @app.delete("/v1/files/{file_id}") @@ -244,13 +234,13 @@ async def delete_file(file_id: str): @app.post("/v1/batches") async def openai_v1_batches(raw_request: Request): - return await v1_batches(tokenizer_manager, raw_request) + return await v1_batches(engine.tokenizer_manager, raw_request) @app.post("/v1/batches/{batch_id}/cancel") async def cancel_batches(batch_id: str): # https://platform.openai.com/docs/api-reference/batch/cancel - return await v1_cancel_batch(tokenizer_manager, batch_id) + return await v1_cancel_batch(engine.tokenizer_manager, batch_id) @app.get("/v1/batches/{batch_id}") @@ -272,102 +262,23 @@ async def retrieve_file_content(file_id: str): def launch_server( server_args: ServerArgs, - model_overide_args: Optional[dict] = None, pipe_finish_writer: Optional[mp.connection.Connection] = None, ): """Launch an HTTP server.""" - global tokenizer_manager - configure_logger(server_args) + configure_logger(server_args.log_level) - server_args.check_server_args() - _set_envs_and_config(server_args) + global engine + engine = Engine(server_args.engine_args) - # Allocate ports for inter-process communications - server_args.port, server_args.additional_ports = allocate_init_ports( - server_args.port, - server_args.additional_ports, - server_args.dp_size, - ) - ports = server_args.additional_ports - port_args = PortArgs( - tokenizer_port=ports[0], - controller_port=ports[1], - detokenizer_port=ports[2], - nccl_ports=ports[3:], - ) - logger.info(f"{server_args=}") - - # Use model from www.modelscope.cn, first download the model. - server_args.model_path = prepare_model(server_args.model_path) - server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path) - - # Launch processes for multi-node tensor parallelism - if server_args.nnodes > 1 and server_args.node_rank != 0: - tp_size_local = server_args.tp_size // server_args.nnodes - gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)] - tp_rank_range = list( - range( - server_args.node_rank * tp_size_local, - (server_args.node_rank + 1) * tp_size_local, - ) - ) - procs = launch_tp_servers( - gpu_ids, - tp_rank_range, - server_args, - ports[3], - model_overide_args, - ) - - try: - for p in procs: - p.join() - finally: - kill_child_process(os.getpid(), including_parent=False) - return - - # Launch processes - tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args) if server_args.chat_template: - load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) - pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False) - pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False) - - if server_args.dp_size == 1: - start_controller_process = start_controller_process_single - else: - start_controller_process = start_controller_process_multi - - proc_controller = mp.Process( - target=start_controller_process, - args=(server_args, port_args, pipe_controller_writer, model_overide_args), - ) - proc_controller.start() - - proc_detoken = mp.Process( - target=start_detokenizer_process, - args=( - server_args, - port_args, - pipe_detoken_writer, - ), - ) - proc_detoken.start() - - # Wait for the model to finish loading - controller_init_state = pipe_controller_reader.recv() - detoken_init_state = pipe_detoken_reader.recv() - - if controller_init_state != "init ok" or detoken_init_state != "init ok": - proc_controller.kill() - proc_detoken.kill() - raise RuntimeError( - "Initialization failed. " - f"controller_init_state: {controller_init_state}, " - f"detoken_init_state: {detoken_init_state}" + load_chat_template_for_openai_api( + engine.tokenizer_manager, server_args.chat_template ) - assert proc_controller.is_alive() and proc_detoken.is_alive() + + if server_args.file_storage_pth: + global file_storage_pth + file_storage_pth = server_args.file_storage_pth # Add api key authorization if server_args.api_key: @@ -393,41 +304,6 @@ def launch_server( t.join() -def _set_envs_and_config(server_args: ServerArgs): - # Set global environments - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - os.environ["NCCL_CUMEM_ENABLE"] = "0" - os.environ["NCCL_NVLS_ENABLE"] = "0" - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" - - # Set ulimit - set_ulimit() - - # Enable show time cost for debugging - if server_args.show_time_cost: - enable_show_time_cost() - - # Disable disk cache - if server_args.disable_disk_cache: - disable_cache() - - # Fix triton bugs - if server_args.tp_size * server_args.dp_size > 1: - # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. - maybe_set_triton_cache_manager() - - # Check flashinfer version - if not server_args.disable_flashinfer: - assert_pkg_version( - "flashinfer", - "0.1.6", - "Please uninstall the old version and " - "reinstall the latest version by following the instructions " - "at https://docs.flashinfer.ai/installation.html.", - ) - - def _wait_and_warmup(server_args, pipe_finish_writer, pid): headers = {} url = server_args.url() @@ -501,18 +377,16 @@ class Runtime: def __init__( self, log_level: str = "error", - model_overide_args: Optional[dict] = None, + model_override_args: Optional[dict] = None, *args, **kwargs, ): """See the arguments in server_args.py::ServerArgs""" - self.server_args = ServerArgs(*args, log_level=log_level, **kwargs) - - # Pre-allocate ports - self.server_args.port, self.server_args.additional_ports = allocate_init_ports( - self.server_args.port, - self.server_args.additional_ports, - self.server_args.dp_size, + self.server_args = ServerArgs.from_kwargs( + *args, + log_level=log_level, + model_override_args=model_override_args, + **kwargs, ) self.url = self.server_args.url() @@ -522,10 +396,9 @@ def __init__( self.pid = None pipe_reader, pipe_writer = mp.Pipe(duplex=False) - proc = mp.Process( target=launch_server, - args=(self.server_args, model_overide_args, pipe_writer), + args=(self.server_args, pipe_writer), ) proc.start() pipe_writer.close() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/serving/server_args.py similarity index 71% rename from python/sglang/srt/server_args.py rename to python/sglang/srt/serving/server_args.py index 8a56c02e16..ccbe794088 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/serving/server_args.py @@ -18,113 +18,79 @@ import argparse import dataclasses import logging -import random -from typing import List, Optional, Union +from dataclasses import fields +from typing import Dict, List, Optional + +from sglang.srt.serving.engine_args import EngineArgs logger = logging.getLogger(__name__) @dataclasses.dataclass class ServerArgs: - # Model and tokenizer - model_path: str - tokenizer_path: Optional[str] = None - tokenizer_mode: str = "auto" - skip_tokenizer_init: bool = False - load_format: str = "auto" - dtype: str = "auto" - kv_cache_dtype: str = "auto" - trust_remote_code: bool = True - context_length: Optional[int] = None - quantization: Optional[str] = None - served_model_name: Optional[str] = None - chat_template: Optional[str] = None - is_embedding: bool = False + # The core engine args + engine_args: EngineArgs # = field(default_factory=EngineArgs) - # Port + # + # The server specifc args + # + # Connection host: str = "127.0.0.1" port: int = 30000 - additional_ports: Optional[Union[List[int], int]] = None - - # Memory and scheduling - mem_fraction_static: Optional[float] = None - max_running_requests: Optional[int] = None - max_num_reqs: Optional[int] = None - max_total_tokens: Optional[int] = None - chunked_prefill_size: int = 8192 - max_prefill_tokens: int = 16384 - schedule_policy: str = "lpm" - schedule_conservativeness: float = 1.0 - - # Other runtime options - tp_size: int = 1 - stream_interval: int = 1 - random_seed: Optional[int] = None - - # Logging - log_level: str = "info" - log_level_http: Optional[str] = None - log_requests: bool = False - show_time_cost: bool = False - # Other - api_key: Optional[str] = None + # OpenAI API + chat_template: Optional[str] = None file_storage_pth: str = "SGLang_storage" - # Data parallelism - dp_size: int = 1 - load_balance_method: str = "round_robin" + # Authentication + api_key: Optional[str] = None - # Optimization/debug options - disable_flashinfer: bool = False - disable_flashinfer_sampling: bool = False - disable_radix_cache: bool = False - disable_regex_jump_forward: bool = False - disable_cuda_graph: bool = False - disable_cuda_graph_padding: bool = False - disable_disk_cache: bool = False - disable_custom_all_reduce: bool = False - enable_mixed_chunk: bool = False - enable_torch_compile: bool = False - enable_p2p_check: bool = False - enable_mla: bool = False - triton_attention_reduce_in_fp32: bool = False + def __post_init__(self): ... - # Distributed args - nccl_init_addr: Optional[str] = None - nnodes: int = 1 - node_rank: Optional[int] = None + def __getattr__(self, item): + # Avoid recursion by checking if `engine_args` exists first + if item == "engine_args" or "engine_args" not in self.__dict__: + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) - def __post_init__(self): - if self.tokenizer_path is None: - self.tokenizer_path = self.model_path + # Forward attribute access to engine_args if not found in ServerArgs. + # For attribute in server_args, it will be found in ServerArgs's __dict__ + # and no entry into this function. + if hasattr(self.engine_args, item): + return getattr(self.engine_args, item) + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) - if self.served_model_name is None: - self.served_model_name = self.model_path + def __setattr__(self, key, value): + # If the attribute exists in ServerArgs, set it directly + if key in {f.name for f in fields(ServerArgs)}: + super().__setattr__(key, value) + # If the attribute exists in EngineArgs, forward it to engine_args + elif hasattr(self.engine_args, key): + setattr(self.engine_args, key, value) + else: + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{key}'" + ) - if self.chunked_prefill_size <= 0: - # Disable chunked prefill - self.chunked_prefill_size = None + @classmethod + def from_kwargs( + cls, + *args, + **kwargs: Dict[str, any], + ) -> "ServerArgs": + """Creates a ServerArgs instance by separating EngineArgs and ServerArgs parameters.""" + engine_args_fields = {field.name for field in fields(EngineArgs)} + server_args_fields = {field.name for field in fields(cls)} - {"engine_args"} - if self.mem_fraction_static is None: - if self.tp_size >= 16: - self.mem_fraction_static = 0.79 - elif self.tp_size >= 8: - self.mem_fraction_static = 0.83 - elif self.tp_size >= 4: - self.mem_fraction_static = 0.85 - elif self.tp_size >= 2: - self.mem_fraction_static = 0.87 - else: - self.mem_fraction_static = 0.88 + engine_args_dict = {k: v for k, v in kwargs.items() if k in engine_args_fields} + server_args_dict = {k: v for k, v in kwargs.items() if k in server_args_fields} - if isinstance(self.additional_ports, int): - self.additional_ports = [self.additional_ports] - elif self.additional_ports is None: - self.additional_ports = [] + engine_args = EngineArgs(*args, **engine_args_dict) - if self.random_seed is None: - self.random_seed = random.randint(0, 1 << 30) + return cls(engine_args=engine_args, **server_args_dict) @staticmethod def add_cli_args(parser: argparse.ArgumentParser): @@ -137,7 +103,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--tokenizer-path", type=str, - default=ServerArgs.tokenizer_path, + default=EngineArgs.tokenizer_path, help="The path of the tokenizer.", ) parser.add_argument( @@ -156,7 +122,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--tokenizer-mode", type=str, - default=ServerArgs.tokenizer_mode, + default=EngineArgs.tokenizer_mode, choices=["auto", "slow"], help="Tokenizer mode. 'auto' will use the fast " "tokenizer if available, and 'slow' will " @@ -170,7 +136,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--load-format", type=str, - default=ServerArgs.load_format, + default=EngineArgs.load_format, choices=["auto", "pt", "safetensors", "npcache", "dummy"], help="The format of the model weights to load. " '"auto" will try to load the weights in the safetensors format ' @@ -186,7 +152,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--dtype", type=str, - default=ServerArgs.dtype, + default=EngineArgs.dtype, choices=["auto", "half", "float16", "bfloat16", "float", "float32"], help="Data type for model weights and activations.\n\n" '* "auto" will use FP16 precision for FP32 and FP16 models, and ' @@ -200,7 +166,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--kv-cache-dtype", type=str, - default=ServerArgs.kv_cache_dtype, + default=EngineArgs.kv_cache_dtype, choices=["auto", "fp8_e5m2"], help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.', ) @@ -217,13 +183,13 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--context-length", type=int, - default=ServerArgs.context_length, + default=EngineArgs.context_length, help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).", ) parser.add_argument( "--quantization", type=str, - default=ServerArgs.quantization, + default=EngineArgs.quantization, choices=[ "awq", "fp8", @@ -239,7 +205,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--served-model-name", type=str, - default=ServerArgs.served_model_name, + default=EngineArgs.served_model_name, help="Override the model name returned by the v1/models endpoint in OpenAI API server.", ) parser.add_argument( @@ -251,81 +217,81 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--mem-fraction-static", type=float, - default=ServerArgs.mem_fraction_static, + default=EngineArgs.mem_fraction_static, help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.", ) parser.add_argument( "--max-running-requests", type=int, - default=ServerArgs.max_running_requests, + default=EngineArgs.max_running_requests, help="The maximum number of running requests.", ) parser.add_argument( "--max-num-reqs", type=int, - default=ServerArgs.max_num_reqs, + default=EngineArgs.max_num_reqs, help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.", ) parser.add_argument( "--max-total-tokens", type=int, - default=ServerArgs.max_total_tokens, + default=EngineArgs.max_total_tokens, help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes.", ) parser.add_argument( "--chunked-prefill-size", type=int, - default=ServerArgs.chunked_prefill_size, + default=EngineArgs.chunked_prefill_size, help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill", ) parser.add_argument( "--max-prefill-tokens", type=int, - default=ServerArgs.max_prefill_tokens, + default=EngineArgs.max_prefill_tokens, help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.", ) parser.add_argument( "--schedule-policy", type=str, - default=ServerArgs.schedule_policy, + default=EngineArgs.schedule_policy, choices=["lpm", "random", "fcfs", "dfs-weight"], help="The scheduling policy of the requests.", ) parser.add_argument( "--schedule-conservativeness", type=float, - default=ServerArgs.schedule_conservativeness, + default=EngineArgs.schedule_conservativeness, help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.", ) parser.add_argument( "--tensor-parallel-size", "--tp-size", type=int, - default=ServerArgs.tp_size, + default=EngineArgs.tp_size, help="The tensor parallelism size.", ) parser.add_argument( "--stream-interval", type=int, - default=ServerArgs.stream_interval, + default=EngineArgs.stream_interval, help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher", ) parser.add_argument( "--random-seed", type=int, - default=ServerArgs.random_seed, + default=EngineArgs.random_seed, help="The random seed.", ) parser.add_argument( "--log-level", type=str, - default=ServerArgs.log_level, + default=EngineArgs.log_level, help="The logging level of all loggers.", ) parser.add_argument( "--log-level-http", type=str, - default=ServerArgs.log_level_http, + default=EngineArgs.log_level_http, help="The logging level of HTTP server. If not set, reuse --log-level by default.", ) parser.add_argument( @@ -356,13 +322,13 @@ def add_cli_args(parser: argparse.ArgumentParser): "--data-parallel-size", "--dp-size", type=int, - default=ServerArgs.dp_size, + default=EngineArgs.dp_size, help="The data parallelism size.", ) parser.add_argument( "--load-balance-method", type=str, - default=ServerArgs.load_balance_method, + default=EngineArgs.load_balance_method, help="The load balancing strategy for data parallelism.", choices=[ "round_robin", @@ -377,7 +343,7 @@ def add_cli_args(parser: argparse.ArgumentParser): help="The nccl init address of multi-node server.", ) parser.add_argument( - "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes." + "--nnodes", type=int, default=EngineArgs.nnodes, help="The number of nodes." ) parser.add_argument("--node-rank", type=int, help="The node rank.") @@ -441,13 +407,13 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--enable-mla", action="store_true", - help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2.", + help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2", ) parser.add_argument( - "--triton-attention-reduce-in-fp32", + "--attention-reduce-in-fp32", action="store_true", help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." - "This only affects Triton attention kernels.", + "This only affects Triton attention kernels", ) parser.add_argument( "--efficient-weight-load", @@ -459,32 +425,23 @@ def add_cli_args(parser: argparse.ArgumentParser): def from_cli_args(cls, args: argparse.Namespace): args.tp_size = args.tensor_parallel_size args.dp_size = args.data_parallel_size - attrs = [attr.name for attr in dataclasses.fields(cls)] - return cls(**{attr: getattr(args, attr) for attr in attrs}) - def url(self): - return f"http://{self.host}:{self.port}" + # Init EngineArgs + engine_args_fields = {field.name for field in dataclasses.fields(EngineArgs)} + engine_args_dict = { + key: getattr(args, key) for key in engine_args_fields if hasattr(args, key) + } + engine_args = EngineArgs(**engine_args_dict) - def check_server_args(self): - assert ( - self.tp_size % self.nnodes == 0 - ), "tp_size must be divisible by number of nodes" - assert not ( - self.dp_size > 1 and self.node_rank is not None - ), "multi-node data parallel is not supported" - if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path: - logger.info( - "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True" - ) - self.trust_remote_code = False - if "gemma-2" in self.model_path.lower(): - logger.info("When using sliding window in gemma-2, turn on flashinfer.") - self.disable_flashinfer = False + # Init ServerArgs with the remaining fields... + server_args_fields = {field.name for field in dataclasses.fields(cls)} - { + "engine_args" + } + server_args_dict = { + key: getattr(args, key) for key in server_args_fields if hasattr(args, key) + } + return cls(engine_args=engine_args, **server_args_dict) -@dataclasses.dataclass -class PortArgs: - tokenizer_port: int - controller_port: int - detokenizer_port: int - nccl_ports: List[int] + def url(self): + return f"http://{self.host}:{self.port}" diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 66a5679d75..683733941a 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -706,10 +706,10 @@ def prepare_tokenizer(tokenizer_path: str): return tokenizer_path -def configure_logger(server_args, prefix: str = ""): +def configure_logger(log_level, prefix: str = ""): format = f"[%(asctime)s{prefix}] %(message)s" logging.basicConfig( - level=getattr(logging, server_args.log_level.upper()), + level=getattr(logging, log_level.upper()), format=format, datefmt="%H:%M:%S", force=True, diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index ac69ab875b..9b9e6d531a 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -23,7 +23,7 @@ import torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoTokenizer -from sglang.srt.server import Runtime +from sglang.srt.serving.server import Runtime from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER DEFAULT_PROMPTS = [ diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index cafcf3f2d5..baab1261d8 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -19,6 +19,7 @@ "test_triton_attn_backend.py", "test_update_weights.py", "test_vision_openai_server.py", + "test_llm_engine.py", ], "sampling/penaltylib": glob.glob( "sampling/penaltylib/**/test_*.py", recursive=True diff --git a/test/srt/test_chunked_prefill.py b/test/srt/test_chunked_prefill.py index 2eb704dc91..7909aead2d 100644 --- a/test/srt/test_chunked_prefill.py +++ b/test/srt/test_chunked_prefill.py @@ -20,6 +20,7 @@ def run_mmlu(self, disable_radix_cache, enable_mixed_chunk): if enable_mixed_chunk: other_args += ["--enable-mixed-chunk"] + other_args += ["--disable-cuda-graph"] model = DEFAULT_MODEL_NAME_FOR_TEST base_url = DEFAULT_URL_FOR_TEST process = popen_launch_server( diff --git a/test/srt/test_llm_engine.py b/test/srt/test_llm_engine.py new file mode 100644 index 0000000000..fdac42a977 --- /dev/null +++ b/test/srt/test_llm_engine.py @@ -0,0 +1,58 @@ +import os +import unittest + +from sglang import LLM, SamplingParams +from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST + + +class TestLLMGeneration(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model_name = DEFAULT_MODEL_NAME_FOR_TEST + cls.prompts_list = [ + "Hello, my name is", + "The capital of China is", + "What is the meaning of life?", + "The future of AI is", + ] + cls.single_prompt = "What is the meaning of life?" + # Turn off tokernizers parallelism to enable running multiple tests + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + def test_generate_with_sampling_params(self): + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + llm = LLM(model=self.model_name) + outputs = llm.generate(self.prompts_list, sampling_params) + + self.assertEqual(len(outputs), len(self.prompts_list)) + for output in outputs: + self.assertIn(output["index"], range(len(self.prompts_list))) + self.assertTrue(output["text"].strip()) + + def test_generate_without_sampling_params(self): + llm = LLM(model=self.model_name) + outputs = llm.generate(self.prompts_list) + + self.assertEqual(len(outputs), len(self.prompts_list)) + for output in outputs: + self.assertIn(output["index"], range(len(self.prompts_list))) + self.assertTrue(output["text"].strip()) + + def test_generate_with_single_prompt_and_sampling_params(self): + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + llm = LLM(model=self.model_name) + outputs = llm.generate(self.single_prompt, sampling_params) + + self.assertEqual(len(outputs), 1) + self.assertTrue(outputs[0]["text"].strip()) + + def test_generate_with_single_prompt_without_sampling_params(self): + llm = LLM(model=self.model_name) + outputs = llm.generate(self.single_prompt) + + self.assertEqual(len(outputs), 1) + self.assertTrue(outputs[0]["text"].strip()) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_moe_serving_throughput.py b/test/srt/test_moe_serving_throughput.py index 4f6e8db82c..8bada5c4c3 100644 --- a/test/srt/test_moe_serving_throughput.py +++ b/test/srt/test_moe_serving_throughput.py @@ -3,7 +3,7 @@ from types import SimpleNamespace from sglang.bench_serving import run_benchmark -from sglang.srt.server_args import ServerArgs +from sglang.srt.serving.engine_args import EngineArgs from sglang.srt.utils import kill_child_process from sglang.test.test_utils import ( DEFAULT_MOE_MODEL_NAME_FOR_TEST, @@ -70,9 +70,9 @@ def run_test(self, disable_radix_cache, disable_flashinfer, chunked_prefill_size def test_default(self): res = self.run_test( - disable_radix_cache=ServerArgs.disable_radix_cache, - disable_flashinfer=ServerArgs.disable_flashinfer, - chunked_prefill_size=ServerArgs.chunked_prefill_size, + disable_radix_cache=EngineArgs.disable_radix_cache, + disable_flashinfer=EngineArgs.disable_flashinfer, + chunked_prefill_size=EngineArgs.chunked_prefill_size, ) if os.getenv("SGLANG_IS_IN_CI", "false") == "true": @@ -82,8 +82,8 @@ def test_default(self): def test_default_without_radix_cache(self): res = self.run_test( disable_radix_cache=True, - disable_flashinfer=ServerArgs.disable_flashinfer, - chunked_prefill_size=ServerArgs.chunked_prefill_size, + disable_flashinfer=EngineArgs.disable_flashinfer, + chunked_prefill_size=EngineArgs.chunked_prefill_size, ) if os.getenv("SGLANG_IS_IN_CI", "false") == "true": diff --git a/test/srt/test_serving_throughput.py b/test/srt/test_serving_throughput.py index f1089a6a7b..9456fabe1a 100644 --- a/test/srt/test_serving_throughput.py +++ b/test/srt/test_serving_throughput.py @@ -3,7 +3,7 @@ from types import SimpleNamespace from sglang.bench_serving import run_benchmark -from sglang.srt.server_args import ServerArgs +from sglang.srt.serving.engine_args import EngineArgs from sglang.srt.utils import kill_child_process from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, @@ -68,9 +68,9 @@ def run_test(self, disable_radix_cache, disable_flashinfer, chunked_prefill_size def test_default(self): res = self.run_test( - disable_radix_cache=ServerArgs.disable_radix_cache, - disable_flashinfer=ServerArgs.disable_flashinfer, - chunked_prefill_size=ServerArgs.chunked_prefill_size, + disable_radix_cache=EngineArgs.disable_radix_cache, + disable_flashinfer=EngineArgs.disable_flashinfer, + chunked_prefill_size=EngineArgs.chunked_prefill_size, ) if os.getenv("SGLANG_IS_IN_CI", "false") == "true": @@ -80,8 +80,8 @@ def test_default(self): def test_default_without_radix_cache(self): res = self.run_test( disable_radix_cache=True, - disable_flashinfer=ServerArgs.disable_flashinfer, - chunked_prefill_size=ServerArgs.chunked_prefill_size, + disable_flashinfer=EngineArgs.disable_flashinfer, + chunked_prefill_size=EngineArgs.chunked_prefill_size, ) if os.getenv("SGLANG_IS_IN_CI", "false") == "true": @@ -90,8 +90,8 @@ def test_default_without_radix_cache(self): def test_default_without_chunked_prefill(self): res = self.run_test( - disable_radix_cache=ServerArgs.disable_radix_cache, - disable_flashinfer=ServerArgs.disable_flashinfer, + disable_radix_cache=EngineArgs.disable_radix_cache, + disable_flashinfer=EngineArgs.disable_flashinfer, chunked_prefill_size=-1, )