From ddd08aa2ec618a39cebef85a28a47ab8e8ce7fba Mon Sep 17 00:00:00 2001 From: ispobock Date: Tue, 24 Sep 2024 08:10:20 +0800 Subject: [PATCH] moe torch compile --- python/sglang/srt/layers/fused_moe/patch.py | 117 ++++++++++++++++++ .../srt/model_executor/cuda_graph_runner.py | 14 ++- 2 files changed, 126 insertions(+), 5 deletions(-) create mode 100644 python/sglang/srt/layers/fused_moe/patch.py diff --git a/python/sglang/srt/layers/fused_moe/patch.py b/python/sglang/srt/layers/fused_moe/patch.py new file mode 100644 index 0000000000..65fcd78779 --- /dev/null +++ b/python/sglang/srt/layers/fused_moe/patch.py @@ -0,0 +1,117 @@ +from typing import Optional + +import torch +from torch.nn import functional as F + + +def fused_topk_native( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +): + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + M, _ = hidden_states.shape + topk_weights = torch.empty( + M, topk, dtype=torch.float32, device=hidden_states.device + ) + topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) + topk_weights = F.softmax(gating_output.float(), dim=-1) + topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return topk_weights, topk_ids + + +# This is used by the Deepseek-V2 model +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, +): + + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + + scores = torch.softmax(gating_output, dim=-1) + num_token = scores.shape[0] + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) + .reshape(num_token, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return topk_weights, topk_ids + + +def select_experts_native( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, +): + # DeekSeekv2 uses grouped_top_k + if use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + topk_weights, topk_ids = grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + ) + else: + topk_weights, topk_ids = fused_topk_native( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + ) + return topk_weights, topk_ids + + +def fused_moe_forward_native( + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, +) -> torch.Tensor: + topk_weights, topk_ids = select_experts_native( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + ) + w13_weights = layer.w13_weight[topk_ids] + w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2) + w2_weights = layer.w2_weight[topk_ids] + x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights)) + x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) + expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) + return torch.einsum("tai,ta -> ti", expert_outs, topk_weights) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 3f73e734dd..4eb2197aac 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -25,6 +25,7 @@ from vllm.distributed.parallel_state import graph_capture from vllm.model_executor.custom_op import CustomOp +from sglang.srt.layers.fused_moe.patch import fused_moe_forward_native from sglang.srt.layers.logits_processor import ( LogitsMetadata, LogitsProcessor, @@ -41,14 +42,15 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False): for sub in model._modules.values(): if isinstance(sub, CustomOp): - # NOTE: FusedMoE torch native implementaiton is not efficient - if "FusedMoE" in sub.__class__.__name__: - continue if reverse: sub._forward_method = sub.forward_cuda setattr(sub, "is_torch_compile", False) else: - sub._forward_method = sub.forward_native + # NOTE: Temporarily workaround MoE + if "FusedMoE" in sub.__class__.__name__: + sub._forward_method = fused_moe_forward_native + else: + sub._forward_method = sub.forward_native setattr(sub, "is_torch_compile", True) if isinstance(sub, torch.nn.Module): _to_torch(sub, reverse) @@ -67,7 +69,9 @@ def patch_model( monkey_patch_vllm_all_gather() backup_ca_comm = tp_group.ca_comm tp_group.ca_comm = None - yield torch.compile(model.forward, mode="max-autotune-no-cudagraphs") + yield torch.compile( + torch.no_grad()(model.forward), mode="max-autotune-no-cudagraphs" + ) else: yield model.forward finally: