Skip to content

Commit 2a6c86c

Browse files
committed
feat(rocm): enable TritonScaledMM fallback on ROCm; add CUDA fallback entry
Signed-off-by: Shivam <[email protected]> Signed-off-by: Shivam <[email protected]>
1 parent c528b90 commit 2a6c86c

File tree

3 files changed

+121
-16
lines changed

3 files changed

+121
-16
lines changed

mini_tests/select_triton_rocm.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import os
4+
import sys
5+
import types
6+
7+
os.environ["VLLM_TARGET_DEVICE"] = "rocm"
8+
9+
# Mock amdsmi to simulate ROCm
10+
amdsmi = types.ModuleType("amdsmi")
11+
amdsmi.amdsmi_init = lambda: None
12+
amdsmi.amdsmi_shut_down = lambda: None
13+
amdsmi.amdsmi_get_processor_handles = lambda: [1]
14+
amdsmi.AmdSmiException = Exception
15+
sys.modules["amdsmi"] = amdsmi
16+
sys.modules["vllm._rocm_C"] = types.ModuleType("_rocm_C")
17+
18+
# Prevent CPU platform from conflicting with ROCm on macOS
19+
import vllm.platforms as platforms_module # noqa: E402
20+
21+
_orig_cpu = platforms_module.cpu_platform_plugin
22+
platforms_module.cpu_platform_plugin = (
23+
lambda: None if os.environ.get("VLLM_TARGET_DEVICE") == "rocm" else _orig_cpu()
24+
)
25+
platforms_module.builtin_platform_plugins["cpu"] = platforms_module.cpu_platform_plugin
26+
27+
# Mock torch to look like ROCm
28+
import torch # noqa: E402
29+
30+
torch.version.hip = "5.7.0"
31+
torch.cuda.get_device_properties = lambda d=0: types.SimpleNamespace(
32+
gcnArchName="gfx900", major=9, minor=0
33+
)
34+
torch.cuda.get_device_capability = lambda d=0: (9, 0)
35+
36+
# Stub custom ops
37+
_ops = types.ModuleType("_custom_ops")
38+
for op in [
39+
"cutlass_scaled_mm_supports_fp4",
40+
"cutlass_scaled_fp4_mm",
41+
"scaled_fp4_quant",
42+
"scaled_fp8_quant",
43+
"apply_repetition_penalties",
44+
"merge_attn_states",
45+
"scaled_int8_quant",
46+
]:
47+
setattr(_ops, op, lambda *a, **k: None)
48+
sys.modules["vllm._custom_ops"] = _ops
49+
50+
from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( # noqa: E402, I001
51+
ScaledMMLinearLayerConfig,
52+
choose_scaled_mm_linear_kernel,
53+
)
54+
55+
cfg = ScaledMMLinearLayerConfig(
56+
is_channelwise=False,
57+
is_static_input_scheme=True,
58+
input_symmetric=True,
59+
)
60+
61+
kernel = choose_scaled_mm_linear_kernel(cfg, compute_capability=None)
62+
63+
print("Selected kernel:", kernel.__name__)
64+
assert "TritonScaledMMLinearKernel" in kernel.__name__
65+
print("OK: TritonScaledMMLinearKernel chosen on ROCm fallback.")

vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
# in priority/performance order (when available)
2828
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
2929
PlatformEnum.CPU: [CPUScaledMMLinearKernel],
30-
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
30+
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel, TritonScaledMMLinearKernel],
3131
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
3232
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
3333
}
@@ -68,6 +68,13 @@ def choose_scaled_mm_linear_kernel(
6868
)
6969
continue
7070

71+
# Check if kernel is supported on this platform/capability
72+
if hasattr(kernel, "is_supported"):
73+
supported, reason = kernel.is_supported(compute_capability)
74+
if not supported:
75+
failure_reasons.append(f" {kernel.__name__}: {reason}")
76+
continue
77+
7178
# If the current platform uses compute_capability,
7279
# make sure the kernel supports the compute cability.
7380
if compute_capability is not None:

vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,39 +4,72 @@
44

55
import torch
66

7+
from vllm import _custom_ops as ops
8+
from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa: E501
9+
triton_scaled_mm,
10+
)
11+
from vllm.model_executor.layers.quantization.utils import replace_parameter
712
from vllm.platforms import current_platform
813

9-
from .cutlass import CutlassScaledMMLinearKernel
10-
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
14+
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
1115

1216

13-
class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
17+
class TritonScaledMMLinearKernel(ScaledMMLinearKernel):
1418
@classmethod
1519
def get_min_capability(cls) -> int:
1620
return 75
1721

22+
@classmethod
23+
def is_supported(
24+
cls, compute_capability: int | None = None
25+
) -> tuple[bool, str | None]:
26+
if current_platform.is_cuda_alike():
27+
return True, None
28+
return False, "Requires ROCm or CUDA."
29+
1830
@classmethod
1931
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
20-
if current_platform.is_cpu():
21-
return (
22-
False,
23-
"TritonScaledMMLinearKernel requires Triton which is not "
24-
+ "currently supported on CPU.",
25-
)
2632
if not c.input_symmetric:
27-
return (
28-
False,
29-
"TritonScaledMMLinearKernel only supports symmetric " + "quantization.",
30-
)
33+
return False, "Only symmetric input is supported."
3134
return True, None
3235

3336
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
34-
super().process_weights_after_loading(layer)
37+
weight = getattr(layer, self.w_q_name)
38+
replace_parameter(
39+
layer,
40+
self.w_q_name,
41+
torch.nn.Parameter(weight.t().data, requires_grad=False),
42+
)
43+
44+
# INPUT SCALE
45+
if self.config.is_static_input_scheme:
46+
input_scale = getattr(layer, self.i_s_name)
47+
replace_parameter(
48+
layer,
49+
self.i_s_name,
50+
torch.nn.Parameter(input_scale.max(), requires_grad=False),
51+
)
52+
setattr(layer, self.i_zp_name, None)
53+
else:
54+
setattr(layer, self.i_s_name, None)
55+
setattr(layer, self.i_zp_name, None)
56+
57+
setattr(layer, self.azp_adj_name, None)
3558

3659
def apply_weights(
3760
self,
3861
layer: torch.nn.Module,
3962
x: torch.Tensor,
4063
bias: torch.Tensor | None = None,
4164
) -> torch.Tensor:
42-
return super().apply_weights(layer, x, bias)
65+
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
66+
67+
x_q, x_s, x_zp = ops.scaled_int8_quant(
68+
x.contiguous(), i_s, i_zp, symmetric=True
69+
)
70+
71+
assert x_zp is None, "Triton kernel only supports symmetric quantization"
72+
73+
return triton_scaled_mm(
74+
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
75+
)

0 commit comments

Comments
 (0)