@@ -18,6 +18,8 @@ def _quant_fp8_kernel(
1818 scale_ptr ,
1919 M ,
2020 M_out ,
21+ K : tl .constexpr ,
22+ num_groups_per_cta : tl .constexpr ,
2123 fp8_min : tl .constexpr ,
2224 fp8_max : tl .constexpr ,
2325 stride_am ,
@@ -30,30 +32,43 @@ def _quant_fp8_kernel(
3032 NUM_STAGES : tl .constexpr ,
3133):
3234 """Quant fp8 kernel."""
33- group_id = tl .program_id (0 )
35+ group_id = tl .program_id (0 ) * num_groups_per_cta
3436 m_id_start = tl .program_id (1 )
3537 m_id_stride = tl .num_programs (1 )
3638
37- g_offs = group_id * GROUP_SIZE + tl .arange (0 , GROUP_SIZE )
39+ GROUP_SIZE_CTA : tl .constexpr = GROUP_SIZE * num_groups_per_cta
40+ g_offs = group_id * GROUP_SIZE + tl .arange (0 , GROUP_SIZE_CTA )
3841 g_offs = tl .max_contiguous (tl .multiple_of (g_offs , GROUP_SIZE ), GROUP_SIZE )
42+ gs_offs = group_id + tl .arange (0 , num_groups_per_cta )
3943 rfp8_max = 1 / fp8_max
4044
4145 m_id = m_id_start
4246 a_ptrs = a_ptr + m_id * stride_am + g_offs * stride_ak
4347 o_ptrs = out_ptr + m_id * stride_om + g_offs * stride_ok
44- s_ptr = scale_ptr + m_id * stride_sm + group_id * stride_sg
48+ s_ptr = scale_ptr + m_id * stride_sm + gs_offs * stride_sg
49+ if K % GROUP_SIZE_CTA == 0 :
50+ mask_n = True
51+ mask_s = True
52+ mask_o = True
53+ else :
54+ mask_n = g_offs < K
55+ mask_o = g_offs < K
56+ mask_s = gs_offs < tl .cdiv (K , GROUP_SIZE )
4557
4658 for m_id in tl .range (m_id_start , M_out , m_id_stride , num_stages = NUM_STAGES ):
47-
48- a = tl .load (a_ptrs , mask = m_id < M , other = 0 ).to (tl .float32 )
49- scale = tl .maximum (tl .max (tl .abs (a )), 1e-6 ) * rfp8_max
50- out = a / scale
59+ a = tl .load (a_ptrs , mask = mask_n & (m_id < M ), other = 0 )
60+ a = a .reshape (num_groups_per_cta , GROUP_SIZE )
61+ a_max = tl .max (tl .abs (a ), axis = 1 )
62+ a_max = tl .maximum (a_max , 1e-6 ).to (tl .float32 )
63+ scale = a_max * rfp8_max
64+ rscale = fp8_max / a_max # triton does not support rcp
65+ out = a .to (tl .float32 ) * rscale [:, None ]
5166
5267 out = tl .clamp (out , fp8_min , fp8_max )
5368 out = out .to (out_ptr .dtype .element_ty )
54-
55- tl .store (o_ptrs , out )
56- tl .store (s_ptr , scale )
69+ out = out . reshape ( GROUP_SIZE * num_groups_per_cta )
70+ tl .store (o_ptrs , out , mask = mask_o )
71+ tl .store (s_ptr , scale , mask = mask_s )
5772
5873 a_ptrs += m_id_stride * stride_am
5974 o_ptrs += m_id_stride * stride_om
@@ -63,30 +78,37 @@ def _quant_fp8_kernel(
6378def _quant_fp8_launcher (A : Tensor , group_size : int , out : Tensor , scales : Tensor ):
6479 """Quant online."""
6580 M , K = A .shape
66- num_groups = K // group_size
6781 M_out = out .size (0 )
6882
6983 dtype = out .dtype
7084 finfo = torch .finfo (dtype )
7185 fmin = finfo .min
7286 fmax = finfo .max
7387
74- num_warps = 1
75-
88+ num_warps = 2
89+ # every cp/ldg instruct can load 128bit=16byte data
90+ # each warp can read 512 byte data
91+ elem_size = A .element_size ()
92+ num_groups_per_warp = 512 // (group_size * elem_size )
93+ num_groups_per_cta = num_groups_per_warp * num_warps
94+ grid_size0 = triton .cdiv (K , group_size * num_groups_per_cta )
7695 props = get_device_props (A .device .index )
7796 num_sm = props ['multi_processor_count' ]
7897 warps_per_sm = props ['warps_per_sm' ]
79- max_ctas = num_sm * warps_per_sm // num_warps
80- grid_size1 = min (M_out , max_ctas // num_groups )
98+ blocks_per_sm = props ['blocks_per_sm' ]
99+ max_ctas = num_sm * min (blocks_per_sm , warps_per_sm // num_warps )
100+ grid_size1 = min (M_out , max_ctas // grid_size0 )
81101 assert grid_size1 < 65536
82- num_stages = min (5 , max (1 , triton .cdiv (M_out , grid_size1 )))
83- grid = (num_groups , grid_size1 )
102+ num_stages = min (4 , max (1 , triton .cdiv (M_out , grid_size1 )))
103+ grid = (grid_size0 , grid_size1 )
84104 _quant_fp8_kernel [grid ](
85105 A ,
86106 out ,
87107 scales ,
88108 M ,
89109 M_out ,
110+ K ,
111+ num_groups_per_cta = num_groups_per_cta ,
90112 fp8_min = fmin ,
91113 fp8_max = fmax ,
92114 stride_am = A .stride (0 ),
0 commit comments