Skip to content

Commit 8d8b188

Browse files
committed
Add fused gdn gating
Signed-off-by: Ascendyh <[email protected]>
1 parent 51c8f60 commit 8d8b188

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import torch
2+
3+
import triton
4+
import triton.language as tl
5+
6+
# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
7+
@triton.jit
8+
def fused_gdn_gating_kernel(
9+
g,
10+
A_log,
11+
a,
12+
dt_bias,
13+
seq_len,
14+
NUM_HEADS: tl.constexpr,
15+
NUM_BATCHES: tl.constexpr,
16+
beta: tl.constexpr,
17+
threshold: tl.constexpr,
18+
BLK_HEADS: tl.constexpr,
19+
BLK_BATCHES: tl.constexpr
20+
):
21+
i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2)
22+
head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS)
23+
batch_off = i_b * BLK_BATCHES + tl.arange(0, BLK_BATCHES)
24+
off = batch_off[:, None] * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off[None, :]
25+
head_mask = head_off < NUM_HEADS
26+
mask = head_mask[None, :] & (batch_off[:, None] < NUM_BATCHES)
27+
blk_A_log = tl.load(A_log + head_off, mask=head_mask)
28+
blk_a = tl.load(a + off, mask=mask)
29+
blk_bias = tl.load(dt_bias + head_off, mask=head_mask)
30+
# If the model is loaded in fp16, without the .float() here, A might be -inf
31+
x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)[None, :]
32+
softplus_x = tl.where(beta * x <= threshold,
33+
(1 / beta) * tl.log(1 + tl.exp(beta * x)), x)
34+
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
35+
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
36+
37+
38+
def fused_gdn_gating(
39+
A_log: torch.Tensor,
40+
a: torch.Tensor,
41+
dt_bias: torch.Tensor,
42+
beta: float = 1.0,
43+
threshold: float = 20.0,
44+
) -> torch.Tensor:
45+
batch, num_heads = a.shape
46+
seq_len = 1
47+
NUM_BATCH_GROUPS = batch
48+
BLK_BATCHES = 1
49+
if batch > 40:
50+
BLK_BATCHES = triton.next_power_of_2(triton.cdiv(batch, 32))
51+
NUM_BATCH_GROUPS = triton.cdiv(batch, BLK_BATCHES)
52+
53+
grid = (NUM_BATCH_GROUPS, seq_len, triton.cdiv(num_heads, 8))
54+
g = torch.empty_like(a, dtype=torch.float32)
55+
fused_gdn_gating_kernel[grid](g,
56+
A_log,
57+
a,
58+
dt_bias,
59+
seq_len,
60+
num_heads,
61+
batch,
62+
beta,
63+
threshold,
64+
8,
65+
BLK_BATCHES=BLK_BATCHES,
66+
num_warps=1)
67+
return g

0 commit comments

Comments
 (0)