Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow serving llama models with tensor parallel #592

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/petals/models/llama/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ def __init__(self, config: LlamaConfig):
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

self.num_key_value_heads = config.num_key_value_heads
self.head_dim = config.hidden_size // config.num_attention_heads

self.pre_attn_graph = None
self.post_attn_graph = None

Expand Down Expand Up @@ -283,14 +286,16 @@ def _reorder_cache_from_bloom_to_llama(
key_states, value_states = key_value
key_states = key_states.permute(0, 2, 1)
key_states = key_states.view(
batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
batch_size, self.num_key_value_heads//2, seq_length, self.head_dim
#batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
)
value_states = value_states.view(*key_states.shape)
return (key_states, value_states)

def _reorder_cache_from_llama_to_bloom(
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
) -> Tuple[torch.Tensor]:
raise NotImplementedError
key_states, value_states = key_value
value_states = value_states.view(
batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
Expand Down
97 changes: 97 additions & 0 deletions src/petals/models/llama/slicing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""

Check failure on line 1 in src/petals/models/llama/slicing.py

View workflow job for this annotation

GitHub Actions / isort

Imports are incorrectly sorted and/or formatted.
Optimized configs for selected models. These configs are not necessary, but they can improve performance in some
cases, e.g. training with very small batches or inference with long sequences.

NB: some of these configs get fairly complicated in order to squeeze a bit of extra performance. When developing your
own config, you can get most of the performance benefits by using auto config -- and maybe splitting MLP layers.
"""
from functools import partial
from itertools import chain
from typing import Callable, Dict, Sequence

import torch
from transformers import PretrainedConfig, LlamaConfig

from tensor_parallel.communications import CollectiveOperation
from tensor_parallel.slicer_wrapper import Config
from tensor_parallel.tensor_parallel import PerDeviceTensors

ConfigGetter = Callable[[PretrainedConfig, Sequence[torch.device]], Config]

def get_llama_config(model_config: LlamaConfig, devices: Sequence[torch.device]) -> Config:
assert model_config.model_type == "llama", f"Trying to pass {model_config.model_type} as llama config"

world_size = len(devices)
head_dim = model_config.hidden_size // model_config.num_attention_heads
num_kv = model_config.num_key_value_heads
q_per_kv = model_config.num_attention_heads // model_config.num_key_value_heads

gather_kv_across_ranks = CollectiveOperation(
world_size=world_size,
func=lambda *kvs: [PerDeviceTensors(*chain(*(x or [None] for x in kvs)))] * world_size
)

select_kv_for_rank = lambda kvs, rank: (kvs[2 * rank], kvs[2 * rank + 1]) if kvs else None

config = Config(
state_rules={
# LlamaAttention
r".*self_attn\.q_proj\.weight$": partial(split_heads, dim=0, head_dim=q_per_kv * head_dim, world_size=world_size),
r".*self_attn\.k_proj\.weight$": partial(split_heads, dim=0, head_dim=head_dim, world_size=world_size),
r".*self_attn\.v_proj\.weight$": partial(split_heads, dim=0, head_dim=head_dim, world_size=world_size),
r".*self_attn\.o_proj\.weight$": partial(split_heads, dim=1, head_dim=q_per_kv * head_dim, world_size=world_size),
# LlamaFeedForward
r".*mlp\.gate_proj\.weight$": "split 0",
r".*mlp\.down_proj\.weight$": "split 1",
r".*mlp\.up_proj\.weight$": "split 0",
# LlamaModel
#r".*embed_tokens.weight$": "split 1",
#r".*lm_head\.weight$": "split 0",
},
input_rules={
r".*self_attn$": {"past_key_value": select_kv_for_rank},
},
output_rules={
r".*self_attn$": {0: "sum", 2: gather_kv_across_ranks},
r".*mlp$": {0: "sum"},
r".*embed_tokens$": {0: "gather -1"},
r".*lm_head$": {0: "gather -1"},
},
attr_rules={
r".*self_attn$": {
"hidden_size": partial(split_inner_dim, num_heads=num_kv, world_size=world_size),
"num_heads": partial(split_num_heads, world_size=world_size),
"num_key_value_heads": partial(split_num_heads, world_size=world_size),
}
},
#attr_rules={
# r".*self_attn$": {
# "hidden_size": partial(split_inner_dim, num_heads=num_kv, world_size=world_size),
# "num_heads": lambda n, rank: q_per_kv * split_num_heads(n // q_per_kv, rank=rank, world_size=world_size),
# }
#},
)

return config



def split_heads(tensor: torch.Tensor, *, dim: int, head_dim: int, rank: int, world_size: int, optional: bool = False):
"""Split a tensor along dim such that each part size is divisible by head_dim"""
if tensor is None and optional:
return None
assert tensor.shape[dim] % head_dim == 0, tensor.shape
if dim < 0:
dim = (tensor.ndim + dim) % tensor.ndim
shape = list(tensor.shape)
shape[dim] //= head_dim
shape.insert(dim + 1, head_dim)
tensor_part = tensor.reshape(shape).tensor_split(world_size, dim=dim)[rank].flatten(dim, dim + 1)
return tensor_part


def split_num_heads(num_heads: int, *, rank: int, world_size: int):
return torch.empty(num_heads, device="meta").tensor_split(world_size)[rank].numel()

def split_inner_dim(inner_dim: int, *, rank: int, num_heads: int, world_size: int):
return split_num_heads(num_heads=num_heads, rank=rank, world_size=world_size) * (inner_dim // num_heads)
3 changes: 3 additions & 0 deletions src/petals/utils/convert_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ def make_tensor_parallel(
if model_config.model_type == "bloom":
tp_config = get_bloom_config(model_config, devices)
del tp_config.state_rules[re.compile(".*word_embeddings.weight$")]
if model_config.model_type == "llama":
from petals.models.llama.slicing import get_llama_config
tp_config = get_llama_config(model_config, devices)
else:
if len(devices) > 1:
logger.warning("Tensor parallelism is not tested for models other than BLOOM yet, proceed with caution")
Expand Down
Loading