From eedd931ddca8f634a0e84d369f4e5cb111a3d07e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 15 Feb 2024 04:50:22 +0000 Subject: [PATCH 1/4] wip, single gpu works --- mlc_llm/relax_model/llama.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py index 145c522810..93ae4c9fd1 100644 --- a/mlc_llm/relax_model/llama.py +++ b/mlc_llm/relax_model/llama.py @@ -39,6 +39,7 @@ def __init__( build_model_only=False, num_shards=1, sliding_window=None, + attention_bias=False, **kwargs, ): self.dtype = dtype @@ -59,6 +60,7 @@ def __init__( self.position_embedding_base = position_embedding_base self.combine_matmul = combine_matmul self.sliding_window = sliding_window + self.attention_bias = attention_bias if build_model_only and num_shards > 1: self.num_shards = num_shards @@ -282,12 +284,13 @@ def __init__(self, config: LlamaConfig): self.position_embedding_base = config.position_embedding_base self.combine_matmul = config.combine_matmul + if self.combine_matmul: self.query_key_value_proj = Linear( self.hidden_size, (self.num_query_heads + 2 * self.num_key_value_heads) * self.head_dim, dtype=dtype, - bias=False, + bias=config.attention_bias, ) self.query_key_value_proj.weight.shard_dim = 0 self.query_key_value_proj.weight.shard_strategy = "shard_qkv" @@ -296,26 +299,26 @@ def __init__(self, config: LlamaConfig): self.hidden_size, self.num_query_heads * self.head_dim, dtype=dtype, - bias=False, + bias=config.attention_bias, ) self.k_proj = Linear( self.hidden_size, self.num_key_value_heads * self.head_dim, dtype=dtype, - bias=False, + bias=config.attention_bias, ) self.v_proj = Linear( self.hidden_size, self.num_key_value_heads * self.head_dim, dtype=dtype, - bias=False, + bias=config.attention_bias, ) self.q_proj.weight.shard_dim = 0 self.k_proj.weight.shard_dim = 0 self.v_proj.weight.shard_dim = 0 self.o_proj = Linear( - self.head_dim * self.num_query_heads, self.hidden_size, dtype=dtype, bias=False + self.head_dim * self.num_query_heads, self.hidden_size, dtype=dtype, bias=config.attention_bias ) self.o_proj.weight.shard_dim = 1 self.o_proj.weight.shard_strategy = "shard_o_proj_k" @@ -1365,9 +1368,16 @@ def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): q_heads = config.num_attention_heads kv_heads = config.get_num_key_value_heads() q, k, v = torch_params - assert q.shape == (q_heads * head_dim, hidden_size) - assert k.shape == (kv_heads * head_dim, hidden_size) - assert v.shape == (kv_heads * head_dim, hidden_size) + + if len(q.shape) == 2: + assert q.shape == (q_heads * head_dim, hidden_size) + assert k.shape == (kv_heads * head_dim, hidden_size) + assert v.shape == (kv_heads * head_dim, hidden_size) + elif len(q.shape) == 1: + assert q.shape == (q_heads * head_dim,) + assert k.shape == (kv_heads * head_dim,) + assert v.shape == (kv_heads * head_dim,) + qkv = np.concatenate([q, k, v], axis=0).astype(dtype) return qkv if "gate_up_proj" in relax_pname: From 5247b22aa68b9986c00c9fe52f0a6730a8f1b5bb Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 15 Feb 2024 05:38:23 +0000 Subject: [PATCH 2/4] wip --- mlc_llm/relax_model/commons.py | 21 +++++++++++++++++++++ mlc_llm/relax_model/llama.py | 8 ++++++++ 2 files changed, 29 insertions(+) diff --git a/mlc_llm/relax_model/commons.py b/mlc_llm/relax_model/commons.py index 676ff610a2..8632916471 100644 --- a/mlc_llm/relax_model/commons.py +++ b/mlc_llm/relax_model/commons.py @@ -55,6 +55,25 @@ def shard_qkv_weight_scale(weight: relax.TensorStructInfo): func = te.create_prim_func([a, w]) return func + def shard_bias(bias: relax.TensorStructInfo): + (hidden_dim,), dtype = bias.shape, bias.dtype + hidden_dim = int(hidden_dim) + if param_shape_is_already_sharded: + hidden_dim *= num_shards + head_dim = hidden_dim // (q_heads + 2 * kv_heads) + a = te.placeholder((hidden_dim,), dtype=dtype) + w = topi.reshape(a, (hidden_dim // head_dim, head_dim)) + q = te.compute((q_heads, head_dim), lambda i, j: w[i, j]) + k = te.compute((kv_heads, head_dim), lambda i, j: w[q_heads + i, j]) + v = te.compute((kv_heads, head_dim), lambda i, j: w[q_heads + kv_heads + i, j]) + q = topi.reshape(q, (num_shards, q_heads // num_shards, head_dim)) + k = topi.reshape(k, (num_shards, kv_heads // num_shards, head_dim)) + v = topi.reshape(v, (num_shards, kv_heads // num_shards, head_dim)) + w = topi.concatenate((q, k, v), axis=1) + w = topi.reshape(w, (num_shards, (q_heads + kv_heads * 2) // num_shards * head_dim)) + func = te.create_prim_func([a, w]) + return func + def shard_k_weight_scale(weight: relax.TensorStructInfo): (spatial, red), dtype = weight.shape, weight.dtype spatial, red = int(spatial), int(red) @@ -112,8 +131,10 @@ def moe_shard_gate_up_weight_scale(weight: relax.TensorStructInfo): return { "shard_qkv": shard_qkv_weight_scale, + "shard_qkv_bias": shard_bias, "shard_mlp_k": shard_k_weight_scale, "shard_o_proj_k": shard_k_weight_scale, + "shard_o_proj_k_bias": shard_bias, "shard_gate_up": shard_gate_up_weight_scale, "moe_shard_mlp_k": moe_shard_k_weight_scale, "moe_shard_gate_up": moe_shard_gate_up_weight_scale, diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py index 93ae4c9fd1..4475ba9785 100644 --- a/mlc_llm/relax_model/llama.py +++ b/mlc_llm/relax_model/llama.py @@ -294,6 +294,10 @@ def __init__(self, config: LlamaConfig): ) self.query_key_value_proj.weight.shard_dim = 0 self.query_key_value_proj.weight.shard_strategy = "shard_qkv" + + if config.attention_bias: + self.query_key_value_proj.bias.shard_dim = 0 + self.query_key_value_proj.bias.shard_strategy = "shard_qkv_bias" else: self.q_proj = Linear( self.hidden_size, @@ -323,6 +327,10 @@ def __init__(self, config: LlamaConfig): self.o_proj.weight.shard_dim = 1 self.o_proj.weight.shard_strategy = "shard_o_proj_k" + if config.attention_bias: + self.o_proj.bias.shard_dim = 0 + self.o_proj.bias.shard_strategy = "shard_o_proj_k_bias" + def project_qkv(self, hidden_states, query_output_shape, kv_output_shape): from tvm.relax.op import reshape, split From e1573b62a88604982986ce3420d85fe099a7b24d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 15 Feb 2024 06:18:31 +0000 Subject: [PATCH 3/4] works --- mlc_llm/relax_model/commons.py | 1 - mlc_llm/relax_model/llama.py | 4 ---- 2 files changed, 5 deletions(-) diff --git a/mlc_llm/relax_model/commons.py b/mlc_llm/relax_model/commons.py index 8632916471..01ce81297f 100644 --- a/mlc_llm/relax_model/commons.py +++ b/mlc_llm/relax_model/commons.py @@ -134,7 +134,6 @@ def moe_shard_gate_up_weight_scale(weight: relax.TensorStructInfo): "shard_qkv_bias": shard_bias, "shard_mlp_k": shard_k_weight_scale, "shard_o_proj_k": shard_k_weight_scale, - "shard_o_proj_k_bias": shard_bias, "shard_gate_up": shard_gate_up_weight_scale, "moe_shard_mlp_k": moe_shard_k_weight_scale, "moe_shard_gate_up": moe_shard_gate_up_weight_scale, diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py index 4475ba9785..5a085c5e03 100644 --- a/mlc_llm/relax_model/llama.py +++ b/mlc_llm/relax_model/llama.py @@ -327,10 +327,6 @@ def __init__(self, config: LlamaConfig): self.o_proj.weight.shard_dim = 1 self.o_proj.weight.shard_strategy = "shard_o_proj_k" - if config.attention_bias: - self.o_proj.bias.shard_dim = 0 - self.o_proj.bias.shard_strategy = "shard_o_proj_k_bias" - def project_qkv(self, hidden_states, query_output_shape, kv_output_shape): from tvm.relax.op import reshape, split From 76185953d5c35a235625bd282994ac6badf83841 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 15 Feb 2024 08:50:33 +0000 Subject: [PATCH 4/4] set vocab_size in benchmark correctly --- serve/benchmarks/benchmark_throughput.py | 1 + 1 file changed, 1 insertion(+) diff --git a/serve/benchmarks/benchmark_throughput.py b/serve/benchmarks/benchmark_throughput.py index 3cb6958e47..d3ab1587c9 100644 --- a/serve/benchmarks/benchmark_throughput.py +++ b/serve/benchmarks/benchmark_throughput.py @@ -142,6 +142,7 @@ def run_mlc(engine, requests, args) -> float: logprobs=args.sampling_setting["logprobs"], top_logprobs=args.sampling_setting["top_logprobs"], json_schema=args.sampling_setting["json_schema"], + vocab_size=engine.model_artifact_config.vocab_size, ), stopping_criteria=StoppingCriteria( max_tokens=args.num_output_tokens, stop_sequences=None