Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KV cache refactor to decouple cache blocks and metadata about them #168

Merged
merged 1 commit into from
Jan 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 38 additions & 33 deletions serve/mlc_serve/model/paged_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,11 @@ def replace_head_prompt_block_with(self, new_block):
self.prompt_shared = False


class KVCache:
class KVCacheInfo:
def __init__(
self,
cache_blocks,
block_size,
):
self.cache_blocks = cache_blocks
self.block_size = block_size

# SequenceId -> list[int]
Expand Down Expand Up @@ -135,18 +133,17 @@ def get_cache_block_size(num_layers, num_heads, head_size):

def __init__(
self,
cache_blocks, # This can be any type
num_blocks: int,
sliding_window: Optional[int] = None,
):
self.num_blocks = num_blocks
self.free_blocks = list(range(num_blocks))
self.kv_cache = KVCache(cache_blocks, self.block_size)
self.kv_cache_info = KVCacheInfo(self.block_size)
self.token_counts = dict[SequenceId, int]()

if sliding_window:
assert sliding_window % self.kv_cache.block_size == 0
self.block_sliding_window = sliding_window // self.kv_cache.block_size
assert sliding_window % self.kv_cache_info.block_size == 0
self.block_sliding_window = sliding_window // self.kv_cache_info.block_size
else:
self.block_sliding_window = None

Expand All @@ -160,20 +157,20 @@ def set_size(self, sequence_ids: List[SequenceId], target_sizes: List[int]):
num_needed_block = min(num_needed_block, self.block_sliding_window)

if size == 0:
if id in self.kv_cache.prompt_block_tables:
self.free_blocks.extend(self.kv_cache.prompt_block_tables[id])
del self.kv_cache.prompt_block_tables[id]
elif id in self.kv_cache.decode_block_tables:
if id in self.kv_cache_info.prompt_block_tables:
self.free_blocks.extend(self.kv_cache_info.prompt_block_tables[id])
del self.kv_cache_info.prompt_block_tables[id]
elif id in self.kv_cache_info.decode_block_tables:
self.free_blocks.extend(
self.kv_cache.decode_block_tables[id].decode_blocks
self.kv_cache_info.decode_block_tables[id].decode_blocks
)
del self.kv_cache.decode_block_tables[id]
del self.kv_cache_info.decode_block_tables[id]

if id in self.kv_cache.slot_mappings:
del self.kv_cache.slot_mappings[id]
if id in self.kv_cache_info.slot_mappings:
del self.kv_cache_info.slot_mappings[id]

elif id in self.kv_cache.decode_block_tables:
decode_block_table = self.kv_cache.decode_block_tables[id]
elif id in self.kv_cache_info.decode_block_tables:
decode_block_table = self.kv_cache_info.decode_block_tables[id]

if len(decode_block_table) < num_needed_block:
# Need to allocate a new block for this request
Expand Down Expand Up @@ -218,40 +215,42 @@ def get_block_circular_index(token_pos):

block_offset = pos % self.block_size
slot = block_number * self.block_size + block_offset
self.kv_cache.slot_mappings[id].append(slot)
self.kv_cache_info.slot_mappings[id].append(slot)

elif id not in self.kv_cache.prompt_block_tables:
elif id not in self.kv_cache_info.prompt_block_tables:
assert (
len(self.free_blocks) >= num_needed_block
), "Not enough free blocks."

for _ in range(num_needed_block):
self.kv_cache.prompt_block_tables[id].append(self.free_blocks.pop())
self.kv_cache_info.prompt_block_tables[id].append(
self.free_blocks.pop()
)

for block_idx in range(math.floor(size / self.block_size)):
if self.block_sliding_window:
block_idx %= self.block_sliding_window

block_number = self.kv_cache.prompt_block_tables[id][block_idx]
block_number = self.kv_cache_info.prompt_block_tables[id][block_idx]
slots = [
block_number * self.block_size + block_offset
for block_offset in range(self.block_size)
]
self.kv_cache.slot_mappings[id] += slots
self.kv_cache_info.slot_mappings[id] += slots

for i in range(len(self.kv_cache.slot_mappings[id]), size):
for i in range(len(self.kv_cache_info.slot_mappings[id]), size):
block_idx = i // self.block_size

if self.block_sliding_window:
block_idx %= self.block_sliding_window

block_number = self.kv_cache.prompt_block_tables[id][block_idx]
block_number = self.kv_cache_info.prompt_block_tables[id][block_idx]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
self.kv_cache.slot_mappings[id].append(slot)
self.kv_cache_info.slot_mappings[id].append(slot)

def get_cache(self):
return self.kv_cache
return self.kv_cache_info

def allocate(self, request_id: RequestId, num_tokens: int, num_sequences: int):
"""
Expand All @@ -267,7 +266,7 @@ def allocate(self, request_id: RequestId, num_tokens: int, num_sequences: int):
if self.sliding_window:
last_block_partially_shared &= num_tokens < self.sliding_window

prompt_blocks = self.kv_cache.prompt_block_tables[prompt_seq_id]
prompt_blocks = self.kv_cache_info.prompt_block_tables[prompt_seq_id]
assert prompt_blocks

prompt_shared = num_sequences > 1
Expand All @@ -277,7 +276,9 @@ def allocate(self, request_id: RequestId, num_tokens: int, num_sequences: int):
self.token_counts[decode_seq_id] = num_tokens

if not last_block_partially_shared:
self.kv_cache.decode_block_tables[decode_seq_id] = DecodeBlockTable(
self.kv_cache_info.decode_block_tables[
decode_seq_id
] = DecodeBlockTable(
prompt_blocks,
num_tokens,
self.block_size,
Expand All @@ -286,25 +287,29 @@ def allocate(self, request_id: RequestId, num_tokens: int, num_sequences: int):
)
else:
if i < num_sequences:
# Need to copy the last block in self.kv_cache.block_tables[prompt_seq_id]
self.kv_cache.decode_block_tables[decode_seq_id] = DecodeBlockTable(
# Need to copy the last block in self.kv_cache_info.block_tables[prompt_seq_id]
self.kv_cache_info.decode_block_tables[
decode_seq_id
] = DecodeBlockTable(
prompt_blocks[:-1],
num_tokens,
self.block_size,
self.block_sliding_window,
prompt_shared,
)
last_block_copy = self.free_blocks.pop()
self.kv_cache.decode_block_tables[decode_seq_id].append(
self.kv_cache_info.decode_block_tables[decode_seq_id].append(
last_block_copy
)
self.kv_cache.pending_copy_from_to.extend(
self.kv_cache_info.pending_copy_from_to.extend(
[prompt_blocks[-1], last_block_copy]
)
else:
# The last sequence can directly overwrite the last block without copying it,
# since other sequences have its own copy of the last block.
self.kv_cache.decode_block_tables[decode_seq_id] = DecodeBlockTable(
self.kv_cache_info.decode_block_tables[
decode_seq_id
] = DecodeBlockTable(
prompt_blocks,
num_tokens,
self.block_size,
Expand Down
17 changes: 9 additions & 8 deletions serve/mlc_serve/model/tvm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tvm.runtime import disco as di

from .base import ModelArtifactConfig
from .paged_cache_manager import KVCache, CacheManager
from .paged_cache_manager import KVCacheInfo, CacheManager
from .model_common import (
sample,
prepare_inputs,
Expand Down Expand Up @@ -152,6 +152,8 @@ def __init__(
"tvm.contrib.vllm.copy_blocks"
)

self.cache_blocks = None

def get_used_memory(self):
if self.disco_session:
params = self.params.debug_get_from_remote(0)
Expand Down Expand Up @@ -204,7 +206,7 @@ def profile_memory_usage(self, seq_lens):
def generate(
self,
requests: Sequence[Union[PrefillRequest, DecodeRequest]],
cache: KVCache,
cache: KVCacheInfo,
) -> List[TextGenerationResult]:
if len(requests) == 0:
return []
Expand Down Expand Up @@ -266,7 +268,7 @@ def generate(
input_ids,
positions,
seq_lens,
cache.cache_blocks,
self.cache_blocks,
slot_mapping,
indices_within_window,
self.params,
Expand All @@ -276,7 +278,7 @@ def generate(
input_ids,
positions,
seq_lens,
cache.cache_blocks,
self.cache_blocks,
slot_mapping,
self.params,
)
Expand All @@ -297,7 +299,7 @@ def generate(
input_ids,
positions,
seq_lens,
cache.cache_blocks,
self.cache_blocks,
slot_mapping,
block_tables,
self.params,
Expand All @@ -324,7 +326,7 @@ def generate(
"int64",
)

self.copy_cache_blocks_func(cache.cache_blocks, block_mapping)
self.copy_cache_blocks_func(self.cache_blocks, block_mapping)
cache.pending_copy_from_to = []

try:
Expand Down Expand Up @@ -474,7 +476,7 @@ def init_tvm_model(
else:
init_cache_func = tvm.get_global_func("tvm.contrib.vllm.allocate_kv_cache")

cache_blocks = init_cache_func(
model.cache_blocks = init_cache_func(
head_size,
model_artifact_config.num_hidden_layers,
num_kv_heads,
Expand All @@ -483,7 +485,6 @@ def init_tvm_model(
)

cache_manager = CacheManager(
cache_blocks,
num_blocks,
model_artifact_config.sliding_window,
)
Expand Down
1 change: 1 addition & 0 deletions serve/mlc_serve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def get_default_mlc_serve_argparser(description="", allow_override=False):
parser.add_argument("--artifact-path", type=str, default="dist")
parser.add_argument("--use-sync-engine", action="store_true")
parser.add_argument("--max-num-batched-tokens", type=int, default=4096)
parser.add_argument("--num-sequences-to-sample", type=int, default=1)
Copy link
Member Author

@masahi masahi Jan 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was accidentaly removed in #162 apparently.

parser.add_argument("--min-decode-steps", type=int, default=32)
parser.add_argument("--max-decode-steps", type=int, default=56)
parser.add_argument("--debug-logging", action="store_true")
Expand Down