Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Smaug support #212

Merged
merged 4 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions mlc_llm/relax_model/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,25 @@ def shard_qkv_weight_scale(weight: relax.TensorStructInfo):
func = te.create_prim_func([a, w])
return func

def shard_bias(bias: relax.TensorStructInfo):
(hidden_dim,), dtype = bias.shape, bias.dtype
hidden_dim = int(hidden_dim)
if param_shape_is_already_sharded:
hidden_dim *= num_shards
head_dim = hidden_dim // (q_heads + 2 * kv_heads)
a = te.placeholder((hidden_dim,), dtype=dtype)
w = topi.reshape(a, (hidden_dim // head_dim, head_dim))
q = te.compute((q_heads, head_dim), lambda i, j: w[i, j])
k = te.compute((kv_heads, head_dim), lambda i, j: w[q_heads + i, j])
v = te.compute((kv_heads, head_dim), lambda i, j: w[q_heads + kv_heads + i, j])
q = topi.reshape(q, (num_shards, q_heads // num_shards, head_dim))
k = topi.reshape(k, (num_shards, kv_heads // num_shards, head_dim))
v = topi.reshape(v, (num_shards, kv_heads // num_shards, head_dim))
w = topi.concatenate((q, k, v), axis=1)
w = topi.reshape(w, (num_shards, (q_heads + kv_heads * 2) // num_shards * head_dim))
func = te.create_prim_func([a, w])
return func

def shard_k_weight_scale(weight: relax.TensorStructInfo):
(spatial, red), dtype = weight.shape, weight.dtype
spatial, red = int(spatial), int(red)
Expand Down Expand Up @@ -112,6 +131,7 @@ def moe_shard_gate_up_weight_scale(weight: relax.TensorStructInfo):

return {
"shard_qkv": shard_qkv_weight_scale,
"shard_qkv_bias": shard_bias,
"shard_mlp_k": shard_k_weight_scale,
"shard_o_proj_k": shard_k_weight_scale,
Copy link
Member Author

@masahi masahi Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why the bias for output projection must not be sharded. Initially I sharded it as well but the result was incorrect. Then I remember that the 1D scale for output projection in FT quantization must not be sharded as well https://github.com/mlc-ai/mlc-llm/blob/main/mlc_llm/relax_model/commons.py#L316-L320. So I skipped the bias shard for output proj and it worked.

Copy link
Member

@vinx13 vinx13 Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the shading is done for the reduction dimension, bias doesn't need to be shared, instead, bias need to be after all reduce or divided by num_shards

"shard_gate_up": shard_gate_up_weight_scale,
Expand Down
30 changes: 22 additions & 8 deletions mlc_llm/relax_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
build_model_only=False,
num_shards=1,
sliding_window=None,
attention_bias=False,
**kwargs,
):
self.dtype = dtype
Expand All @@ -59,6 +60,7 @@ def __init__(
self.position_embedding_base = position_embedding_base
self.combine_matmul = combine_matmul
self.sliding_window = sliding_window
self.attention_bias = attention_bias

if build_model_only and num_shards > 1:
self.num_shards = num_shards
Expand Down Expand Up @@ -282,40 +284,45 @@ def __init__(self, config: LlamaConfig):
self.position_embedding_base = config.position_embedding_base

self.combine_matmul = config.combine_matmul

if self.combine_matmul:
self.query_key_value_proj = Linear(
self.hidden_size,
(self.num_query_heads + 2 * self.num_key_value_heads) * self.head_dim,
dtype=dtype,
bias=False,
bias=config.attention_bias,
)
self.query_key_value_proj.weight.shard_dim = 0
self.query_key_value_proj.weight.shard_strategy = "shard_qkv"

if config.attention_bias:
self.query_key_value_proj.bias.shard_dim = 0
self.query_key_value_proj.bias.shard_strategy = "shard_qkv_bias"
else:
self.q_proj = Linear(
self.hidden_size,
self.num_query_heads * self.head_dim,
dtype=dtype,
bias=False,
bias=config.attention_bias,
)
self.k_proj = Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
dtype=dtype,
bias=False,
bias=config.attention_bias,
)
self.v_proj = Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
dtype=dtype,
bias=False,
bias=config.attention_bias,
)
self.q_proj.weight.shard_dim = 0
self.k_proj.weight.shard_dim = 0
self.v_proj.weight.shard_dim = 0

self.o_proj = Linear(
self.head_dim * self.num_query_heads, self.hidden_size, dtype=dtype, bias=False
self.head_dim * self.num_query_heads, self.hidden_size, dtype=dtype, bias=config.attention_bias
)
self.o_proj.weight.shard_dim = 1
self.o_proj.weight.shard_strategy = "shard_o_proj_k"
Expand Down Expand Up @@ -1365,9 +1372,16 @@ def f_compute_relax_param(relax_pname: str, torch_params: List[Any]):
q_heads = config.num_attention_heads
kv_heads = config.get_num_key_value_heads()
q, k, v = torch_params
assert q.shape == (q_heads * head_dim, hidden_size)
assert k.shape == (kv_heads * head_dim, hidden_size)
assert v.shape == (kv_heads * head_dim, hidden_size)

if len(q.shape) == 2:
assert q.shape == (q_heads * head_dim, hidden_size)
assert k.shape == (kv_heads * head_dim, hidden_size)
assert v.shape == (kv_heads * head_dim, hidden_size)
elif len(q.shape) == 1:
assert q.shape == (q_heads * head_dim,)
assert k.shape == (kv_heads * head_dim,)
assert v.shape == (kv_heads * head_dim,)

qkv = np.concatenate([q, k, v], axis=0).astype(dtype)
return qkv
if "gate_up_proj" in relax_pname:
Expand Down
Loading