|
| 1 | +# Copyright (c) OpenMMLab. All rights reserved. |
| 2 | + |
| 3 | +import contextlib |
| 4 | +from typing import Callable, List |
| 5 | + |
| 6 | +import torch |
| 7 | +import torch.distributed as dist |
| 8 | + |
| 9 | +from lmdeploy.pytorch.backends.moe import FusedMoEBlockedF8Builder, FusedMoEBlockedF8Impl |
| 10 | +from lmdeploy.pytorch.distributed import get_dist_manager |
| 11 | +from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import fused_moe_blocked_fp8 |
| 12 | +from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8 |
| 13 | +from lmdeploy.pytorch.kernels.cuda.fused_moe import _renormalize |
| 14 | +from lmdeploy.pytorch.model_inputs import get_step_ctx_manager |
| 15 | +from lmdeploy.utils import get_logger |
| 16 | + |
| 17 | +from .ep_utils import gather_outputs_by_attn_tp, split_inputs_by_attn_tp |
| 18 | + |
| 19 | +logger = get_logger('lmdeploy') |
| 20 | + |
| 21 | + |
| 22 | +class TritonFusedMoEBlockedF8Impl(FusedMoEBlockedF8Impl): |
| 23 | + """Triton fused moe blocked f8 implementation.""" |
| 24 | + |
| 25 | + def __init__(self, |
| 26 | + top_k: int, |
| 27 | + num_experts: int, |
| 28 | + renormalize: bool = False, |
| 29 | + block_size: int = 128, |
| 30 | + out_dtype: torch.dtype = torch.float16): |
| 31 | + self.num_experts = num_experts |
| 32 | + self.top_k = top_k |
| 33 | + self.renormalize = renormalize |
| 34 | + self.block_size = block_size |
| 35 | + self.out_dtype = out_dtype |
| 36 | + |
| 37 | + def ep_expert_list(self, world_size: int, rank: int): |
| 38 | + """Experts list of current rank.""" |
| 39 | + num_experts = self.num_experts |
| 40 | + expert_per_rank = (num_experts + world_size - 1) // world_size |
| 41 | + first_expert = rank * expert_per_rank |
| 42 | + last_expert = min(first_expert + expert_per_rank, num_experts) |
| 43 | + return list(range(first_expert, last_expert)) |
| 44 | + |
| 45 | + def forward(self, |
| 46 | + hidden_states: torch.Tensor, |
| 47 | + topk_weights: torch.Tensor, |
| 48 | + topk_ids: torch.LongTensor, |
| 49 | + gate_up_weights: torch.Tensor, |
| 50 | + gate_up_scale: torch.Tensor, |
| 51 | + down_weights: torch.Tensor, |
| 52 | + down_scale: torch.Tensor, |
| 53 | + gate_up_bias: torch.Tensor = None, |
| 54 | + down_bias: torch.Tensor = None, |
| 55 | + expert_list: List[int] = None, |
| 56 | + act_func: Callable = None): |
| 57 | + """forward.""" |
| 58 | + input_size = hidden_states.shape |
| 59 | + hidden_states = hidden_states.flatten(0, -2) |
| 60 | + input_quant, input_scale = quant_fp8(hidden_states, self.block_size, dtype=gate_up_weights.dtype) |
| 61 | + |
| 62 | + expert_offset = 0 |
| 63 | + num_experts = None |
| 64 | + if expert_list is not None and len(expert_list) != self.num_experts: |
| 65 | + expert_offset = expert_list[0] |
| 66 | + num_experts = self.num_experts |
| 67 | + output = fused_moe_blocked_fp8(input_quant, |
| 68 | + input_scale, |
| 69 | + gate_up_weights, |
| 70 | + gate_up_scale, |
| 71 | + down_weights, |
| 72 | + down_scale, |
| 73 | + topk_weights=topk_weights, |
| 74 | + topk_ids=topk_ids, |
| 75 | + topk=self.top_k, |
| 76 | + w1_bias=gate_up_bias, |
| 77 | + w2_bias=down_bias, |
| 78 | + out_dtype=hidden_states.dtype, |
| 79 | + expert_offset=expert_offset, |
| 80 | + num_experts=num_experts, |
| 81 | + renormalize=self.renormalize, |
| 82 | + act_func=act_func) |
| 83 | + output = output.unflatten(0, input_size[:-1]) |
| 84 | + return output |
| 85 | + |
| 86 | + |
| 87 | +@contextlib.contextmanager |
| 88 | +def monk_deep_gemm(): |
| 89 | + from dlblas.kernels.fused_moe_v3 import use_deep_gemm |
| 90 | + if use_deep_gemm: |
| 91 | + yield |
| 92 | + return |
| 93 | + |
| 94 | + # patch deep_gemm |
| 95 | + import deep_gemm |
| 96 | + import dlblas |
| 97 | + |
| 98 | + from lmdeploy.pytorch.third_party import deep_gemm as patched_deep_gemm |
| 99 | + func0_ = getattr(deep_gemm, 'get_col_major_tma_aligned_tensor', None) |
| 100 | + func1_ = getattr(deep_gemm, 'm_grouped_gemm_fp8_fp8_bf16_nt_masked', None) |
| 101 | + deep_gemm.get_col_major_tma_aligned_tensor = patched_deep_gemm.get_mn_major_tma_aligned_tensor |
| 102 | + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked = patched_deep_gemm.m_grouped_fp8_gemm_nt_masked |
| 103 | + |
| 104 | + # patch dlblas |
| 105 | + dlblas.kernels.fused_moe_v3.use_deep_gemm = True |
| 106 | + dlblas.kernels.fused_moe_v3.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous = \ |
| 107 | + patched_deep_gemm.m_grouped_fp8_gemm_nt_contiguous |
| 108 | + yield |
| 109 | + |
| 110 | + # unpatch dlblas |
| 111 | + dlblas.kernels.fused_moe_v3.use_deep_gemm = False |
| 112 | + |
| 113 | + # unpatch deep_gemm |
| 114 | + if func0_ is not None: |
| 115 | + deep_gemm.get_col_major_tma_aligned_tensor = func0_ |
| 116 | + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked = func1_ |
| 117 | + else: |
| 118 | + del deep_gemm.get_col_major_tma_aligned_tensor |
| 119 | + del deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked |
| 120 | + |
| 121 | + |
| 122 | +class FusedDeepEpMoEBlockedF8Impl(TritonFusedMoEBlockedF8Impl): |
| 123 | + |
| 124 | + def __init__(self, |
| 125 | + ep_size: int, |
| 126 | + ep_group: dist.ProcessGroup, |
| 127 | + top_k: int, |
| 128 | + num_experts: int, |
| 129 | + hidden_dim: int, |
| 130 | + renormalize: bool = False, |
| 131 | + block_size: int = 128, |
| 132 | + out_dtype: torch.dtype = torch.bfloat16, |
| 133 | + layer_idx: int = 0): |
| 134 | + super().__init__(top_k, num_experts, renormalize, block_size, out_dtype) |
| 135 | + self.num_experts = num_experts |
| 136 | + self.ep_size = ep_size |
| 137 | + self.ep_group = ep_group |
| 138 | + self.hidden_dim = hidden_dim |
| 139 | + self.block_size = block_size |
| 140 | + self.out_dtype = out_dtype |
| 141 | + self.layer_idx = layer_idx |
| 142 | + try: |
| 143 | + import deep_gemm # noqa: F401 |
| 144 | + self.use_deep_gemm = True |
| 145 | + except ImportError: |
| 146 | + self.use_deep_gemm = False |
| 147 | + logger.warning('For higher performance, please install DeepGEMM https://github.com/deepseek-ai/DeepGEMM') |
| 148 | + |
| 149 | + # pre-allocate buffer |
| 150 | + self.fusedmoe_build(True) |
| 151 | + |
| 152 | + def ep_expert_list(self, world_size: int, rank: int): |
| 153 | + """Experts list of current rank.""" |
| 154 | + if get_dist_manager().current_context().dist_config.enable_eplb: |
| 155 | + from dlblas.layers.moe.eplb import get_eplb_phy2log_metadata_by_layer |
| 156 | + phy2log = get_eplb_phy2log_metadata_by_layer(self.layer_idx) |
| 157 | + expert_per_rank = (self.num_experts + world_size - 1) // world_size |
| 158 | + first_expert = rank * expert_per_rank |
| 159 | + last_expert = min(first_expert + expert_per_rank, self.num_experts) |
| 160 | + sliced_phy2log = phy2log[first_expert:last_expert].tolist() |
| 161 | + return sliced_phy2log |
| 162 | + else: |
| 163 | + return super().ep_expert_list(world_size=world_size, rank=rank) |
| 164 | + |
| 165 | + def forward(self, |
| 166 | + hidden_states: torch.Tensor, |
| 167 | + topk_weights: torch.Tensor, |
| 168 | + topk_ids: torch.LongTensor, |
| 169 | + gate_up_weights: torch.Tensor, |
| 170 | + gate_up_scale: torch.Tensor, |
| 171 | + down_weights: torch.Tensor, |
| 172 | + down_scale: torch.Tensor, |
| 173 | + gate_up_bias: torch.Tensor = None, |
| 174 | + down_bias: torch.Tensor = None, |
| 175 | + expert_list: List[int] = None, |
| 176 | + act_func: Callable = None, |
| 177 | + **kwargs): |
| 178 | + """forward.""" |
| 179 | + hidden_states, topk_weights, topk_ids, split_size = split_inputs_by_attn_tp(hidden_states, topk_weights, |
| 180 | + topk_ids) |
| 181 | + |
| 182 | + topk_weights = self.do_renormalize(topk_weights) |
| 183 | + step_ctx = get_step_ctx_manager().current_context() |
| 184 | + low_latency_mode = step_ctx.is_decoding and self.use_deep_gemm |
| 185 | + moe = self.fusedmoe_build(low_latency_mode) |
| 186 | + out_states = moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, gate_up_scale, down_weights, |
| 187 | + down_scale, expert_list) |
| 188 | + |
| 189 | + out_states = gather_outputs_by_attn_tp(out_states, split_size) |
| 190 | + return out_states |
| 191 | + |
| 192 | + def do_renormalize(self, topk_weights): |
| 193 | + return _renormalize(topk_weights, self.renormalize) |
| 194 | + |
| 195 | + def fusedmoe_build(self, low_latency_mode: bool = False): |
| 196 | + from dlblas.layers.moe.ep_moe import build_deepep_moe |
| 197 | + deepep_moe = build_deepep_moe(low_latency_mode, |
| 198 | + self.ep_size, |
| 199 | + self.ep_group, |
| 200 | + self.num_experts, |
| 201 | + self.hidden_dim, |
| 202 | + self.block_size, |
| 203 | + self.top_k, |
| 204 | + self.out_dtype, |
| 205 | + layer_idx=self.layer_idx, |
| 206 | + chunk_size=16 * 1024) |
| 207 | + |
| 208 | + # patch forward |
| 209 | + _origin_forward = deepep_moe.forward |
| 210 | + _origin_fusedmoe_forward = deepep_moe.fusedmoe_forward |
| 211 | + |
| 212 | + def _patched_forward(*args, **kwargs): |
| 213 | + with monk_deep_gemm(): |
| 214 | + out = _origin_forward(*args, **kwargs) |
| 215 | + return out |
| 216 | + |
| 217 | + def _patched_fusedmoe_forward(*args, **kwargs): |
| 218 | + with monk_deep_gemm(): |
| 219 | + out = _origin_fusedmoe_forward(*args, **kwargs) |
| 220 | + return out |
| 221 | + |
| 222 | + deepep_moe.forward = _patched_forward |
| 223 | + deepep_moe.fusedmoe_forward = _patched_fusedmoe_forward |
| 224 | + |
| 225 | + return deepep_moe |
| 226 | + |
| 227 | + |
| 228 | +class TritonFusedMoEBlockedF8Builder(FusedMoEBlockedF8Builder): |
| 229 | + """Triton fused moe blocked f8 builder.""" |
| 230 | + |
| 231 | + @staticmethod |
| 232 | + def build(top_k: int, |
| 233 | + num_experts: int, |
| 234 | + hidden_dim: int = 1, |
| 235 | + renormalize: bool = False, |
| 236 | + block_size: int = 128, |
| 237 | + ep_size: int = 1, |
| 238 | + ep_group: dist.ProcessGroup = None, |
| 239 | + out_dtype: torch.dtype = torch.float16, |
| 240 | + layer_idx: int = 0, |
| 241 | + custom_gateup_act: bool = False): |
| 242 | + """Build from mlp.""" |
| 243 | + if ep_size > 1: |
| 244 | + assert custom_gateup_act is False, 'Custom gate up activation is not supported in EP MoE.' |
| 245 | + return FusedDeepEpMoEBlockedF8Impl(ep_size=ep_size, |
| 246 | + ep_group=ep_group, |
| 247 | + top_k=top_k, |
| 248 | + num_experts=num_experts, |
| 249 | + hidden_dim=hidden_dim, |
| 250 | + renormalize=renormalize, |
| 251 | + block_size=block_size, |
| 252 | + out_dtype=out_dtype, |
| 253 | + layer_idx=layer_idx) |
| 254 | + else: |
| 255 | + return TritonFusedMoEBlockedF8Impl(top_k=top_k, |
| 256 | + num_experts=num_experts, |
| 257 | + renormalize=renormalize, |
| 258 | + block_size=block_size, |
| 259 | + out_dtype=out_dtype) |
0 commit comments