Skip to content

Commit d0865d4

Browse files
committed
add ep support
1 parent d767bca commit d0865d4

File tree

17 files changed

+1339
-1034
lines changed

17 files changed

+1339
-1034
lines changed

lmdeploy/pytorch/backends/cuda/moe.py

Lines changed: 0 additions & 627 deletions
This file was deleted.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from .blocked_fp8 import TritonFusedMoEBlockedF8Builder # noqa: F401
3+
from .default import TritonFusedMoEBuilder # noqa: F401
4+
from .w8a8 import TritonFusedMoEW8A8Builder # noqa: F401
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
3+
import contextlib
4+
from typing import Callable, List
5+
6+
import torch
7+
import torch.distributed as dist
8+
9+
from lmdeploy.pytorch.backends.moe import FusedMoEBlockedF8Builder, FusedMoEBlockedF8Impl
10+
from lmdeploy.pytorch.distributed import get_dist_manager
11+
from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import fused_moe_blocked_fp8
12+
from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8
13+
from lmdeploy.pytorch.kernels.cuda.fused_moe import _renormalize
14+
from lmdeploy.pytorch.model_inputs import get_step_ctx_manager
15+
from lmdeploy.utils import get_logger
16+
17+
from .ep_utils import gather_outputs_by_attn_tp, split_inputs_by_attn_tp
18+
19+
logger = get_logger('lmdeploy')
20+
21+
22+
class TritonFusedMoEBlockedF8Impl(FusedMoEBlockedF8Impl):
23+
"""Triton fused moe blocked f8 implementation."""
24+
25+
def __init__(self,
26+
top_k: int,
27+
num_experts: int,
28+
renormalize: bool = False,
29+
block_size: int = 128,
30+
out_dtype: torch.dtype = torch.float16):
31+
self.num_experts = num_experts
32+
self.top_k = top_k
33+
self.renormalize = renormalize
34+
self.block_size = block_size
35+
self.out_dtype = out_dtype
36+
37+
def ep_expert_list(self, world_size: int, rank: int):
38+
"""Experts list of current rank."""
39+
num_experts = self.num_experts
40+
expert_per_rank = (num_experts + world_size - 1) // world_size
41+
first_expert = rank * expert_per_rank
42+
last_expert = min(first_expert + expert_per_rank, num_experts)
43+
return list(range(first_expert, last_expert))
44+
45+
def forward(self,
46+
hidden_states: torch.Tensor,
47+
topk_weights: torch.Tensor,
48+
topk_ids: torch.LongTensor,
49+
gate_up_weights: torch.Tensor,
50+
gate_up_scale: torch.Tensor,
51+
down_weights: torch.Tensor,
52+
down_scale: torch.Tensor,
53+
gate_up_bias: torch.Tensor = None,
54+
down_bias: torch.Tensor = None,
55+
expert_list: List[int] = None,
56+
act_func: Callable = None):
57+
"""forward."""
58+
input_size = hidden_states.shape
59+
hidden_states = hidden_states.flatten(0, -2)
60+
input_quant, input_scale = quant_fp8(hidden_states, self.block_size, dtype=gate_up_weights.dtype)
61+
62+
expert_offset = 0
63+
num_experts = None
64+
if expert_list is not None and len(expert_list) != self.num_experts:
65+
expert_offset = expert_list[0]
66+
num_experts = self.num_experts
67+
output = fused_moe_blocked_fp8(input_quant,
68+
input_scale,
69+
gate_up_weights,
70+
gate_up_scale,
71+
down_weights,
72+
down_scale,
73+
topk_weights=topk_weights,
74+
topk_ids=topk_ids,
75+
topk=self.top_k,
76+
w1_bias=gate_up_bias,
77+
w2_bias=down_bias,
78+
out_dtype=hidden_states.dtype,
79+
expert_offset=expert_offset,
80+
num_experts=num_experts,
81+
renormalize=self.renormalize,
82+
act_func=act_func)
83+
output = output.unflatten(0, input_size[:-1])
84+
return output
85+
86+
87+
@contextlib.contextmanager
88+
def monk_deep_gemm():
89+
from dlblas.kernels.fused_moe_v3 import use_deep_gemm
90+
if use_deep_gemm:
91+
yield
92+
return
93+
94+
# patch deep_gemm
95+
import deep_gemm
96+
import dlblas
97+
98+
from lmdeploy.pytorch.third_party import deep_gemm as patched_deep_gemm
99+
func0_ = getattr(deep_gemm, 'get_col_major_tma_aligned_tensor', None)
100+
func1_ = getattr(deep_gemm, 'm_grouped_gemm_fp8_fp8_bf16_nt_masked', None)
101+
deep_gemm.get_col_major_tma_aligned_tensor = patched_deep_gemm.get_mn_major_tma_aligned_tensor
102+
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked = patched_deep_gemm.m_grouped_fp8_gemm_nt_masked
103+
104+
# patch dlblas
105+
dlblas.kernels.fused_moe_v3.use_deep_gemm = True
106+
dlblas.kernels.fused_moe_v3.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous = \
107+
patched_deep_gemm.m_grouped_fp8_gemm_nt_contiguous
108+
yield
109+
110+
# unpatch dlblas
111+
dlblas.kernels.fused_moe_v3.use_deep_gemm = False
112+
113+
# unpatch deep_gemm
114+
if func0_ is not None:
115+
deep_gemm.get_col_major_tma_aligned_tensor = func0_
116+
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked = func1_
117+
else:
118+
del deep_gemm.get_col_major_tma_aligned_tensor
119+
del deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked
120+
121+
122+
class FusedDeepEpMoEBlockedF8Impl(TritonFusedMoEBlockedF8Impl):
123+
124+
def __init__(self,
125+
ep_size: int,
126+
ep_group: dist.ProcessGroup,
127+
top_k: int,
128+
num_experts: int,
129+
hidden_dim: int,
130+
renormalize: bool = False,
131+
block_size: int = 128,
132+
out_dtype: torch.dtype = torch.bfloat16,
133+
layer_idx: int = 0):
134+
super().__init__(top_k, num_experts, renormalize, block_size, out_dtype)
135+
self.num_experts = num_experts
136+
self.ep_size = ep_size
137+
self.ep_group = ep_group
138+
self.hidden_dim = hidden_dim
139+
self.block_size = block_size
140+
self.out_dtype = out_dtype
141+
self.layer_idx = layer_idx
142+
try:
143+
import deep_gemm # noqa: F401
144+
self.use_deep_gemm = True
145+
except ImportError:
146+
self.use_deep_gemm = False
147+
logger.warning('For higher performance, please install DeepGEMM https://github.com/deepseek-ai/DeepGEMM')
148+
149+
# pre-allocate buffer
150+
self.fusedmoe_build(True)
151+
152+
def ep_expert_list(self, world_size: int, rank: int):
153+
"""Experts list of current rank."""
154+
if get_dist_manager().current_context().dist_config.enable_eplb:
155+
from dlblas.layers.moe.eplb import get_eplb_phy2log_metadata_by_layer
156+
phy2log = get_eplb_phy2log_metadata_by_layer(self.layer_idx)
157+
expert_per_rank = (self.num_experts + world_size - 1) // world_size
158+
first_expert = rank * expert_per_rank
159+
last_expert = min(first_expert + expert_per_rank, self.num_experts)
160+
sliced_phy2log = phy2log[first_expert:last_expert].tolist()
161+
return sliced_phy2log
162+
else:
163+
return super().ep_expert_list(world_size=world_size, rank=rank)
164+
165+
def forward(self,
166+
hidden_states: torch.Tensor,
167+
topk_weights: torch.Tensor,
168+
topk_ids: torch.LongTensor,
169+
gate_up_weights: torch.Tensor,
170+
gate_up_scale: torch.Tensor,
171+
down_weights: torch.Tensor,
172+
down_scale: torch.Tensor,
173+
gate_up_bias: torch.Tensor = None,
174+
down_bias: torch.Tensor = None,
175+
expert_list: List[int] = None,
176+
act_func: Callable = None,
177+
**kwargs):
178+
"""forward."""
179+
hidden_states, topk_weights, topk_ids, split_size = split_inputs_by_attn_tp(hidden_states, topk_weights,
180+
topk_ids)
181+
182+
topk_weights = self.do_renormalize(topk_weights)
183+
step_ctx = get_step_ctx_manager().current_context()
184+
low_latency_mode = step_ctx.is_decoding and self.use_deep_gemm
185+
moe = self.fusedmoe_build(low_latency_mode)
186+
out_states = moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, gate_up_scale, down_weights,
187+
down_scale, expert_list)
188+
189+
out_states = gather_outputs_by_attn_tp(out_states, split_size)
190+
return out_states
191+
192+
def do_renormalize(self, topk_weights):
193+
return _renormalize(topk_weights, self.renormalize)
194+
195+
def fusedmoe_build(self, low_latency_mode: bool = False):
196+
from dlblas.layers.moe.ep_moe import build_deepep_moe
197+
deepep_moe = build_deepep_moe(low_latency_mode,
198+
self.ep_size,
199+
self.ep_group,
200+
self.num_experts,
201+
self.hidden_dim,
202+
self.block_size,
203+
self.top_k,
204+
self.out_dtype,
205+
layer_idx=self.layer_idx,
206+
chunk_size=16 * 1024)
207+
208+
# patch forward
209+
_origin_forward = deepep_moe.forward
210+
_origin_fusedmoe_forward = deepep_moe.fusedmoe_forward
211+
212+
def _patched_forward(*args, **kwargs):
213+
with monk_deep_gemm():
214+
out = _origin_forward(*args, **kwargs)
215+
return out
216+
217+
def _patched_fusedmoe_forward(*args, **kwargs):
218+
with monk_deep_gemm():
219+
out = _origin_fusedmoe_forward(*args, **kwargs)
220+
return out
221+
222+
deepep_moe.forward = _patched_forward
223+
deepep_moe.fusedmoe_forward = _patched_fusedmoe_forward
224+
225+
return deepep_moe
226+
227+
228+
class TritonFusedMoEBlockedF8Builder(FusedMoEBlockedF8Builder):
229+
"""Triton fused moe blocked f8 builder."""
230+
231+
@staticmethod
232+
def build(top_k: int,
233+
num_experts: int,
234+
hidden_dim: int = 1,
235+
renormalize: bool = False,
236+
block_size: int = 128,
237+
ep_size: int = 1,
238+
ep_group: dist.ProcessGroup = None,
239+
out_dtype: torch.dtype = torch.float16,
240+
layer_idx: int = 0,
241+
custom_gateup_act: bool = False):
242+
"""Build from mlp."""
243+
if ep_size > 1:
244+
assert custom_gateup_act is False, 'Custom gate up activation is not supported in EP MoE.'
245+
return FusedDeepEpMoEBlockedF8Impl(ep_size=ep_size,
246+
ep_group=ep_group,
247+
top_k=top_k,
248+
num_experts=num_experts,
249+
hidden_dim=hidden_dim,
250+
renormalize=renormalize,
251+
block_size=block_size,
252+
out_dtype=out_dtype,
253+
layer_idx=layer_idx)
254+
else:
255+
return TritonFusedMoEBlockedF8Impl(top_k=top_k,
256+
num_experts=num_experts,
257+
renormalize=renormalize,
258+
block_size=block_size,
259+
out_dtype=out_dtype)

0 commit comments

Comments
 (0)