Skip to content

Commit 9007739

Browse files
yewentao256gemini-code-assist[bot]
authored andcommitted
[Feature] Migrate DeepGEMM API from get_m_alignment_for_contiguous_layout to get_mk_alignment_for_contiguous_layout (vllm-project#26935)
Signed-off-by: yewentao256 <[email protected]> Signed-off-by: Wentao Ye <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent ba7cca0 commit 9007739

File tree

8 files changed

+57
-46
lines changed

8 files changed

+57
-46
lines changed

tests/kernels/moe/test_block_fp8.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222
)
2323
from vllm.platforms import current_platform
2424
from vllm.utils import has_deep_gemm
25-
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
25+
from vllm.utils.deep_gemm import (
26+
get_mk_alignment_for_contiguous_layout,
27+
is_deep_gemm_e8m0_used,
28+
)
2629

2730
dg_available = has_deep_gemm()
2831

29-
if dg_available:
30-
from deep_gemm import get_m_alignment_for_contiguous_layout
31-
3232
if current_platform.get_device_capability() < (9, 0):
3333
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
3434

@@ -218,8 +218,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
218218
torch.manual_seed(seed)
219219

220220
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
221-
block_m = get_m_alignment_for_contiguous_layout()
222-
block_size = [block_m, block_m]
221+
block_size = get_mk_alignment_for_contiguous_layout()
223222
dtype = torch.bfloat16
224223

225224
a = torch.randn((M, K), dtype=dtype) / 10

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,17 @@
66
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
77
from vllm.logger import init_logger
88
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
9-
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import deep_gemm_block_shape
109
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
1110
TopKWeightAndReduceDelegate,
1211
)
1312
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
1413
from vllm.platforms import current_platform
1514
from vllm.triton_utils import tl, triton
16-
from vllm.utils.deep_gemm import fp8_m_grouped_gemm_nt_masked, is_deep_gemm_e8m0_used
15+
from vllm.utils.deep_gemm import (
16+
fp8_m_grouped_gemm_nt_masked,
17+
get_mk_alignment_for_contiguous_layout,
18+
is_deep_gemm_e8m0_used,
19+
)
1720

1821
logger = init_logger(__name__)
1922

@@ -227,7 +230,7 @@ def __init__(
227230
quant_config: Quantization configuration
228231
"""
229232
super().__init__(quant_config)
230-
assert self.block_shape == deep_gemm_block_shape()
233+
assert self.block_shape == get_mk_alignment_for_contiguous_layout()
231234
self.max_num_tokens = max_num_tokens
232235
self.num_dispatchers = num_dispatchers
233236

vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
BatchedDeepGemmExperts,
99
)
1010
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
11-
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import deep_gemm_block_shape
1211
from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts
12+
from vllm.utils.deep_gemm import get_mk_alignment_for_contiguous_layout
1313

1414

1515
class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
@@ -31,7 +31,7 @@ def __init__(
3131
self.allow_deep_gemm = (
3232
allow_deep_gemm
3333
and self.quant_config.use_fp8_w8a8
34-
and self.block_shape == deep_gemm_block_shape()
34+
and self.block_shape == get_mk_alignment_for_contiguous_layout()
3535
)
3636

3737
self.batched_deep_gemm_experts = (

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
)
1414
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
1515
compute_aligned_M,
16-
deep_gemm_block_shape,
1716
deepgemm_moe_permute,
1817
deepgemm_unpermute_and_reduce,
1918
)
@@ -28,14 +27,17 @@
2827
per_token_group_quant_fp8,
2928
)
3029
from vllm.utils import has_deep_gemm
31-
from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous
30+
from vllm.utils.deep_gemm import (
31+
get_mk_alignment_for_contiguous_layout,
32+
m_grouped_fp8_gemm_nt_contiguous,
33+
)
3234
from vllm.utils.functools import run_once
3335

3436
logger = init_logger(__name__)
3537

3638

3739
def _valid_deep_gemm_shape(M: int, N: int, K: int) -> bool:
38-
align = deep_gemm_block_shape()[0]
40+
align = get_mk_alignment_for_contiguous_layout()[0]
3941
return align <= M and N % align == 0 and K % align == 0
4042

4143

@@ -54,7 +56,7 @@ def _valid_deep_gemm(
5456
M = hidden_states.size(0)
5557
_, K, N = w2.size()
5658

57-
align = deep_gemm_block_shape()[0]
59+
align = get_mk_alignment_for_contiguous_layout()[0]
5860

5961
if not _valid_deep_gemm_shape(M, N, K):
6062
logger.debug_once(
@@ -124,7 +126,7 @@ def warmup_deepgemm_gg_contiguous_kernels(
124126

125127
assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts"
126128

127-
block_m = deep_gemm_block_shape()[0]
129+
block_m = get_mk_alignment_for_contiguous_layout()[0]
128130
num_experts = w1.size(0)
129131
device = w1.device
130132

@@ -173,7 +175,7 @@ def _warmup(w: torch.Tensor, w_scale: torch.Tensor):
173175
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
174176
def __init__(self, quant_config: FusedMoEQuantConfig):
175177
super().__init__(quant_config)
176-
assert quant_config.block_shape == deep_gemm_block_shape()
178+
assert quant_config.block_shape == get_mk_alignment_for_contiguous_layout()
177179
assert quant_config.quant_dtype == torch.float8_e4m3fn
178180
assert not quant_config.per_act_token_quant
179181
assert not quant_config.per_out_ch_quant
@@ -255,7 +257,7 @@ def apply(
255257
M=topk_ids.size(0),
256258
num_topk=topk_ids.size(1),
257259
local_num_experts=local_num_experts,
258-
alignment=deep_gemm_block_shape()[0],
260+
alignment=get_mk_alignment_for_contiguous_layout()[0],
259261
expert_tokens_meta=expert_tokens_meta,
260262
)
261263

@@ -364,7 +366,7 @@ def deep_gemm_moe_fp8(
364366
w2_scale=w2_scale,
365367
a1_scale=a1_scale,
366368
a2_scale=a2_scale,
367-
block_shape=deep_gemm_block_shape(),
369+
block_shape=get_mk_alignment_for_contiguous_layout(),
368370
)
369371

370372
fn = mk.FusedMoEModularKernel(

vllm/model_executor/layers/fused_moe/deep_gemm_utils.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,13 @@
55
and updated to fit vllm needs and terminology.
66
"""
77

8-
import functools
9-
108
import torch
119

1210
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
1311
from vllm.model_executor.layers.fused_moe.utils import count_expert_num_tokens
1412
from vllm.triton_utils import tl, triton
1513
from vllm.utils import round_up
16-
17-
18-
@functools.cache
19-
def deep_gemm_block_shape() -> list[int]:
20-
# Lazy import to avoid CUDA initialization problems.
21-
import deep_gemm as dg
22-
23-
block = dg.get_m_alignment_for_contiguous_layout()
24-
return [block, block]
14+
from vllm.utils.deep_gemm import get_mk_alignment_for_contiguous_layout
2515

2616

2717
def expert_num_tokens_round_up_and_sum(
@@ -354,8 +344,7 @@ def deepgemm_moe_permute(
354344
H = aq.size(1)
355345
device = aq.device
356346

357-
block_m = deep_gemm_block_shape()[0]
358-
block_k = deep_gemm_block_shape()[1]
347+
block_m, block_k = get_mk_alignment_for_contiguous_layout()
359348

360349
M_sum = compute_aligned_M(
361350
M=topk_ids.size(0),

vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
_valid_deep_gemm,
1111
_valid_deep_gemm_shape,
1212
)
13-
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import deep_gemm_block_shape
1413
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
15-
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
14+
from vllm.utils.deep_gemm import (
15+
get_mk_alignment_for_contiguous_layout,
16+
is_deep_gemm_e8m0_used,
17+
)
1618

1719

1820
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
@@ -28,7 +30,7 @@ def __init__(
2830
self.allow_deep_gemm = (
2931
allow_deep_gemm
3032
and self.quant_config.use_fp8_w8a8
31-
and self.block_shape == deep_gemm_block_shape()
33+
and self.block_shape == get_mk_alignment_for_contiguous_layout()
3234
)
3335

3436
self.deep_gemm_expert = (

vllm/model_executor/warmup/deep_gemm_warmup.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,19 @@
1212
import vllm.envs as envs
1313
from vllm.distributed.parallel_state import get_dp_group
1414
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
15-
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
16-
compute_aligned_M,
17-
deep_gemm_block_shape,
18-
)
15+
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M
1916
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
2017
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
2118
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
2219
TritonOrDeepGemmExperts,
2320
)
2421
from vllm.model_executor.layers.linear import LinearBase
2522
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
26-
from vllm.utils.deep_gemm import fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous
23+
from vllm.utils.deep_gemm import (
24+
fp8_gemm_nt,
25+
get_mk_alignment_for_contiguous_layout,
26+
m_grouped_fp8_gemm_nt_contiguous,
27+
)
2728

2829

2930
def _generate_optimal_warmup_m_values(
@@ -129,7 +130,7 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool:
129130
"""
130131
Return True if the input module/layer could be processed with DeepGEMM.
131132
"""
132-
block_size = deep_gemm_block_shape()[0]
133+
block_size = get_mk_alignment_for_contiguous_layout()[0]
133134
if not (
134135
isinstance(module, LinearBase)
135136
and isinstance(module.quant_method, Fp8LinearMethod)
@@ -139,7 +140,7 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool:
139140

140141
w, _, block_sizes = _extract_data_from_linear_base_module(module)
141142
return (
142-
block_sizes == deep_gemm_block_shape()
143+
block_sizes == get_mk_alignment_for_contiguous_layout()
143144
and w.ndim == 2
144145
and w.shape[0] % block_size == 0
145146
and w.shape[1] % block_size == 0
@@ -155,7 +156,7 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
155156
if (
156157
moe_quant_config is None
157158
or moe_quant_config.quant_dtype != torch.float8_e4m3fn
158-
or moe_quant_config.block_shape != deep_gemm_block_shape()
159+
or moe_quant_config.block_shape != get_mk_alignment_for_contiguous_layout()
159160
):
160161
return False
161162

@@ -176,7 +177,7 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens:
176177
return
177178

178179
n, k = w.size()
179-
block_m = deep_gemm_block_shape()[0]
180+
block_m = get_mk_alignment_for_contiguous_layout()[0]
180181

181182
device = w.device
182183
a1q = torch.empty((max_tokens, k), device=device, dtype=torch.float8_e4m3fn)
@@ -229,7 +230,7 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
229230

230231
assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts"
231232

232-
block_m = deep_gemm_block_shape()[0]
233+
block_m = get_mk_alignment_for_contiguous_layout()[0]
233234
num_experts = w1.size(0)
234235
device = w1.device
235236

vllm/utils/deep_gemm.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def _missing(*_: Any, **__: Any) -> NoReturn:
7575
_fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None
7676
_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None
7777
_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None
78+
_get_mk_alignment_for_contiguous_layout_impl: Callable[..., Any] | None = None
7879

7980

8081
def _lazy_init() -> None:
@@ -83,7 +84,7 @@ def _lazy_init() -> None:
8384
global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl
8485
global _get_paged_mqa_logits_metadata_impl
8586
global _get_mn_major_tma_aligned_tensor_impl
86-
87+
global _get_mk_alignment_for_contiguous_layout_impl
8788
# fast path
8889
if (
8990
_fp8_gemm_nt_impl is not None
@@ -92,6 +93,7 @@ def _lazy_init() -> None:
9293
or _fp8_mqa_logits_impl is not None
9394
or _fp8_paged_mqa_logits_impl is not None
9495
or _get_paged_mqa_logits_metadata_impl is not None
96+
or _get_mk_alignment_for_contiguous_layout_impl is not None
9597
):
9698
return
9799

@@ -118,6 +120,9 @@ def _lazy_init() -> None:
118120
_get_mn_major_tma_aligned_tensor_impl = getattr(
119121
_dg, "get_mn_major_tma_aligned_tensor", None
120122
)
123+
_get_mk_alignment_for_contiguous_layout_impl = getattr(
124+
_dg, "get_mk_alignment_for_contiguous_layout", None
125+
)
121126

122127

123128
def get_num_sms() -> int:
@@ -126,6 +131,15 @@ def get_num_sms() -> int:
126131
return int(_dg.get_num_sms())
127132

128133

134+
@functools.cache
135+
def get_mk_alignment_for_contiguous_layout() -> list[int]:
136+
_lazy_init()
137+
if _get_mk_alignment_for_contiguous_layout_impl is None:
138+
return _missing()
139+
mk_align_size = _get_mk_alignment_for_contiguous_layout_impl()
140+
return [mk_align_size, mk_align_size]
141+
142+
129143
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
130144
"""Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor"""
131145
_lazy_init()
@@ -338,4 +352,5 @@ def should_use_deepgemm_for_fp8_linear(
338352
"get_num_sms",
339353
"should_use_deepgemm_for_fp8_linear",
340354
"get_col_major_tma_aligned_tensor",
355+
"get_mk_alignment_for_contiguous_layout",
341356
]

0 commit comments

Comments
 (0)