Skip to content

Commit 3f992aa

Browse files
committed
optimize quant fp8
1 parent adef903 commit 3f992aa

File tree

1 file changed

+39
-17
lines changed

1 file changed

+39
-17
lines changed

lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ def _quant_fp8_kernel(
1818
scale_ptr,
1919
M,
2020
M_out,
21+
K: tl.constexpr,
22+
num_groups_per_cta: tl.constexpr,
2123
fp8_min: tl.constexpr,
2224
fp8_max: tl.constexpr,
2325
stride_am,
@@ -30,30 +32,43 @@ def _quant_fp8_kernel(
3032
NUM_STAGES: tl.constexpr,
3133
):
3234
"""Quant fp8 kernel."""
33-
group_id = tl.program_id(0)
35+
group_id = tl.program_id(0) * num_groups_per_cta
3436
m_id_start = tl.program_id(1)
3537
m_id_stride = tl.num_programs(1)
3638

37-
g_offs = group_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
39+
GROUP_SIZE_CTA: tl.constexpr = GROUP_SIZE * num_groups_per_cta
40+
g_offs = group_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE_CTA)
3841
g_offs = tl.max_contiguous(tl.multiple_of(g_offs, GROUP_SIZE), GROUP_SIZE)
42+
gs_offs = group_id + tl.arange(0, num_groups_per_cta)
3943
rfp8_max = 1 / fp8_max
4044

4145
m_id = m_id_start
4246
a_ptrs = a_ptr + m_id * stride_am + g_offs * stride_ak
4347
o_ptrs = out_ptr + m_id * stride_om + g_offs * stride_ok
44-
s_ptr = scale_ptr + m_id * stride_sm + group_id * stride_sg
48+
s_ptr = scale_ptr + m_id * stride_sm + gs_offs * stride_sg
49+
if K % GROUP_SIZE_CTA == 0:
50+
mask_n = True
51+
mask_s = True
52+
mask_o = True
53+
else:
54+
mask_n = g_offs < K
55+
mask_o = g_offs < K
56+
mask_s = gs_offs < tl.cdiv(K, GROUP_SIZE)
4557

4658
for m_id in tl.range(m_id_start, M_out, m_id_stride, num_stages=NUM_STAGES):
47-
48-
a = tl.load(a_ptrs, mask=m_id < M, other=0).to(tl.float32)
49-
scale = tl.maximum(tl.max(tl.abs(a)), 1e-6) * rfp8_max
50-
out = a / scale
59+
a = tl.load(a_ptrs, mask=mask_n & (m_id < M), other=0)
60+
a = a.reshape(num_groups_per_cta, GROUP_SIZE)
61+
a_max = tl.max(tl.abs(a), axis=1)
62+
a_max = tl.maximum(a_max, 1e-6).to(tl.float32)
63+
scale = a_max * rfp8_max
64+
rscale = fp8_max / a_max # triton does not support rcp
65+
out = a.to(tl.float32) * rscale[:, None]
5166

5267
out = tl.clamp(out, fp8_min, fp8_max)
5368
out = out.to(out_ptr.dtype.element_ty)
54-
55-
tl.store(o_ptrs, out)
56-
tl.store(s_ptr, scale)
69+
out = out.reshape(GROUP_SIZE * num_groups_per_cta)
70+
tl.store(o_ptrs, out, mask=mask_o)
71+
tl.store(s_ptr, scale, mask=mask_s)
5772

5873
a_ptrs += m_id_stride * stride_am
5974
o_ptrs += m_id_stride * stride_om
@@ -63,30 +78,37 @@ def _quant_fp8_kernel(
6378
def _quant_fp8_launcher(A: Tensor, group_size: int, out: Tensor, scales: Tensor):
6479
"""Quant online."""
6580
M, K = A.shape
66-
num_groups = K // group_size
6781
M_out = out.size(0)
6882

6983
dtype = out.dtype
7084
finfo = torch.finfo(dtype)
7185
fmin = finfo.min
7286
fmax = finfo.max
7387

74-
num_warps = 1
75-
88+
num_warps = 2
89+
# every cp/ldg instruct can load 128bit=16byte data
90+
# each warp can read 512 byte data
91+
elem_size = A.element_size()
92+
num_groups_per_warp = 512 // (group_size * elem_size)
93+
num_groups_per_cta = num_groups_per_warp * num_warps
94+
grid_size0 = triton.cdiv(K, group_size * num_groups_per_cta)
7695
props = get_device_props(A.device.index)
7796
num_sm = props['multi_processor_count']
7897
warps_per_sm = props['warps_per_sm']
79-
max_ctas = num_sm * warps_per_sm // num_warps
80-
grid_size1 = min(M_out, max_ctas // num_groups)
98+
blocks_per_sm = props['blocks_per_sm']
99+
max_ctas = num_sm * min(blocks_per_sm, warps_per_sm // num_warps)
100+
grid_size1 = min(M_out, max_ctas // grid_size0)
81101
assert grid_size1 < 65536
82-
num_stages = min(5, max(1, triton.cdiv(M_out, grid_size1)))
83-
grid = (num_groups, grid_size1)
102+
num_stages = min(4, max(1, triton.cdiv(M_out, grid_size1)))
103+
grid = (grid_size0, grid_size1)
84104
_quant_fp8_kernel[grid](
85105
A,
86106
out,
87107
scales,
88108
M,
89109
M_out,
110+
K,
111+
num_groups_per_cta=num_groups_per_cta,
90112
fp8_min=fmin,
91113
fp8_max=fmax,
92114
stride_am=A.stride(0),

0 commit comments

Comments
 (0)