@@ -17,11 +17,11 @@ def _apply_rotary_impl(x_l, x_h, cos_l, cos_h, sin_l, sin_h):
1717 # triton 3.4 would do fma 3 times to perform the above computation,
1818 # which causes higher numerical error. So we manually expand the
1919 # computation to avoid fma.
20- x_l_new = x_l * cos_l + 0
21- x_l_new - = x_h * sin_l + 0
22- x_h_new = x_h * cos_h + 0
23- x_h_new + = x_l * sin_h + 0
24- return x_l_new , x_h_new
20+ x_l_new0 = x_l * cos_l + 0
21+ x_l_new1 = x_h * sin_l + 0
22+ x_h_new0 = x_h * cos_h + 0
23+ x_h_new1 = x_l * sin_h + 0
24+ return x_l_new0 - x_l_new1 , x_h_new0 + x_h_new1
2525
2626
2727@triton .jit (do_not_specialize = ('seq_len' , ))
@@ -142,7 +142,6 @@ def apply_rotary_pos_emb(q: Tensor,
142142 k_embed = torch .empty_like (k )
143143
144144 seq_len = cos .numel () // cos .size (- 1 )
145- BLOCK = 16
146145
147146 if q .size (- 1 ) == cos .size (- 1 ):
148147 half_size = q .size (- 1 ) // 2
@@ -156,9 +155,16 @@ def apply_rotary_pos_emb(q: Tensor,
156155 BLOCK_N = triton .next_power_of_2 (half_size )
157156 num_heads_q = q .size (- 2 )
158157 num_heads_k = k .size (- 2 )
159- num_warps = 4
158+ num_warps = 2
160159 num_stages = 1
161160
161+ # compute best BLOCK size
162+ num_threads = num_warps * 32
163+ elem_size = q .dtype .itemsize
164+ elem_per_ldgv4 = 16 // elem_size
165+ BLOCK = num_threads * elem_per_ldgv4 // BLOCK_N
166+ BLOCK = max (1 , BLOCK )
167+
162168 grid = (
163169 num_heads_q + num_heads_k ,
164170 triton .cdiv (seq_len , BLOCK ),
0 commit comments