diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index 7448bb122152..34653e58616f 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -14,6 +14,9 @@ from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + check_aiter_fp8_linear_support, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, @@ -75,6 +78,85 @@ def register(self, pm_pass: PatternMatcherPass): raise NotImplementedError +if check_aiter_fp8_linear_support(): + import aiter as rocm_aiter + from aiter.ops.triton.activation import act_mul_and_fp8_group_quant + + from vllm.utils import direct_register_custom_op + + rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8 + rocm_aiter_fp8_quant_group_size = 128 + + def _rocm_aiter_act_mul_and_fp8_group_quant_impl( + x: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + return act_mul_and_fp8_group_quant( + x, + activation="silu", + group_size=rocm_aiter_fp8_quant_group_size, + dtype_quant=rocm_aiter_fp8_dtype, + ) + + def _rocm_aiter_act_mul_and_fp8_group_quant_fake( + x: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + M, N = x.shape + assert N % 2 == 0 + N_half = N // 2 + x_fp8 = torch.empty((M, N_half), dtype=rocm_aiter_fp8_dtype, device=x.device) + out_bs = torch.empty( + ( + M, + (N_half + rocm_aiter_fp8_quant_group_size - 1) + // rocm_aiter_fp8_quant_group_size, + ), + dtype=torch.float32, + device=x.device, + ) + return x_fp8, out_bs + + direct_register_custom_op( + op_name="rocm_aiter_act_mul_and_fp8_group_quant", + op_func=_rocm_aiter_act_mul_and_fp8_group_quant_impl, + mutates_args=[], + fake_impl=_rocm_aiter_act_mul_and_fp8_group_quant_fake, + dispatch_key=current_platform.dispatch_key, + ) + AITER_BLOCK_QUANT_OP = torch.ops.vllm.rocm_aiter_per1x128_quant.default + FUSED_SILU_MUL_QUANT_OP = ( + torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default + ) + + class AiterSiluMulFp8BlockQuantPattern(ActivationQuantPattern): + def __init__(self): + pass + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + input: torch.Tensor, + result_silu_mul: torch.Tensor, + ): + at1 = auto_functionalized( + SILU_MUL_OP, result=result_silu_mul, input=input + ) + at2 = AITER_BLOCK_QUANT_OP(x=at1[1]) + return at2[0], at2[1] + + def replacement( + input: torch.Tensor, + result_silu_mul: torch.Tensor, + ): + at = FUSED_SILU_MUL_QUANT_OP(x=input) + return at[0], at[1] + + inputs = [ + empty_bf16(5, 4), # input + empty_bf16(5, 4), # result_silu_mul + ] + + register_replacement(pattern, replacement, inputs, fwd_only, pm_pass) + + class SiluMulFp8StaticQuantPattern(ActivationQuantPattern): """ Fusion for SiluMul+Fp8StaticQuant Pattern @@ -198,6 +280,10 @@ def __init__(self, config: VllmConfig): pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern() pattern_silu_mul_nvfp4.register(self.patterns) + if check_aiter_fp8_linear_support(): + pattern_silu_mul_aiter_block_fp8 = AiterSiluMulFp8BlockQuantPattern() + pattern_silu_mul_aiter_block_fp8.register(self.patterns) + self.dump_patterns(config, self.patterns) @VllmInductorPass.time_and_log @@ -206,9 +292,11 @@ def __call__(self, graph: torch.fx.Graph): logger.debug("Replaced %s patterns", self.matched_count) def uuid(self): - return VllmInductorPass.hash_source( - self, + fusion_patterns = [ ActivationQuantPattern, SiluMulFp8StaticQuantPattern, SiluMulNvfp4QuantPattern, - ) + ] + if check_aiter_fp8_linear_support(): + fusion_patterns.append(AiterSiluMulFp8BlockQuantPattern) + return VllmInductorPass.hash_source(self, *fusion_patterns) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index df54e94a03db..5a15e528f535 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -11,6 +11,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import is_rocm_aiter_rmsnorm_enabled from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, @@ -88,6 +89,22 @@ def __str__(self): } +if is_rocm_aiter_rmsnorm_enabled(): + AITER_RMS_GROUP_QUANT_OP = torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default + AITER_RMS_ADD_GROUP_QUANT_OP = ( + torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default + ) + + AITER_BLOCK_QUANT_OP = torch.ops.vllm.rocm_aiter_per1x128_quant.default + AITER_RMS_OP = torch.ops.vllm.rocm_aiter_rms_norm.default + AITER_RMS_ADD_OP = torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default + + import aiter as rocm_aiter + + rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8 + rocm_aiter_fp8_quant_group_size = 128 + + class RMSNormQuantPattern: def __init__(self, epsilon: float, key: FusedRMSQuantKey): self.epsilon = epsilon @@ -382,6 +399,90 @@ def replacement( ) +if is_rocm_aiter_rmsnorm_enabled(): + + class AiterRMSGroupQuantFP8Pattern: + def __init__(self, epsilon: float, quant_dtype: torch.dtype): + self.epsilon = epsilon + self.quant_dtype = quant_dtype + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + input: torch.Tensor, + weight: torch.Tensor, # result_rms: torch.Tensor, + ): + at1 = AITER_RMS_OP( + x=input, weight=weight, variance_epsilon=self.epsilon + ) + + at2 = AITER_BLOCK_QUANT_OP(x=at1[0]) + + return at2[0], at2[1] + + def replacement( + input: torch.Tensor, + weight: torch.Tensor, + ): + at = AITER_RMS_GROUP_QUANT_OP( + x=input, residual=None, weight=weight, variance_epsilon=self.epsilon + ) + + return at[0], at[1] + + inputs = [ + empty_bf16(5, 4), # input + empty_bf16(1, 5), # weight + ] + + pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) + + class AiterFusedAddRMSGroupQuantPattern: + def __init__(self, epsilon: float, quant_dtype: torch.dtype): + self.epsilon = epsilon + self.quant_dtype = quant_dtype + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + ): + at1 = AITER_RMS_ADD_OP( + x=input, + residual=residual, + weight=weight, + variance_epsilon=self.epsilon, + ) + + at2 = AITER_BLOCK_QUANT_OP(x=at1[0]) + + # result, scale, residual + return at2[0], at2[1], at1[1] + + def replacement( + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + ): + at = AITER_RMS_ADD_GROUP_QUANT_OP( + x=input, + residual=residual, + weight=weight, + variance_epsilon=self.epsilon, + ) + + # result, scale, residual + return at[0], at[1], at[2] + + inputs = [ + empty_bf16(5, 4), # input + empty_bf16(5, 4), # residual + empty_bf16(1, 5), # weight + ] + + pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) + + class RMSNormQuantFusionPass(VllmPatternMatcherPass): """ This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op. @@ -413,6 +514,14 @@ def __init__(self, config: VllmConfig): self.patterns ) + if is_rocm_aiter_rmsnorm_enabled(): + # Fuse rms_norm + dynamic group fp8 quant + AiterRMSGroupQuantFP8Pattern(epsilon, FP8_DTYPE).register(self.patterns) + + AiterFusedAddRMSGroupQuantPattern(epsilon, FP8_DTYPE).register( + self.patterns + ) + self.dump_patterns(config, self.patterns) @VllmInductorPass.time_and_log @@ -421,11 +530,15 @@ def __call__(self, graph: fx.Graph): logger.debug("Replaced %s patterns", self.matched_count) def uuid(self) -> Any: - return self.hash_source( - self, + fusion_patterns = [ RMSNormQuantPattern, RMSNormStaticQuantPattern, RMSNormDynamicQuantPattern, FusedAddRMSNormStaticQuantPattern, FusedAddRMSNormDynamicQuantPattern, - ) + ] + if is_rocm_aiter_rmsnorm_enabled(): + fusion_patterns.extend( + [AiterRMSGroupQuantFP8Pattern, AiterFusedAddRMSGroupQuantPattern] + ) + return self.hash_source(self, *fusion_patterns) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 135fbda2d540..cd5827570834 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -16,6 +16,14 @@ def is_rocm_aiter_rmsnorm_enabled() -> bool: return envs.VLLM_ROCM_USE_AITER_RMSNORM and envs.VLLM_ROCM_USE_AITER +if current_platform.is_rocm() and is_rocm_aiter_rmsnorm_enabled(): + import aiter as rocm_aiter + from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant + + rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8 + rocm_aiter_fp8_quant_group_size = 128 + + def rms_norm( x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float ) -> torch.Tensor: @@ -99,6 +107,46 @@ def rocm_aiter_rmsnorm2d_fwd_with_add_impl( return output, residual_out +def rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + (x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant( + x, + weight, + variance_epsilon, + None, + None, + None, + group_size=rocm_aiter_fp8_quant_group_size, + dtype_quant=rocm_aiter_fp8_dtype, + res1=residual, + ) + return (x_quant, x_quant_scales, res) + + +def rocm_aiter_rmsnorm_fp8_group_quant_impl( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: + (x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant( + x, + weight, + variance_epsilon, + None, + None, + None, + group_size=rocm_aiter_fp8_quant_group_size, + dtype_quant=rocm_aiter_fp8_dtype, + res1=residual, + ) + return (x_quant, x_quant_scales) + + def rocm_aiter_rms_norm_fake( x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float ) -> torch.Tensor: @@ -114,17 +162,72 @@ def rocm_aiter_rmsnorm2d_fwd_with_add_fake( return torch.empty_like(x), torch.empty_like(residual) +def rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + M, N = x.shape + scale_shape = ( + M, + (N + rocm_aiter_fp8_quant_group_size - 1) // rocm_aiter_fp8_quant_group_size, + ) + return ( + torch.empty_like(x, dtype=rocm_aiter_fp8_dtype, device=x.device), + torch.empty(scale_shape, dtype=torch.float32, device=x.device), + torch.empty_like(residual, device=residual.device), + ) + + +def rocm_aiter_rmsnorm_fp8_group_quant_fake( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: + M, N = x.shape + scale_shape = ( + M, + (N + rocm_aiter_fp8_quant_group_size - 1) // rocm_aiter_fp8_quant_group_size, + ) + return ( + torch.empty_like(x, dtype=rocm_aiter_fp8_dtype, device=x.device), + torch.empty(scale_shape, dtype=torch.float32, device=x.device), + ) + + if current_platform.is_rocm(): direct_register_custom_op( op_name="rocm_aiter_rms_norm", op_func=rocm_aiter_rms_norm_impl, + mutates_args=[], fake_impl=rocm_aiter_rms_norm_fake, + dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( op_name="rocm_aiter_rmsnorm2d_fwd_with_add", op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl, + mutates_args=[], fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm_fp8_group_quant", + op_func=rocm_aiter_rmsnorm_fp8_group_quant_impl, + mutates_args=[], + fake_impl=rocm_aiter_rmsnorm_fp8_group_quant_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm_with_add_fp8_group_quant", + op_func=rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl, + mutates_args=[], + fake_impl=rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake, + dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 51af40a11914..6c8f2ba7c2a9 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -76,9 +76,12 @@ def rocm_aiter_gemm_w8a8_blockscale_impl( block_size: list[int], output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: - import aiter as rocm_aiter - - return rocm_aiter.gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) + # MI300's fp8nuz should be enough to detect if we call ck vs triton + if current_platform.is_fp8_fnuz(): + from aiter import gemm_a8w8_blockscale + else: + from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale + return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) def rocm_aiter_gemm_w8a8_blockscale_fake( @@ -101,16 +104,33 @@ def rocm_aiter_gemm_w8a8_blockscale_fake( op_func=rocm_aiter_gemm_w8a8_blockscale_impl, fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake, ) - if ( - envs.VLLM_ROCM_USE_AITER - and envs.VLLM_ROCM_USE_AITER_LINEAR - and current_platform.is_fp8_fnuz() - ): + if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR: import aiter as rocm_aiter from aiter import get_hip_quant aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) + def aiter_per1x128_quant_impl( + x: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + return aiter_per1x128_quant(x, quant_dtype=rocm_aiter.dtypes.fp8) + + def aiter_per1x128_quant_fake( + x: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + shape, device = x.shape, x.device + y = torch.empty(shape, dtype=rocm_aiter.dtypes.fp8, device=device) + scale = torch.empty( + (*shape[:-1], shape[-1] // 128), dtype=torch.float32, device=device + ) + return y, scale + + direct_register_custom_op( + op_name="rocm_aiter_per1x128_quant", + op_func=aiter_per1x128_quant_impl, + fake_impl=aiter_per1x128_quant_fake, + ) + # TODO we should be able to change the type of block_size to GroupShape # after we resolve GroupShape compilation issue @@ -352,8 +372,8 @@ def _run_aiter( weight_scale: torch.Tensor, ) -> torch.Tensor: assert self.act_quant_group_shape == GroupShape(1, 128) - q_input, input_scale = aiter_per1x128_quant( - input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8 + q_input, input_scale = torch.ops.vllm.rocm_aiter_per1x128_quant( + input_2d.contiguous() ) return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale( q_input, @@ -945,7 +965,6 @@ def check_aiter_fp8_linear_support() -> bool: current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR - and current_platform.is_fp8_fnuz() )