Skip to content

Commit

Permalink
Smaug support (#212)
Browse files Browse the repository at this point in the history
* wip, single gpu works

* wip

* works

* set vocab_size in benchmark correctly
  • Loading branch information
masahi committed Feb 16, 2024
1 parent 5588d17 commit abe93a1
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 8 deletions.
20 changes: 20 additions & 0 deletions mlc_llm/relax_model/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,25 @@ def shard_qkv_weight_scale(weight: relax.TensorStructInfo):
func = te.create_prim_func([a, w])
return func

def shard_bias(bias: relax.TensorStructInfo):
(hidden_dim,), dtype = bias.shape, bias.dtype
hidden_dim = int(hidden_dim)
if param_shape_is_already_sharded:
hidden_dim *= num_shards
head_dim = hidden_dim // (q_heads + 2 * kv_heads)
a = te.placeholder((hidden_dim,), dtype=dtype)
w = topi.reshape(a, (hidden_dim // head_dim, head_dim))
q = te.compute((q_heads, head_dim), lambda i, j: w[i, j])
k = te.compute((kv_heads, head_dim), lambda i, j: w[q_heads + i, j])
v = te.compute((kv_heads, head_dim), lambda i, j: w[q_heads + kv_heads + i, j])
q = topi.reshape(q, (num_shards, q_heads // num_shards, head_dim))
k = topi.reshape(k, (num_shards, kv_heads // num_shards, head_dim))
v = topi.reshape(v, (num_shards, kv_heads // num_shards, head_dim))
w = topi.concatenate((q, k, v), axis=1)
w = topi.reshape(w, (num_shards, (q_heads + kv_heads * 2) // num_shards * head_dim))
func = te.create_prim_func([a, w])
return func

def shard_k_weight_scale(weight: relax.TensorStructInfo):
(spatial, red), dtype = weight.shape, weight.dtype
spatial, red = int(spatial), int(red)
Expand Down Expand Up @@ -112,6 +131,7 @@ def moe_shard_gate_up_weight_scale(weight: relax.TensorStructInfo):

return {
"shard_qkv": shard_qkv_weight_scale,
"shard_qkv_bias": shard_bias,
"shard_mlp_k": shard_k_weight_scale,
"shard_o_proj_k": shard_k_weight_scale,
"shard_gate_up": shard_gate_up_weight_scale,
Expand Down
30 changes: 22 additions & 8 deletions mlc_llm/relax_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
build_model_only=False,
num_shards=1,
sliding_window=None,
attention_bias=False,
**kwargs,
):
self.dtype = dtype
Expand All @@ -59,6 +60,7 @@ def __init__(
self.position_embedding_base = position_embedding_base
self.combine_matmul = combine_matmul
self.sliding_window = sliding_window
self.attention_bias = attention_bias

if build_model_only and num_shards > 1:
self.num_shards = num_shards
Expand Down Expand Up @@ -282,40 +284,45 @@ def __init__(self, config: LlamaConfig):
self.position_embedding_base = config.position_embedding_base

self.combine_matmul = config.combine_matmul

if self.combine_matmul:
self.query_key_value_proj = Linear(
self.hidden_size,
(self.num_query_heads + 2 * self.num_key_value_heads) * self.head_dim,
dtype=dtype,
bias=False,
bias=config.attention_bias,
)
self.query_key_value_proj.weight.shard_dim = 0
self.query_key_value_proj.weight.shard_strategy = "shard_qkv"

if config.attention_bias:
self.query_key_value_proj.bias.shard_dim = 0
self.query_key_value_proj.bias.shard_strategy = "shard_qkv_bias"
else:
self.q_proj = Linear(
self.hidden_size,
self.num_query_heads * self.head_dim,
dtype=dtype,
bias=False,
bias=config.attention_bias,
)
self.k_proj = Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
dtype=dtype,
bias=False,
bias=config.attention_bias,
)
self.v_proj = Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
dtype=dtype,
bias=False,
bias=config.attention_bias,
)
self.q_proj.weight.shard_dim = 0
self.k_proj.weight.shard_dim = 0
self.v_proj.weight.shard_dim = 0

self.o_proj = Linear(
self.head_dim * self.num_query_heads, self.hidden_size, dtype=dtype, bias=False
self.head_dim * self.num_query_heads, self.hidden_size, dtype=dtype, bias=config.attention_bias
)
self.o_proj.weight.shard_dim = 1
self.o_proj.weight.shard_strategy = "shard_o_proj_k"
Expand Down Expand Up @@ -1365,9 +1372,16 @@ def f_compute_relax_param(relax_pname: str, torch_params: List[Any]):
q_heads = config.num_attention_heads
kv_heads = config.get_num_key_value_heads()
q, k, v = torch_params
assert q.shape == (q_heads * head_dim, hidden_size)
assert k.shape == (kv_heads * head_dim, hidden_size)
assert v.shape == (kv_heads * head_dim, hidden_size)

if len(q.shape) == 2:
assert q.shape == (q_heads * head_dim, hidden_size)
assert k.shape == (kv_heads * head_dim, hidden_size)
assert v.shape == (kv_heads * head_dim, hidden_size)
elif len(q.shape) == 1:
assert q.shape == (q_heads * head_dim,)
assert k.shape == (kv_heads * head_dim,)
assert v.shape == (kv_heads * head_dim,)

qkv = np.concatenate([q, k, v], axis=0).astype(dtype)
return qkv
if "gate_up_proj" in relax_pname:
Expand Down
1 change: 1 addition & 0 deletions serve/benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def run_mlc(engine, requests, args) -> float:
logprobs=args.sampling_setting["logprobs"],
top_logprobs=args.sampling_setting["top_logprobs"],
json_schema=args.sampling_setting["json_schema"],
vocab_size=engine.model_artifact_config.vocab_size,
),
stopping_criteria=StoppingCriteria(
max_tokens=args.num_output_tokens, stop_sequences=None
Expand Down

0 comments on commit abe93a1

Please sign in to comment.