Skip to content

Commit 7aaf24a

Browse files
committed
[fix] UB overflow bugfix
Signed-off-by: Ascendyh <[email protected]>
1 parent 8d8b188 commit 7aaf24a

File tree

1 file changed

+56
-26
lines changed

1 file changed

+56
-26
lines changed

vllm_ascend/ops/fused_gdn_gating.py

Lines changed: 56 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import triton
44
import triton.language as tl
5+
import triton.runtime.driver as driver
56

67
# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
78
@triton.jit
@@ -16,24 +17,32 @@ def fused_gdn_gating_kernel(
1617
beta: tl.constexpr,
1718
threshold: tl.constexpr,
1819
BLK_HEADS: tl.constexpr,
19-
BLK_BATCHES: tl.constexpr
20+
COL_ITER: tl.constexpr,
21+
BLK_BATCHES: tl.constexpr,
22+
ROW_ITER: tl.constexpr,
2023
):
21-
i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2)
22-
head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS)
23-
batch_off = i_b * BLK_BATCHES + tl.arange(0, BLK_BATCHES)
24-
off = batch_off[:, None] * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off[None, :]
25-
head_mask = head_off < NUM_HEADS
26-
mask = head_mask[None, :] & (batch_off[:, None] < NUM_BATCHES)
27-
blk_A_log = tl.load(A_log + head_off, mask=head_mask)
28-
blk_a = tl.load(a + off, mask=mask)
29-
blk_bias = tl.load(dt_bias + head_off, mask=head_mask)
30-
# If the model is loaded in fp16, without the .float() here, A might be -inf
31-
x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)[None, :]
32-
softplus_x = tl.where(beta * x <= threshold,
33-
(1 / beta) * tl.log(1 + tl.exp(beta * x)), x)
34-
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
35-
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
24+
# New impl
25+
i_b, i_s = tl.program_id(0), tl.program_id(1)
26+
for row_idx in range(0, ROW_ITER):
27+
batch_off = i_b * ROW_ITER * BLK_BATCHES + row_idx * BLK_BATCHES + tl.arange(0, BLK_BATCHES)
3628

29+
for col_idx in range(0, COL_ITER):
30+
head_off = col_idx * BLK_HEADS + tl.arange(0, BLK_HEADS)
31+
32+
off = batch_off[:, None] * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off[None, :]
33+
head_mask = head_off < NUM_HEADS
34+
mask = head_mask[None, :] & (batch_off[:, None] < NUM_BATCHES)
35+
blk_A_log = tl.load(A_log + head_off, mask=head_mask)
36+
blk_a = tl.load(a + off, mask=mask)
37+
blk_bias = tl.load(dt_bias + head_off, mask=head_mask)
38+
39+
x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)[None, :]
40+
softplus_x = tl.where(beta * x <= threshold,
41+
(1 / beta) * tl.log(1 + tl.exp(beta * x)), x)
42+
43+
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
44+
45+
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
3746

3847
def fused_gdn_gating(
3948
A_log: torch.Tensor,
@@ -44,14 +53,34 @@ def fused_gdn_gating(
4453
) -> torch.Tensor:
4554
batch, num_heads = a.shape
4655
seq_len = 1
47-
NUM_BATCH_GROUPS = batch
48-
BLK_BATCHES = 1
49-
if batch > 40:
50-
BLK_BATCHES = triton.next_power_of_2(triton.cdiv(batch, 32))
51-
NUM_BATCH_GROUPS = triton.cdiv(batch, BLK_BATCHES)
52-
53-
grid = (NUM_BATCH_GROUPS, seq_len, triton.cdiv(num_heads, 8))
56+
57+
num_cores = driver.active.utils.get_device_properties(torch.npu.current_device())["num_vectorcore"]
58+
59+
# a_log_size = A_log.element_size() * A_log.nelement()
60+
# a_size = a.element_size() * a.nelement()
61+
# dt_bias_size = dt_bias.element_size() * dt_bias.nelement()
62+
63+
# 1. Row
64+
BLK_HEADS = 8 # TODO
65+
COL_ITER = triton.cdiv(num_heads, BLK_HEADS)
66+
67+
# 2. Col
68+
if batch <= num_cores:
69+
progs = batch
70+
BLK_BATCHES = 1
71+
ROW_ITER = 1
72+
else:
73+
progs = num_cores
74+
75+
factor = 64 # Black box ub factor
76+
row_per_core = triton.cdiv(batch, num_cores)
77+
BLK_BATCHES = triton.next_power_of_2(triton.cdiv(1572864, factor * BLK_HEADS) // a.element_size()) // 2
78+
ROW_ITER = triton.cdiv(row_per_core, BLK_BATCHES)
79+
80+
5481
g = torch.empty_like(a, dtype=torch.float32)
82+
83+
grid = (progs, seq_len)
5584
fused_gdn_gating_kernel[grid](g,
5685
A_log,
5786
a,
@@ -61,7 +90,8 @@ def fused_gdn_gating(
6190
batch,
6291
beta,
6392
threshold,
64-
8,
93+
BLK_HEADS=BLK_HEADS,
94+
COL_ITER=COL_ITER,
6595
BLK_BATCHES=BLK_BATCHES,
66-
num_warps=1)
67-
return g
96+
ROW_ITER=ROW_ITER,)
97+
return g

0 commit comments

Comments
 (0)