diff --git a/python/sglang/srt/managers/controller_multi.py b/python/sglang/srt/managers/controller_multi.py index ba626d4cff..e27cef1da4 100644 --- a/python/sglang/srt/managers/controller_multi.py +++ b/python/sglang/srt/managers/controller_multi.py @@ -46,6 +46,7 @@ class LoadBalanceMethod(Enum): ROUND_ROBIN = auto() SHORTEST_QUEUE = auto() + PRE_RADIX = auto() @classmethod def from_str(cls, method: str): @@ -86,20 +87,31 @@ def __init__( self.recv_from_tokenizer = context.socket(zmq.PULL) self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}") + self.recv_from_tree_cache = context.socket(zmq.PULL) + self.recv_from_tree_cache.setsockopt(zmq.RCVHWM, 8) + self.recv_from_tree_cache.bind(f"tcp://127.0.0.1:10000") + # Dispatch method self.round_robin_counter = 0 dispatch_lookup = { LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler, LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler, + LoadBalanceMethod.PRE_RADIX: self.pre_radix_scheduler, } self.dispatching = dispatch_lookup[self.load_balance_method] # Start data parallel workers self.workers = [] + + self.newest_tree_cache = {} + for i in range(server_args.dp_size): self.start_dp_worker(i) - def start_dp_worker(self, dp_worker_id: int): + def start_dp_worker( + self, + dp_worker_id: int, + ): tp_size = self.server_args.tp_size pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe( @@ -135,6 +147,13 @@ def start_dp_worker(self, dp_worker_id: int): ) ) + # TODO 找到哪些操作会改变树 + def pre_radix_scheduler(self, input_requests): + if len(input_requests) == 0: + return + + self.round_robin_scheduler(input_requests=input_requests) + def round_robin_scheduler(self, input_requests): for r in input_requests: self.workers[self.round_robin_counter].queue.put(r) @@ -151,8 +170,34 @@ def shortest_queue_scheduler(self, input_requests): def loop_for_forward(self): while True: recv_reqs = self.recv_requests() + + if len(recv_reqs) == 0: + continue + + self.recv_tree_cache() self.dispatching(recv_reqs) + def recv_tree_cache(self): + flag = False + while True: + try: + recv_radix_cache = self.recv_from_tree_cache.recv_pyobj(zmq.NOBLOCK) + except zmq.ZMQError: + break + + gpu_id = recv_radix_cache.gpu_id + if ( + gpu_id not in self.newest_tree_cache + or recv_radix_cache.time > self.newest_tree_cache[gpu_id].time + ): + self.newest_tree_cache[gpu_id] = recv_radix_cache + flag = True + + # 使用日志记录器记录信息 + if flag: + # logger.info(f"latest_cache={len(self.newest_tree_cache)}") + pass + def recv_requests(self): recv_reqs = [] diff --git a/python/sglang/srt/managers/controller_single.py b/python/sglang/srt/managers/controller_single.py index 2ae37059c1..4e7bf289da 100644 --- a/python/sglang/srt/managers/controller_single.py +++ b/python/sglang/srt/managers/controller_single.py @@ -17,7 +17,8 @@ import logging import multiprocessing -from typing import List +import multiprocessing.connection +from typing import Any, List import zmq diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 8fc03b8599..4540bfcc04 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -166,6 +166,7 @@ def __init__( req_to_token_pool=self.model_runner.req_to_token_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool, disable=server_args.disable_radix_cache, + gpu_id=gpu_id, ) self.tree_cache_metrics = {"total": 0, "hit": 0} self.scheduler = PolicyScheduler(self.schedule_policy, self.tree_cache) diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index a1c685405a..8a334c8ba8 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -55,19 +55,51 @@ def _key_match(key0: List, key1: List): return i +from dataclasses import dataclass + +import zmq + + +@dataclass +class RadixCacheSend: + gpu_id: int + root_node: TreeNode + time: time + + class RadixCache(BasePrefixCache): def __init__( self, req_to_token_pool: ReqToTokenPool, token_to_kv_pool: BaseTokenToKVPool, disable: bool = False, + gpu_id: int = 0, ): self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool = token_to_kv_pool self.disable = disable + + context = zmq.Context() + self.send_radix_tree = context.socket(zmq.PUSH) + self.send_radix_tree.setsockopt(zmq.SNDHWM, 8) + self.send_radix_tree.connect(f"tcp://127.0.0.1:10000") + self.gpu_id = gpu_id + self.reset() ##### Public API ##### + def send_prefix_tree(self): + try: + self.send_radix_tree.send_pyobj( + RadixCacheSend( + gpu_id=self.gpu_id, root_node=self.root_node, time=time.time() + ), + zmq.NOBLOCK, + ) + except zmq.Again as e: + print( + "=======================================Radix Cache Queue is full, drop out new radix cache tree=======================================" + ) def reset(self): self.root_node = TreeNode() @@ -76,6 +108,8 @@ def reset(self): self.root_node.lock_ref = 1 self.evictable_size_ = 0 + self.send_prefix_tree() + def match_prefix(self, key: List, **kwargs): if self.disable: return [], self.root_node @@ -95,7 +129,12 @@ def insert(self, key: List, value=None): if value is None: value = [x for x in key] - return self._insert_helper(self.root_node, key, value) + res = self._insert_helper(self.root_node, key, value) + + # insert会改变树的结构 + self.send_prefix_tree() + + return res def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None): """Cache request when it finishes.""" @@ -176,6 +215,9 @@ def evict(self, num_tokens: int, evict_callback: Callable): if len(x.parent.children) == 0: heapq.heappush(leaves, x.parent) + # 会改变树的结构 + self.send_prefix_tree() + def inc_lock_ref(self, node: TreeNode): if self.disable: return 0 diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8a56c02e16..8a6adddcac 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -367,6 +367,7 @@ def add_cli_args(parser: argparse.ArgumentParser): choices=[ "round_robin", "shortest_queue", + "pre_radix", ], )