Skip to content

Commit

Permalink
Add an LLM engine
Browse files Browse the repository at this point in the history
  • Loading branch information
JianyuZhan committed Aug 17, 2024
1 parent 3694f8f commit d724625
Show file tree
Hide file tree
Showing 13 changed files with 1,762 additions and 58 deletions.
20 changes: 20 additions & 0 deletions examples/usage/llm_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from sglang import LLM, SamplingParams

# Sample prompts.
prompts = [
"Hello, my name is",
"The capital of China is",
"What is the meaning of life?",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="deepseek-ai/deepseek-llm-7b-chat")

outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for prompt, output in zip(prompts, outputs):
print('===============================')
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
2 changes: 2 additions & 0 deletions python/sglang/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SGL API Components

from sglang.api import (
LLM,
SamplingParams,
Runtime,
assistant,
assistant_begin,
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
SglSelect,
SglVideo,
)

from sglang.srt.serving.engine import LLM
from sglang.srt.sampling_params import SamplingParams

def function(
func: Optional[Callable] = None, num_api_spec_tokens: Optional[int] = None
Expand Down
218 changes: 218 additions & 0 deletions python/sglang/srt/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
from dataclasses import dataclass, fields

from typing import Optional, Dict, Union, List

class ModelConfig:
def __init__(self,
model_path: str,
load_format: str = "auto",
tokenizer_path: Optional[str] = None,
tokenizer_mode: str = "auto",
skip_tokenizer_init: bool = False,
dtype: str = "auto",
trust_remote_code: bool = True,
context_length: Optional[int] = None,
quantization: Optional[str] = None,
served_model_name: Optional[str] = None,
random_seed: Optional[int] = None,
stream_interval: int = 1,
tokenizer_port: int = 0,
detokenizer_port: int = 0,
controller_port: int = 0,
model_override_args: Optional[Dict] = None) -> None:
"""
ModelConfig for model and tokenizer configuration.
Args:
model_path: Path to the model file or directory.
load_format: Format to load the model. Default is 'auto'.
tokenizer_path: Path to the tokenizer file or directory. Default is None.
tokenizer_mode: Mode for loading the tokenizer. Default is 'auto'.
skip_tokenizer_init: Whether to skip the tokenizer initialization. Default is False.
dtype: Data type for the model. Default is 'auto'.
trust_remote_code: Whether to trust and execute remote code from the model repository. Default is True.
context_length: Maximum context length for the model. Default is None.
quantization: Quantization method. Default is None.
served_model_name: Custom name for the served model. Default is None.
random_seed: Seed for random number generation. Default is None.
stream_interval: Interval for streaming output. Default is 1.
tokenizer_port: Port number for the tokenizer. Default is 0.
detokenizer_port: Port number for the detokenizer. Default is 0.
controller_port: Port number for the controller. Default is 0.
model_override_args: Dictionary of model override arguments. Default is None.
"""
self.model_path = model_path
self.load_format = load_format
self.tokenizer_path = tokenizer_path
self.tokenizer_mode = tokenizer_mode
self.skip_tokenizer_init = skip_tokenizer_init
self.dtype = dtype
self.trust_remote_code = trust_remote_code
self.context_length = context_length
self.quantization = quantization
self.served_model_name = served_model_name
self.random_seed = random_seed
self.stream_interval = stream_interval
self.tokenizer_port = tokenizer_port
self.detokenizer_port = detokenizer_port
self.controller_port = controller_port
self.model_override_args = model_override_args

def __repr__(self):
return (f"ModelConfig(model_path={self.model_path}, load_format={self.load_format}, "
f"tokenizer_path={self.tokenizer_path}, tokenizer_mode={self.tokenizer_mode}, "
f"skip_tokenizer_init={self.skip_tokenizer_init}, dtype={self.dtype}, "
f"trust_remote_code={self.trust_remote_code}, context_length={self.context_length}, "
f"quantization={self.quantization}, served_model_name={self.served_model_name}, "
f"random_seed={self.random_seed}, stream_interval={self.stream_interval}, "
f"tokenizer_port={self.tokenizer_port}, detokenizer_port={self.detokenizer_port}, "
f"controller_port={self.controller_port}, model_override_args={self.model_override_args})")

class ScheduleConfig:
def __init__(self,
mem_fraction_static: Optional[float] = None,
max_running_requests: Optional[int] = None,
max_num_reqs: Optional[int] = None,
max_total_tokens: Optional[int] = None,
chunked_prefill_size: int = 8192,
max_prefill_tokens: int = 16384,
schedule_policy: str = "lpm",
schedule_conservativeness: float = 1.0) -> None:
"""
ScheduleConfig object for scheduling and memory management
Args:
mem_fraction_static: Fraction of memory statically allocated. Default is None.
max_running_requests: Maximum number of running requests. Default is None.
max_num_reqs: Maximum number of requests. Default is None.
max_total_tokens: Maximum total tokens allowed. Default is None.
chunked_prefill_size: Size for chunked prefill. Default is 8192.
max_prefill_tokens: Maximum tokens allowed in the prefill phase. Default is 16384.
schedule_policy: Scheduling policy (e.g., 'lpm'). Default is 'lpm'.
schedule_conservativeness: Conservativeness factor for scheduling. Default is 1.0.
"""
self.mem_fraction_static = mem_fraction_static
self.max_running_requests = max_running_requests
self.max_num_reqs = max_num_reqs
self.max_total_tokens = max_total_tokens
self.chunked_prefill_size = chunked_prefill_size
self.max_prefill_tokens = max_prefill_tokens
self.schedule_policy = schedule_policy
self.schedule_conservativeness = schedule_conservativeness

def __repr__(self):
return (f"ScheduleConfig(mem_fraction_static={self.mem_fraction_static}, "
f"max_running_requests={self.max_running_requests}, max_num_reqs={self.max_num_reqs}, "
f"max_total_tokens={self.max_total_tokens}, chunked_prefill_size={self.chunked_prefill_size}, "
f"max_prefill_tokens={self.max_prefill_tokens}, schedule_policy={self.schedule_policy}, "
f"schedule_conservativeness={self.schedule_conservativeness})")

class ParallelConfig:
def __init__(self,
tp_size: int = 1,
dp_size: int = 1,
load_balance_method: str = "round_robin",
nccl_init_addr: Optional[str] = None,
nccl_ports: List[int] = None,
additional_ports: Optional[Union[List[int], int]] = None,
nnodes: int = 1,
node_rank: Optional[int] = None) -> None:
"""
ParallelConfig object for parallelism and distributed settings.
Args:
tp_size: Tensor parallelism size. Default is 1.
dp_size: Data parallelism size. Default is 1.
load_balance_method: Method for load balancing across nodes. Default is 'round_robin'.
nccl_init_addr: NCCL initialization address. Default is None.
nccl_ports: List of ports for NCCL communication. Default is None.
additional_ports: Additional ports for distributed communication. Default is None.
nnodes: Number of nodes in the distributed setup. Default is 1.
node_rank: Rank of the current node. Default is None.
"""
self.tp_size = tp_size
self.dp_size = dp_size
self.load_balance_method = load_balance_method
self.nccl_init_addr = nccl_init_addr
self.nccl_ports = nccl_ports
self.additional_ports = additional_ports
self.nnodes = nnodes
self.node_rank = node_rank

def __repr__(self):
return (f"ParallelConfig(tp_size={self.tp_size}, dp_size={self.dp_size}, "
f"load_balance_method={self.load_balance_method}, nccl_init_addr={self.nccl_init_addr}, "
f"nccl_ports={self.nccl_ports}, additional_ports={self.additional_ports}, "
f"nnodes={self.nnodes}, node_rank={self.node_rank})")

class OptimizationConfig:
def __init__(self,
disable_flashinfer: bool = False,
disable_flashinfer_sampling: bool = False,
disable_radix_cache: bool = False,
disable_regex_jump_forward: bool = False,
disable_cuda_graph: bool = False,
disable_disk_cache: bool = False,
enable_torch_compile: bool = False,
enable_p2p_check: bool = False,
enable_mla: bool = False,
attention_reduce_in_fp32: bool = False,
efficient_weight_load: bool = False) -> None:
"""
OptimizationConfig object for optimization and debug options
Args:
disable_flashinfer: Disable flashinfer library. Default is False.
disable_flashinfer_sampling: Disable flashinfer sampling. Default is False.
disable_radix_cache: Disable radix cache optimization. Default is False.
disable_regex_jump_forward: Disable regex-based jump forward optimization. Default is False.
disable_cuda_graph: Disable CUDA graph optimization. Default is False.
disable_disk_cache: Disable disk caching. Default is False.
enable_torch_compile: Enable PyTorch compilation optimization. Default is False.
enable_p2p_check: Enable peer-to-peer communication checks. Default is False.
enable_mla: Enable Multi-Head Latent Attention from DeepSeek-V2. Default is False.
attention_reduce_in_fp32: Perform attention reduction in FP32 precision. Default is False.
efficient_weight_load: Enable efficient weight loading. Default is False.
"""
self.disable_flashinfer = disable_flashinfer
self.disable_flashinfer_sampling = disable_flashinfer_sampling
self.disable_radix_cache = disable_radix_cache
self.disable_regex_jump_forward = disable_regex_jump_forward
self.disable_cuda_graph = disable_cuda_graph
self.disable_disk_cache = disable_disk_cache
self.enable_torch_compile = enable_torch_compile
self.enable_p2p_check = enable_p2p_check
self.enable_mla = enable_mla
self.attention_reduce_in_fp32 = attention_reduce_in_fp32
self.efficient_weight_load = efficient_weight_load

def __repr__(self):
return (f"OptimizationConfig(disable_flashinfer={self.disable_flashinfer}, "
f"disable_flashinfer_sampling={self.disable_flashinfer_sampling}, disable_radix_cache={self.disable_radix_cache}, "
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, disable_cuda_graph={self.disable_cuda_graph}, "
f"disable_disk_cache={self.disable_disk_cache}, enable_torch_compile={self.enable_torch_compile}, "
f"enable_p2p_check={self.enable_p2p_check}, enable_mla={self.enable_mla}, "
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, efficient_weight_load={self.efficient_weight_load})")


@dataclass(frozen=True)
class EngineConfig:
"""Dataclass which contains all engine-related configuration. This
simplifies passing around the distinct configurations in the codebase.
"""
model_config: ModelConfig
schedule_config: ScheduleConfig
parallel_config: ParallelConfig
optimization_config: OptimizationConfig

def __post_init__(self):
"""Verify configs are valid & consistent with each other.
"""
# TODO: Do validation
pass

def to_dict(self):
"""Return the configs as a dictionary, for use in **kwargs.
"""
return dict(
(field.name, getattr(self, field.name)) for field in fields(self))
24 changes: 17 additions & 7 deletions python/sglang/srt/managers/controller_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import multiprocessing
import os
from enum import Enum, auto
from typing import List

import numpy as np
import zmq
Expand All @@ -35,7 +36,7 @@
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_parent_process
from sglang.utils import get_exception_traceback

Expand Down Expand Up @@ -71,12 +72,16 @@ class ControllerMulti:
def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
controller_port: int,
detokenizer_port: int,
nccl_ports: List[int],
model_overide_args,
):
# Parse args
self.server_args = server_args
self.port_args = port_args
self.controller_port = controller_port
self.detokenizer_port = detokenizer_port
self.nccl_ports = nccl_ports
self.model_overide_args = model_overide_args
self.load_balance_method = LoadBalanceMethod.from_str(
server_args.load_balance_method
Expand All @@ -85,7 +90,7 @@ def __init__(
# Init communication
context = zmq.Context()
self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}")
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{controller_port}")

# Dispatch method
self.round_robin_counter = 0
Expand Down Expand Up @@ -113,7 +118,9 @@ def start_dp_worker(self, dp_worker_id: int):
target=start_controller_process_single,
args=(
self.server_args,
self.port_args,
self.controller_port,
self.detokenizer_port,
self.nccl_ports,
pipe_controller_writer,
self.model_overide_args,
True,
Expand Down Expand Up @@ -188,7 +195,9 @@ def recv_requests(self):

def start_controller_process(
server_args: ServerArgs,
port_args: PortArgs,
controller_port: int,
detokenizer_port: int,
nccl_ports: List[int],
pipe_writer,
model_overide_args: dict,
):
Expand All @@ -200,7 +209,8 @@ def start_controller_process(
)

try:
controller = ControllerMulti(server_args, port_args, model_overide_args)
controller = ControllerMulti(server_args, controller_port, detokenizer_port,
nccl_ports, model_overide_args)
except Exception:
pipe_writer.send(get_exception_traceback())
raise
Expand Down
Loading

0 comments on commit d724625

Please sign in to comment.