Skip to content

Commit e904207

Browse files
committed
fix kernel numerical error
1 parent e1d4a47 commit e904207

File tree

4 files changed

+39
-22
lines changed

4 files changed

+39
-22
lines changed

lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,25 @@
55
from torch import Tensor
66

77

8+
@triton.jit
9+
def _apply_rotary_impl(x_l, x_h, cos_l, cos_h, sin_l, sin_h):
10+
"""Apply rotary positional embedding implementation."""
11+
# x_l, x_h: [BLOCK, BLOCK_N]
12+
# cos_l, cos_h, sin_l, sin_h: [BLOCK, BLOCK_N]
13+
14+
# qe_l = q_l * cos_l - q_h * sin_l
15+
# qe_h = q_h * cos_h + q_l * sin_h
16+
17+
# triton 3.4 would do fma 3 times to perform the above computation,
18+
# which causes higher numerical error. So we manually expand the
19+
# computation to avoid fma.
20+
x_l_new = x_l * cos_l + 0
21+
x_l_new -= x_h * sin_l + 0
22+
x_h_new = x_h * cos_h + 0
23+
x_h_new += x_l * sin_h + 0
24+
return x_l_new, x_h_new
25+
26+
827
@triton.jit(do_not_specialize=('seq_len', ))
928
def apply_rotary_pos_emb_qk_kernel(
1029
Q,
@@ -67,8 +86,8 @@ def apply_rotary_pos_emb_qk_kernel(
6786

6887
q_l = tl.load(ql_ptrs)
6988
q_h = tl.load(qh_ptrs)
70-
qe_l = q_l * cos_l - q_h * sin_l
71-
qe_h = q_h * cos_h + q_l * sin_h
89+
90+
qe_l, qe_h = _apply_rotary_impl(q_l, q_h, cos_l, cos_h, sin_l, sin_h)
7291

7392
tl.store(qel_ptrs, qe_l, mask=seq_mask)
7493
tl.store(qeh_ptrs, qe_h, mask=seq_mask)
@@ -86,8 +105,8 @@ def apply_rotary_pos_emb_qk_kernel(
86105
keh_ptrs += head_id * stride_keh
87106
k_l = tl.load(kl_ptrs)
88107
k_h = tl.load(kh_ptrs)
89-
ke_l = k_l * cos_l - k_h * sin_l
90-
ke_h = k_h * cos_h + k_l * sin_h
108+
109+
ke_l, ke_h = _apply_rotary_impl(k_l, k_h, cos_l, cos_h, sin_l, sin_h)
91110

92111
tl.store(kel_ptrs, ke_l, mask=seq_mask)
93112
tl.store(keh_ptrs, ke_h, mask=seq_mask)

lmdeploy/pytorch/kernels/cuda/rms_norm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def _compute_rms_norm(x, w, eps: tl.constexpr, N_COLS: tl.constexpr):
1414

1515
var = tl.sum(xf * xf, 0) * float(1.0 / N_COLS)
1616
out = xf * tl.math.rsqrt(var + eps)
17-
out = (w * out).to(x.dtype)
17+
out = w * out.to(x.dtype)
1818
return out
1919

2020

@@ -27,7 +27,7 @@ def rms_norm_kernel(input, weight, output, seq_len, input_row_stride: tl.constex
2727
offsets = tl.arange(0, BLOCK_N)
2828
mask = offsets < N_COLS
2929

30-
w = tl.load(weight + offsets, mask=mask).to(tl.float32)
30+
w = tl.load(weight + offsets, mask=mask)
3131

3232
x_ptr = input + prog_id * input_row_stride + offsets
3333
out_ptr = output + prog_id * input_row_stride + offsets
@@ -50,7 +50,7 @@ def add_rms_norm_kernel(input, weight, residual, output, out_residual, seq_len,
5050
offsets = tl.arange(0, BLOCK_N)
5151
mask = offsets < N_COLS
5252

53-
w = tl.load(weight + offsets, mask=mask).to(tl.float32)
53+
w = tl.load(weight + offsets, mask=mask)
5454

5555
x_ptr = input + prog_id * input_row_stride + offsets
5656
res_ptr = residual + prog_id * residual_row_stride + offsets

tests/pytorch/kernel/test_apply_rotary.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def num_heads_k(self, request):
3535

3636
@pytest.fixture
3737
def feature_dim(self):
38-
yield 16
38+
yield 128
3939

4040
@pytest.fixture
4141
def seq_length(self, batch_size):
@@ -47,23 +47,23 @@ def max_seqlen(self, seq_length):
4747

4848
@pytest.fixture
4949
def q_states(self, seq_length, num_heads_q, feature_dim, dtype):
50-
yield torch.rand(seq_length.sum(), num_heads_q, feature_dim, dtype=dtype, device='cuda')
50+
yield torch.randn(seq_length.sum(), num_heads_q, feature_dim, dtype=dtype, device='cuda')
5151

5252
@pytest.fixture
5353
def k_states(self, seq_length, num_heads_k, feature_dim, dtype):
54-
yield torch.rand(seq_length.sum(), num_heads_k, feature_dim, dtype=dtype, device='cuda')
54+
yield torch.randn(seq_length.sum(), num_heads_k, feature_dim, dtype=dtype, device='cuda')
5555

5656
@pytest.fixture
5757
def position_ids_1d(self, seq_length, max_seqlen):
5858
yield torch.randint(0, max_seqlen.item(), (seq_length.sum().item(), ), device='cuda')
5959

6060
@pytest.fixture
6161
def cached_cos(self, max_seqlen, feature_dim, dtype):
62-
yield torch.rand(max_seqlen, feature_dim, dtype=dtype, device='cuda')
62+
yield torch.randn(max_seqlen, feature_dim, dtype=dtype, device='cuda')
6363

6464
@pytest.fixture
6565
def cached_sin(self, max_seqlen, feature_dim, dtype):
66-
yield torch.rand(max_seqlen, feature_dim, dtype=dtype, device='cuda')
66+
yield torch.randn(max_seqlen, feature_dim, dtype=dtype, device='cuda')
6767

6868
@pytest.fixture
6969
def cos(self, cached_cos, position_ids_1d):
@@ -91,11 +91,5 @@ def test_apply_rotary(self, q_states, k_states, cos, sin, gt):
9191

9292
rtol = None
9393
atol = None
94-
if q_states.dtype == torch.float16:
95-
rtol = 1e-5
96-
atol = 1e-3
97-
elif q_states.dtype == torch.bfloat16:
98-
rtol = 1e-5
99-
atol = 1e-2
10094
torch.testing.assert_close(q_embed, q_gt, rtol=rtol, atol=atol)
10195
torch.testing.assert_close(k_embed, k_gt, rtol=rtol, atol=atol)

tests/pytorch/kernel/test_rms_norm.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,16 @@ def dtype(self, request):
1515
yield request.param
1616

1717
@pytest.fixture(scope='class')
18-
def input(self, dtype):
19-
yield torch.rand(4, 8, dtype=dtype, device='cuda')
18+
def hidden_size(self):
19+
yield 4096
2020

2121
@pytest.fixture(scope='class')
22-
def weight(self, dtype):
23-
yield torch.rand(8, dtype=dtype, device='cuda')
22+
def input(self, dtype, hidden_size):
23+
yield torch.randn(4, hidden_size, dtype=dtype, device='cuda')
24+
25+
@pytest.fixture(scope='class')
26+
def weight(self, dtype, hidden_size):
27+
yield torch.randn(hidden_size, dtype=dtype, device='cuda')
2428

2529
@pytest.fixture(scope='class')
2630
def eps(self):

0 commit comments

Comments
 (0)