Skip to content

Commit 7e07947

Browse files
committed
update apply rotary num warps
1 parent e904207 commit 7e07947

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ def _apply_rotary_impl(x_l, x_h, cos_l, cos_h, sin_l, sin_h):
1717
# triton 3.4 would do fma 3 times to perform the above computation,
1818
# which causes higher numerical error. So we manually expand the
1919
# computation to avoid fma.
20-
x_l_new = x_l * cos_l + 0
21-
x_l_new -= x_h * sin_l + 0
22-
x_h_new = x_h * cos_h + 0
23-
x_h_new += x_l * sin_h + 0
24-
return x_l_new, x_h_new
20+
x_l_new0 = x_l * cos_l + 0
21+
x_l_new1 = x_h * sin_l + 0
22+
x_h_new0 = x_h * cos_h + 0
23+
x_h_new1 = x_l * sin_h + 0
24+
return x_l_new0 - x_l_new1, x_h_new0 + x_h_new1
2525

2626

2727
@triton.jit(do_not_specialize=('seq_len', ))
@@ -142,7 +142,6 @@ def apply_rotary_pos_emb(q: Tensor,
142142
k_embed = torch.empty_like(k)
143143

144144
seq_len = cos.numel() // cos.size(-1)
145-
BLOCK = 16
146145

147146
if q.size(-1) == cos.size(-1):
148147
half_size = q.size(-1) // 2
@@ -156,9 +155,16 @@ def apply_rotary_pos_emb(q: Tensor,
156155
BLOCK_N = triton.next_power_of_2(half_size)
157156
num_heads_q = q.size(-2)
158157
num_heads_k = k.size(-2)
159-
num_warps = 4
158+
num_warps = 2
160159
num_stages = 1
161160

161+
# compute best BLOCK size
162+
num_threads = num_warps * 32
163+
elem_size = q.dtype.itemsize
164+
elem_per_ldgv4 = 16 // elem_size
165+
BLOCK = num_threads * elem_per_ldgv4 // BLOCK_N
166+
BLOCK = max(1, BLOCK)
167+
162168
grid = (
163169
num_heads_q + num_heads_k,
164170
triton.cdiv(seq_len, BLOCK),

0 commit comments

Comments
 (0)