1+ import torch
2+
3+ import triton
4+ import triton .language as tl
5+
6+ # g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
7+ @triton .jit
8+ def fused_gdn_gating_kernel (
9+ g ,
10+ A_log ,
11+ a ,
12+ dt_bias ,
13+ seq_len ,
14+ NUM_HEADS : tl .constexpr ,
15+ NUM_BATCHES : tl .constexpr ,
16+ beta : tl .constexpr ,
17+ threshold : tl .constexpr ,
18+ BLK_HEADS : tl .constexpr ,
19+ BLK_BATCHES : tl .constexpr
20+ ):
21+ i_b , i_s , i_d = tl .program_id (0 ), tl .program_id (1 ), tl .program_id (2 )
22+ head_off = i_d * BLK_HEADS + tl .arange (0 , BLK_HEADS )
23+ batch_off = i_b * BLK_BATCHES + tl .arange (0 , BLK_BATCHES )
24+ off = batch_off [:, None ] * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off [None , :]
25+ head_mask = head_off < NUM_HEADS
26+ mask = head_mask [None , :] & (batch_off [:, None ] < NUM_BATCHES )
27+ blk_A_log = tl .load (A_log + head_off , mask = head_mask )
28+ blk_a = tl .load (a + off , mask = mask )
29+ blk_bias = tl .load (dt_bias + head_off , mask = head_mask )
30+ # If the model is loaded in fp16, without the .float() here, A might be -inf
31+ x = blk_a .to (tl .float32 ) + blk_bias .to (tl .float32 )[None , :]
32+ softplus_x = tl .where (beta * x <= threshold ,
33+ (1 / beta ) * tl .log (1 + tl .exp (beta * x )), x )
34+ blk_g = - tl .exp (blk_A_log .to (tl .float32 )) * softplus_x
35+ tl .store (g + off , blk_g .to (g .dtype .element_ty ), mask = mask )
36+
37+
38+ def fused_gdn_gating (
39+ A_log : torch .Tensor ,
40+ a : torch .Tensor ,
41+ dt_bias : torch .Tensor ,
42+ beta : float = 1.0 ,
43+ threshold : float = 20.0 ,
44+ ) -> torch .Tensor :
45+ batch , num_heads = a .shape
46+ seq_len = 1
47+ NUM_BATCH_GROUPS = batch
48+ BLK_BATCHES = 1
49+ if batch > 40 :
50+ BLK_BATCHES = triton .next_power_of_2 (triton .cdiv (batch , 32 ))
51+ NUM_BATCH_GROUPS = triton .cdiv (batch , BLK_BATCHES )
52+
53+ grid = (NUM_BATCH_GROUPS , seq_len , triton .cdiv (num_heads , 8 ))
54+ g = torch .empty_like (a , dtype = torch .float32 )
55+ fused_gdn_gating_kernel [grid ](g ,
56+ A_log ,
57+ a ,
58+ dt_bias ,
59+ seq_len ,
60+ num_heads ,
61+ batch ,
62+ beta ,
63+ threshold ,
64+ 8 ,
65+ BLK_BATCHES = BLK_BATCHES ,
66+ num_warps = 1 )
67+ return g
0 commit comments