Skip to content

Commit 99018da

Browse files
committed
feat(rocm): enable TritonScaledMM fallback on ROCm; add CUDA fallback entry
Signed-off-by: Shivam <[email protected]> Signed-off-by: Shivam <[email protected]>
1 parent 01653a9 commit 99018da

File tree

3 files changed

+103
-17
lines changed

3 files changed

+103
-17
lines changed

mini_tests/select_triton_rocm.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import os, sys, types
2+
3+
os.environ["VLLM_TARGET_DEVICE"] = "rocm"
4+
5+
# Mock amdsmi to simulate ROCm
6+
amdsmi = types.ModuleType("amdsmi")
7+
amdsmi.amdsmi_init = lambda: None
8+
amdsmi.amdsmi_shut_down = lambda: None
9+
amdsmi.amdsmi_get_processor_handles = lambda: [1]
10+
amdsmi.AmdSmiException = Exception
11+
sys.modules["amdsmi"] = amdsmi
12+
sys.modules["vllm._rocm_C"] = types.ModuleType("_rocm_C")
13+
14+
# Prevent CPU platform from conflicting with ROCm on macOS
15+
import vllm.platforms as platforms_module
16+
_orig_cpu = platforms_module.cpu_platform_plugin
17+
platforms_module.cpu_platform_plugin = lambda: None if os.environ.get("VLLM_TARGET_DEVICE") == "rocm" else _orig_cpu()
18+
platforms_module.builtin_platform_plugins["cpu"] = platforms_module.cpu_platform_plugin
19+
20+
# Mock torch to look like ROCm
21+
import torch
22+
torch.version.hip = "5.7.0"
23+
torch.cuda.get_device_properties = lambda d=0: types.SimpleNamespace(gcnArchName="gfx900", major=9, minor=0)
24+
torch.cuda.get_device_capability = lambda d=0: (9, 0)
25+
26+
# Stub custom ops
27+
_ops = types.ModuleType("_custom_ops")
28+
for op in ["cutlass_scaled_mm_supports_fp4", "cutlass_scaled_fp4_mm", "scaled_fp4_quant",
29+
"scaled_fp8_quant", "apply_repetition_penalties", "merge_attn_states", "scaled_int8_quant"]:
30+
setattr(_ops, op, lambda *a, **k: None)
31+
sys.modules["vllm._custom_ops"] = _ops
32+
33+
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
34+
choose_scaled_mm_linear_kernel, ScaledMMLinearLayerConfig,
35+
)
36+
37+
cfg = ScaledMMLinearLayerConfig(
38+
is_channelwise=False,
39+
is_static_input_scheme=True,
40+
input_symmetric=True,
41+
)
42+
43+
try:
44+
kernel = choose_scaled_mm_linear_kernel(cfg, compute_capability=None)
45+
except TypeError:
46+
from vllm.platforms import PlatformEnum
47+
kernel, _ = choose_scaled_mm_linear_kernel(PlatformEnum.ROCM, cfg, compute_capability=None)
48+
49+
print("Selected kernel:", kernel.__name__)
50+
assert "TritonScaledMMLinearKernel" in kernel.__name__
51+
print("OK: TritonScaledMMLinearKernel chosen on ROCm fallback.")

vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
# in priority/performance order (when available)
2929
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
3030
PlatformEnum.CPU: [CPUScaledMMLinearKernel],
31-
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
31+
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel, TritonScaledMMLinearKernel],
3232
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
3333
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
3434
}
@@ -69,6 +69,13 @@ def choose_scaled_mm_linear_kernel(
6969
)
7070
continue
7171

72+
# Check if kernel is supported on this platform/capability
73+
if hasattr(kernel, "is_supported"):
74+
supported, reason = kernel.is_supported(compute_capability)
75+
if not supported:
76+
failure_reasons.append(f" {kernel.__name__}: {reason}")
77+
continue
78+
7279
# If the current platform uses compute_capability,
7380
# make sure the kernel supports the compute cability.
7481
if compute_capability is not None:

vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,39 +5,67 @@
55

66
import torch
77

8+
from vllm import _custom_ops as ops
9+
from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import (
10+
triton_scaled_mm,
11+
)
12+
from vllm.model_executor.layers.quantization.utils import replace_parameter
813
from vllm.platforms import current_platform
914

10-
from .cutlass import CutlassScaledMMLinearKernel
11-
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
15+
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
1216

1317

14-
class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
18+
class TritonScaledMMLinearKernel(ScaledMMLinearKernel):
1519
@classmethod
1620
def get_min_capability(cls) -> int:
1721
return 75
1822

1923
@classmethod
20-
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
21-
if current_platform.is_cpu():
22-
return (
23-
False,
24-
"TritonScaledMMLinearKernel requires Triton which is not "
25-
+ "currently supported on CPU.",
26-
)
24+
def is_supported(
25+
cls, compute_capability: Optional[int] = None
26+
) -> tuple[bool, Optional[str]]:
27+
if current_platform.is_rocm() or current_platform.is_cuda():
28+
return True, None
29+
return False, "Requires ROCm or CUDA."
30+
31+
@classmethod
32+
def can_implement(
33+
cls, c: ScaledMMLinearLayerConfig
34+
) -> tuple[bool, Optional[str]]:
2735
if not c.input_symmetric:
28-
return (
29-
False,
30-
"TritonScaledMMLinearKernel only supports symmetric " + "quantization.",
31-
)
36+
return False, "Only symmetric input is supported."
3237
return True, None
3338

3439
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
35-
super().process_weights_after_loading(layer)
40+
# INPUT SCALE
41+
if self.config.is_static_input_scheme:
42+
input_scale = getattr(layer, self.i_s_name)
43+
replace_parameter(
44+
layer,
45+
self.i_s_name,
46+
torch.nn.Parameter(input_scale.max(), requires_grad=False),
47+
)
48+
setattr(layer, self.i_zp_name, None)
49+
else:
50+
setattr(layer, self.i_s_name, None)
51+
setattr(layer, self.i_zp_name, None)
52+
53+
setattr(layer, self.azp_adj_name, None)
3654

3755
def apply_weights(
3856
self,
3957
layer: torch.nn.Module,
4058
x: torch.Tensor,
4159
bias: Optional[torch.Tensor] = None,
4260
) -> torch.Tensor:
43-
return super().apply_weights(layer, x, bias)
61+
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
62+
63+
x_q, x_s, x_zp = ops.scaled_int8_quant(
64+
x.contiguous(), i_s, i_zp, symmetric=True
65+
)
66+
67+
assert x_zp is None, "Triton kernel only supports symmetric quantization"
68+
69+
return triton_scaled_mm(
70+
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
71+
)

0 commit comments

Comments
 (0)