Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add zmq radix cache #12

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion python/sglang/srt/managers/controller_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class LoadBalanceMethod(Enum):

ROUND_ROBIN = auto()
SHORTEST_QUEUE = auto()
PRE_RADIX = auto()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不应该是单独的一个 Method,可以融入到 Resources_aware里面


@classmethod
def from_str(cls, method: str):
Expand Down Expand Up @@ -86,20 +87,31 @@ def __init__(
self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}")

self.recv_from_tree_cache = context.socket(zmq.PULL)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

通信的机制改一下

self.recv_from_tree_cache.setsockopt(zmq.RCVHWM, 8)
self.recv_from_tree_cache.bind(f"tcp://127.0.0.1:10000")

# Dispatch method
self.round_robin_counter = 0
dispatch_lookup = {
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
LoadBalanceMethod.PRE_RADIX: self.pre_radix_scheduler,
}
self.dispatching = dispatch_lookup[self.load_balance_method]

# Start data parallel workers
self.workers = []

self.newest_tree_cache = {}

for i in range(server_args.dp_size):
self.start_dp_worker(i)

def start_dp_worker(self, dp_worker_id: int):
def start_dp_worker(
self,
dp_worker_id: int,
):
tp_size = self.server_args.tp_size

pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(
Expand Down Expand Up @@ -135,6 +147,13 @@ def start_dp_worker(self, dp_worker_id: int):
)
)

# TODO 找到哪些操作会改变树
def pre_radix_scheduler(self, input_requests):
if len(input_requests) == 0:
return

self.round_robin_scheduler(input_requests=input_requests)

def round_robin_scheduler(self, input_requests):
for r in input_requests:
self.workers[self.round_robin_counter].queue.put(r)
Expand All @@ -151,8 +170,34 @@ def shortest_queue_scheduler(self, input_requests):
def loop_for_forward(self):
while True:
recv_reqs = self.recv_requests()

if len(recv_reqs) == 0:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要加无关代码

continue

self.recv_tree_cache()
self.dispatching(recv_reqs)

def recv_tree_cache(self):
flag = False
while True:
try:
recv_radix_cache = self.recv_from_tree_cache.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break

gpu_id = recv_radix_cache.gpu_id
if (
gpu_id not in self.newest_tree_cache
or recv_radix_cache.time > self.newest_tree_cache[gpu_id].time
):
self.newest_tree_cache[gpu_id] = recv_radix_cache
flag = True

# 使用日志记录器记录信息
if flag:
# logger.info(f"latest_cache={len(self.newest_tree_cache)}")
pass

def recv_requests(self):
recv_reqs = []

Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/managers/controller_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

import logging
import multiprocessing
from typing import List
import multiprocessing.connection
from typing import Any, List

import zmq

Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def __init__(
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
disable=server_args.disable_radix_cache,
gpu_id=gpu_id,
)
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = PolicyScheduler(self.schedule_policy, self.tree_cache)
Expand Down
44 changes: 43 additions & 1 deletion python/sglang/srt/mem_cache/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,51 @@ def _key_match(key0: List, key1: List):
return i


from dataclasses import dataclass

import zmq


@dataclass
class RadixCacheSend:
gpu_id: int
root_node: TreeNode
time: time


class RadixCache(BasePrefixCache):
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool: BaseTokenToKVPool,
disable: bool = False,
gpu_id: int = 0,
):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool
self.disable = disable

context = zmq.Context()
self.send_radix_tree = context.socket(zmq.PUSH)
self.send_radix_tree.setsockopt(zmq.SNDHWM, 8)
self.send_radix_tree.connect(f"tcp://127.0.0.1:10000")
self.gpu_id = gpu_id

self.reset()

##### Public API #####
def send_prefix_tree(self):
try:
self.send_radix_tree.send_pyobj(
RadixCacheSend(
gpu_id=self.gpu_id, root_node=self.root_node, time=time.time()
),
zmq.NOBLOCK,
)
except zmq.Again as e:
print(
"=======================================Radix Cache Queue is full, drop out new radix cache tree======================================="
)

def reset(self):
self.root_node = TreeNode()
Expand All @@ -76,6 +108,8 @@ def reset(self):
self.root_node.lock_ref = 1
self.evictable_size_ = 0

self.send_prefix_tree()

def match_prefix(self, key: List, **kwargs):
if self.disable:
return [], self.root_node
Expand All @@ -95,7 +129,12 @@ def insert(self, key: List, value=None):

if value is None:
value = [x for x in key]
return self._insert_helper(self.root_node, key, value)
res = self._insert_helper(self.root_node, key, value)

# insert会改变树的结构
self.send_prefix_tree()

return res

def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
"""Cache request when it finishes."""
Expand Down Expand Up @@ -176,6 +215,9 @@ def evict(self, num_tokens: int, evict_callback: Callable):
if len(x.parent.children) == 0:
heapq.heappush(leaves, x.parent)

# 会改变树的结构
self.send_prefix_tree()

def inc_lock_ref(self, node: TreeNode):
if self.disable:
return 0
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
choices=[
"round_robin",
"shortest_queue",
"pre_radix",
],
)

Expand Down
Loading