Skip to content

Commit 298dc89

Browse files
committed
Add RMS Norm Quant fp8 Helion Kernel
Signed-off-by: Yanan Cao <[email protected]>
1 parent be15f51 commit 298dc89

File tree

1 file changed

+291
-0
lines changed

1 file changed

+291
-0
lines changed
Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Helion custom op for RMSNorm with FP8 quantization.
5+
"""
6+
7+
import helion
8+
import helion.language as hl
9+
import torch
10+
11+
from vllm.compilation.helion.benchmark import KernelBenchmark
12+
from vllm.compilation.helion.custom_op import HelionCustomOp
13+
from vllm.model_executor.custom_op import CustomOp
14+
15+
16+
@torch.library.custom_op(
17+
"my_helion_lib::rms_norm_fp8",
18+
mutates_args=(),
19+
device_types="cuda",
20+
)
21+
@helion.kernel(
22+
config=helion.Config(
23+
block_sizes=[1],
24+
indexing=[
25+
"tensor_descriptor",
26+
"pointer",
27+
"pointer",
28+
"pointer",
29+
"pointer",
30+
"tensor_descriptor",
31+
"pointer",
32+
"pointer",
33+
],
34+
load_eviction_policies=["", "first", "", "", "first", "last"],
35+
num_stages=7,
36+
num_warps=8,
37+
pid_type="flat",
38+
range_flattens=[None],
39+
range_multi_buffers=[None],
40+
range_num_stages=[0],
41+
range_unroll_factors=[0],
42+
range_warp_specializes=[],
43+
reduction_loops=[None],
44+
),
45+
static_shapes=False,
46+
)
47+
def _rms_norm_fp8_helion_kernel(
48+
input: torch.Tensor,
49+
weight: torch.Tensor,
50+
scale: torch.Tensor,
51+
epsilon: float,
52+
) -> torch.Tensor:
53+
"""
54+
Helion kernel for RMSNorm with FP8 quantization.
55+
56+
Operation: quantize_fp8(RMSNorm(input, weight, epsilon))
57+
58+
Algorithm (matching CUDA reference exactly):
59+
1. variance = sum(x^2) / hidden_size (per token/row)
60+
2. norm_factor = rsqrt(variance + epsilon)
61+
3. normalized = (input * norm_factor).to(input.dtype) * weight
62+
4. quantized = normalized * (1 / scale)
63+
64+
Args:
65+
input (Tensor): Input tensor with shape [batch, hidden_size]
66+
weight (Tensor): Weight tensor with shape [hidden_size]
67+
scale (Tensor): Scalar scale factor for FP8 quantization
68+
epsilon (float): Epsilon value for numerical stability
69+
70+
Returns:
71+
Tensor: Output tensor with same shape as input and dtype float8_e4m3fn
72+
"""
73+
m, n = input.size()
74+
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {n}"
75+
assert scale.numel() == 1, "Scale must be a scalar Tensor"
76+
77+
out = torch.empty_like(input, dtype=torch.float8_e4m3fn)
78+
79+
# Tile over batch dimension only (following Helion rms_norm example)
80+
for tile_m in hl.tile(m):
81+
scale_val = hl.load(scale, [0])
82+
inv_scale = 1.0 / scale_val
83+
84+
input_row = input[tile_m, :].to(torch.float32)
85+
86+
# variance = sum(x^2) / hidden_size in fp32
87+
x_squared = input_row * input_row
88+
variance = torch.mean(x_squared, dim=-1)
89+
90+
# normalization factor
91+
inv_rms = torch.rsqrt(variance + epsilon)
92+
93+
# out_norm = ((scalar_t)(x * s_variance)) * src2.val[j];
94+
normalized = (input_row * inv_rms[:, None]).to(input.dtype) # fp32 → bf16
95+
weighted = (normalized * weight[:]).to(torch.float32) # bf16*bf16 → fp32
96+
97+
# Quantize to FP8
98+
result_scaled = weighted * inv_scale
99+
out[tile_m, :] = result_scaled.to(out.dtype)
100+
101+
return out
102+
103+
104+
@_rms_norm_fp8_helion_kernel.register_fake
105+
def _rms_norm_fp8_helion_kernel_fake(
106+
input: torch.Tensor,
107+
weight: torch.Tensor,
108+
scale: torch.Tensor,
109+
epsilon: float,
110+
) -> torch.Tensor:
111+
"""
112+
Fake/meta implementation for rms_norm_fp8 Helion kernel.
113+
Defines the input/output shape relationship without actual computation.
114+
115+
Shape contract:
116+
- input: [..., hidden_size]
117+
- weight: [hidden_size]
118+
- scale: scalar (numel == 1)
119+
- epsilon: float
120+
- returns: [..., hidden_size] with dtype float8_e4m3fn
121+
"""
122+
return torch.empty_like(input, dtype=torch.float8_e4m3fn)
123+
124+
125+
# Now define the vLLM CustomOp wrapper
126+
@CustomOp.register("rms_norm_fp8_helion")
127+
class RMSNormFp8Helion(HelionCustomOp):
128+
"""
129+
RMSNorm with FP8 quantization using Helion.
130+
131+
This operation computes:
132+
quantize_fp8(RMSNorm(input, weight, epsilon))
133+
134+
The operation combines:
135+
1. Compute RMS (root mean square): rsqrt(mean(x^2) + epsilon)
136+
2. Normalize input by RMS
137+
3. Apply elementwise multiplication with weight
138+
4. Quantize result to FP8 format
139+
140+
Shapes:
141+
input: (num_tokens, hidden_size)
142+
weight: (hidden_size,)
143+
scale: (1,) - scalar scale factor for FP8 quantization
144+
output: (num_tokens, hidden_size) with dtype float8_e4m3fn
145+
"""
146+
147+
def forward_helion(
148+
self,
149+
input: torch.Tensor,
150+
weight: torch.Tensor,
151+
scale: torch.Tensor,
152+
epsilon: float = 1e-5,
153+
) -> torch.Tensor:
154+
"""
155+
Helion kernel implementation.
156+
157+
Args:
158+
input: Input tensor with shape (num_tokens, hidden_size)
159+
weight: Weight tensor with shape (hidden_size,)
160+
scale: Scale tensor (scalar) for FP8 quantization
161+
epsilon: Epsilon for numerical stability
162+
163+
Returns:
164+
Output tensor with shape (num_tokens, hidden_size) and dtype
165+
float8_e4m3fn
166+
"""
167+
return torch.ops.my_helion_lib.rms_norm_fp8(input, weight, scale, epsilon)
168+
169+
170+
class RMSNormFp8Benchmark(KernelBenchmark):
171+
"""
172+
Benchmark harness for RMSNorm-FP8 kernel.
173+
174+
This class provides test configurations and benchmark utilities
175+
for the RMSNormFp8Helion custom op.
176+
"""
177+
178+
benchmark_name = "rms_norm_fp8"
179+
180+
def __init__(self):
181+
"""Initialize the benchmark."""
182+
self.op = RMSNormFp8Helion()
183+
self.epsilon = 1e-5
184+
185+
def get_quick_test_shapes(self) -> list[tuple[list[tuple], torch.dtype]]:
186+
"""
187+
Get test configurations for quick smoke testing.
188+
189+
Returns:
190+
List of (shapes, dtype) tuples.
191+
Input shapes are (num_tokens, hidden_size).
192+
"""
193+
return [
194+
(
195+
[
196+
(1, 4096),
197+
(256, 4096),
198+
(1024, 4096),
199+
(1, 8192),
200+
(256, 8192),
201+
(1024, 8192),
202+
],
203+
torch.bfloat16,
204+
),
205+
]
206+
207+
def get_full_test_shapes(self) -> list[tuple[list[tuple], torch.dtype]]:
208+
"""
209+
Get test configurations for comprehensive benchmarking.
210+
211+
Returns:
212+
List of (shapes, dtype) tuples.
213+
Input shapes are (num_tokens, hidden_size).
214+
"""
215+
num_tokens_list = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
216+
hidden_sizes = [512, 1024, 2048, 4096, 5504, 6912, 7168, 8192, 14336, 16384]
217+
218+
shapes_bf16 = []
219+
shapes_fp16 = []
220+
221+
for num_tokens in num_tokens_list:
222+
for hidden_size in hidden_sizes:
223+
shape = (num_tokens, hidden_size)
224+
shapes_bf16.append(shape)
225+
shapes_fp16.append(shape)
226+
227+
return [
228+
(shapes_bf16, torch.bfloat16),
229+
(shapes_fp16, torch.float16),
230+
]
231+
232+
def create_inputs(
233+
self, dtype: torch.dtype, **shape_params
234+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
235+
"""
236+
Create input tensors for rms_norm_fp8 kernel.
237+
238+
Args:
239+
dtype: Data type for inputs
240+
**shape_params: Must contain 'shape' - a tuple specifying input shape
241+
242+
Returns:
243+
Tuple of (input_tensor, weight, scale)
244+
- input_tensor has shape (num_tokens, hidden_size)
245+
- weight has shape (hidden_size,)
246+
- scale is a scalar tensor
247+
"""
248+
shape = shape_params["shape"]
249+
hidden_size = shape[-1]
250+
251+
input_tensor = torch.randn(*shape, dtype=dtype, device="cuda")
252+
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")
253+
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
254+
return input_tensor, weight, scale
255+
256+
def run_baseline(
257+
self, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
258+
) -> torch.Tensor:
259+
"""
260+
Run the baseline reference kernel.
261+
262+
This is the existing vLLM CUDA kernel that Helion is meant to
263+
replace or accelerate. Used for performance comparison in benchmarks.
264+
265+
Args:
266+
input: Input tensor with shape (num_tokens, hidden_size)
267+
weight: Weight tensor with shape (hidden_size,)
268+
scale: Scale tensor (scalar)
269+
270+
Returns:
271+
Output tensor from baseline kernel
272+
"""
273+
out = torch.empty_like(input, dtype=torch.float8_e4m3fn)
274+
torch.ops._C.rms_norm_static_fp8_quant(out, input, weight, scale, self.epsilon)
275+
return out
276+
277+
def run_helion(
278+
self, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
279+
) -> torch.Tensor:
280+
"""
281+
Run the Helion kernel.
282+
283+
Args:
284+
input: Input tensor with shape (num_tokens, hidden_size)
285+
weight: Weight tensor with shape (hidden_size,)
286+
scale: Scale tensor (scalar)
287+
288+
Returns:
289+
Output tensor from Helion kernel
290+
"""
291+
return self.op.forward_helion(input, weight, scale, self.epsilon)

0 commit comments

Comments
 (0)