Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions tests/kernels/quantization/test_scaled_mm_kernel_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for ScaledMM kernel selection logic

Run `pytest tests/kernels/quantization/test_scaled_mm_kernel_selection.py`.
"""

import pytest

from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
ScaledMMLinearLayerConfig,
choose_scaled_mm_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
TritonScaledMMLinearKernel,
)
from vllm.platforms import current_platform


@pytest.mark.skipif(not current_platform.is_rocm(), reason="ROCm-specific test")
def test_triton_kernel_selected_on_rocm():
"""Test that TritonScaledMMLinearKernel is selected on ROCm
when Aiter is not available."""
config = ScaledMMLinearLayerConfig(
is_channelwise=False,
is_static_input_scheme=True,
input_symmetric=True,
)

kernel = choose_scaled_mm_linear_kernel(config, compute_capability=None)

assert kernel == TritonScaledMMLinearKernel, (
f"Expected TritonScaledMMLinearKernel on ROCm, got {kernel.__name__}"
)


@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA-specific test")
def test_triton_kernel_available_on_cuda():
"""Test that TritonScaledMMLinearKernel can be selected on CUDA."""
config = ScaledMMLinearLayerConfig(
is_channelwise=False,
is_static_input_scheme=True,
input_symmetric=True,
)

# Triton should be supported on CUDA
supported, reason = TritonScaledMMLinearKernel.is_supported()
assert supported, (
f"TritonScaledMMLinearKernel should be supported on CUDA: {reason}"
)

# It should be able to implement symmetric configs
can_impl, reason = TritonScaledMMLinearKernel.can_implement(config)
assert can_impl, (
f"TritonScaledMMLinearKernel should implement symmetric config: {reason}"
)


def test_triton_kernel_rejects_asymmetric():
"""Test that TritonScaledMMLinearKernel rejects asymmetric quantization."""
config = ScaledMMLinearLayerConfig(
is_channelwise=False,
is_static_input_scheme=True,
input_symmetric=False, # Asymmetric
)

can_impl, reason = TritonScaledMMLinearKernel.can_implement(config)
assert not can_impl, "TritonScaledMMLinearKernel should reject asymmetric config"
assert "symmetric" in reason.lower(), f"Unexpected rejection reason: {reason}"
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import torch

from vllm.platforms import current_platform


@dataclass
class ScaledMMLinearLayerConfig:
Expand All @@ -15,6 +17,22 @@ class ScaledMMLinearLayerConfig:


class ScaledMMLinearKernel(ABC):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if compute_capability is None:
_cc = current_platform.get_device_capability()
if _cc is not None:
compute_capability = _cc.major * 10 + _cc.minor

if compute_capability is not None:
min_cap = cls.get_min_capability()
if min_cap is not None and min_cap > compute_capability:
return False, f"requires capability {min_cap}, got {compute_capability}"

return True, None

@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
Expand All @@ -35,6 +53,7 @@ def __init__(
azp_adj_param_name: str,
) -> None:
assert self.can_implement(c)
assert self.is_supported()
self.config = c
self.w_q_name = w_q_param_name
self.w_s_name = w_s_param_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# in priority/performance order (when available)
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
PlatformEnum.CPU: [CPUScaledMMLinearKernel],
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel, TritonScaledMMLinearKernel],
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
}
Expand Down Expand Up @@ -55,41 +55,25 @@ def choose_scaled_mm_linear_kernel(
type[ScaledMMLinearKernel]: Chosen kernel.
"""

if compute_capability is None:
_cc = current_platform.get_device_capability()
if _cc is not None:
compute_capability = _cc[0] * 10 + _cc[1]

failure_reasons = []
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","):
failure_reasons.append(
f" {kernel.__name__} disabled by environment variable"
)
failure_reasons.append(f"{kernel.__name__}: disabled by env var")
continue

# If the current platform uses compute_capability,
# make sure the kernel supports the compute cability.
if compute_capability is not None:
kernel_min_capability = kernel.get_min_capability()
if (
kernel_min_capability is not None
and kernel_min_capability > compute_capability
):
failure_reasons.append(
f"{kernel.__name__} requires capability "
f"{kernel_min_capability}, current compute capability "
f"is {compute_capability}"
)
continue
is_supported, reason = kernel.is_supported(compute_capability)
if not is_supported:
failure_reasons.append(f"{kernel.__name__}: {reason}")
continue

can_implement, reason = kernel.can_implement(config)
if not can_implement:
failure_reasons.append(f"{kernel.__name__}: {reason}")
continue

can_implement, failure_reason = kernel.can_implement(config)
if can_implement:
return kernel
else:
failure_reasons.append(
f" {kernel.__name__} cannot implement due to: {failure_reason}"
)
return kernel

raise ValueError(
"Failed to find a kernel that can implement the "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
def get_min_capability(cls) -> int:
return 90

@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_rocm():
return False, "Requires ROCm."
return super().is_supported(compute_capability)

@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_rocm():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,15 @@ def get_min_capability(cls) -> int:
return 75

@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_cpu():
return False, "CPUScaledMM requires running on CPU."
return False, "Requires CPU."
return True, None

@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
return True, None

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,15 @@ def get_min_capability(cls) -> int:
return 75

@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_cuda():
return False, "CutlassScaledMM requires running on CUDA."
return False, "Requires CUDA."
return super().is_supported(compute_capability)

@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
return True, None

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,72 @@

import torch

from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa: E501
triton_scaled_mm,
)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.platforms import current_platform

from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig


class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
class TritonScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 75

@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if current_platform.is_cuda_alike():
return True, None
return False, "Requires ROCm or CUDA."

@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if current_platform.is_cpu():
return (
False,
"TritonScaledMMLinearKernel requires Triton which is not "
+ "currently supported on CPU.",
)
if not c.input_symmetric:
return (
False,
"TritonScaledMMLinearKernel only supports symmetric " + "quantization.",
)
return False, "Only symmetric input is supported."
return True, None

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
weight = getattr(layer, self.w_q_name)
replace_parameter(
layer,
self.w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False),
)

# INPUT SCALE
if self.config.is_static_input_scheme:
input_scale = getattr(layer, self.i_s_name)
replace_parameter(
layer,
self.i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False),
)
setattr(layer, self.i_zp_name, None)
else:
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)

setattr(layer, self.azp_adj_name, None)

def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return super().apply_weights(layer, x, bias)
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)

x_q, x_s, x_zp = ops.scaled_int8_quant(
x.contiguous(), i_s, i_zp, symmetric=True
)

assert x_zp is None, "Triton kernel only supports symmetric quantization"

return triton_scaled_mm(
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
)
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@


class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_tpu():
return False, "Requires TPU."
return True, None

@classmethod
def get_min_capability(cls) -> int:
raise NotImplementedError(
"TPU platform does have a concept of compute capability, "
"this method should not be called."
)
raise NotImplementedError("TPU does not have compute capability.")

@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
Expand Down