22
33import triton
44import triton .language as tl
5+ import triton .runtime .driver as driver
56
67# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
78@triton .jit
@@ -16,24 +17,32 @@ def fused_gdn_gating_kernel(
1617 beta : tl .constexpr ,
1718 threshold : tl .constexpr ,
1819 BLK_HEADS : tl .constexpr ,
19- BLK_BATCHES : tl .constexpr
20+ COL_ITER : tl .constexpr ,
21+ BLK_BATCHES : tl .constexpr ,
22+ ROW_ITER : tl .constexpr ,
2023):
21- i_b , i_s , i_d = tl .program_id (0 ), tl .program_id (1 ), tl .program_id (2 )
22- head_off = i_d * BLK_HEADS + tl .arange (0 , BLK_HEADS )
23- batch_off = i_b * BLK_BATCHES + tl .arange (0 , BLK_BATCHES )
24- off = batch_off [:, None ] * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off [None , :]
25- head_mask = head_off < NUM_HEADS
26- mask = head_mask [None , :] & (batch_off [:, None ] < NUM_BATCHES )
27- blk_A_log = tl .load (A_log + head_off , mask = head_mask )
28- blk_a = tl .load (a + off , mask = mask )
29- blk_bias = tl .load (dt_bias + head_off , mask = head_mask )
30- # If the model is loaded in fp16, without the .float() here, A might be -inf
31- x = blk_a .to (tl .float32 ) + blk_bias .to (tl .float32 )[None , :]
32- softplus_x = tl .where (beta * x <= threshold ,
33- (1 / beta ) * tl .log (1 + tl .exp (beta * x )), x )
34- blk_g = - tl .exp (blk_A_log .to (tl .float32 )) * softplus_x
35- tl .store (g + off , blk_g .to (g .dtype .element_ty ), mask = mask )
24+ # New impl
25+ i_b , i_s = tl .program_id (0 ), tl .program_id (1 )
26+ for row_idx in range (0 , ROW_ITER ):
27+ batch_off = i_b * ROW_ITER * BLK_BATCHES + row_idx * BLK_BATCHES + tl .arange (0 , BLK_BATCHES )
3628
29+ for col_idx in range (0 , COL_ITER ):
30+ head_off = col_idx * BLK_HEADS + tl .arange (0 , BLK_HEADS )
31+
32+ off = batch_off [:, None ] * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off [None , :]
33+ head_mask = head_off < NUM_HEADS
34+ mask = head_mask [None , :] & (batch_off [:, None ] < NUM_BATCHES )
35+ blk_A_log = tl .load (A_log + head_off , mask = head_mask )
36+ blk_a = tl .load (a + off , mask = mask )
37+ blk_bias = tl .load (dt_bias + head_off , mask = head_mask )
38+
39+ x = blk_a .to (tl .float32 ) + blk_bias .to (tl .float32 )[None , :]
40+ softplus_x = tl .where (beta * x <= threshold ,
41+ (1 / beta ) * tl .log (1 + tl .exp (beta * x )), x )
42+
43+ blk_g = - tl .exp (blk_A_log .to (tl .float32 )) * softplus_x
44+
45+ tl .store (g + off , blk_g .to (g .dtype .element_ty ), mask = mask )
3746
3847def fused_gdn_gating (
3948 A_log : torch .Tensor ,
@@ -44,14 +53,34 @@ def fused_gdn_gating(
4453) -> torch .Tensor :
4554 batch , num_heads = a .shape
4655 seq_len = 1
47- NUM_BATCH_GROUPS = batch
48- BLK_BATCHES = 1
49- if batch > 40 :
50- BLK_BATCHES = triton .next_power_of_2 (triton .cdiv (batch , 32 ))
51- NUM_BATCH_GROUPS = triton .cdiv (batch , BLK_BATCHES )
52-
53- grid = (NUM_BATCH_GROUPS , seq_len , triton .cdiv (num_heads , 8 ))
56+
57+ num_cores = driver .active .utils .get_device_properties (torch .npu .current_device ())["num_vectorcore" ]
58+
59+ # a_log_size = A_log.element_size() * A_log.nelement()
60+ # a_size = a.element_size() * a.nelement()
61+ # dt_bias_size = dt_bias.element_size() * dt_bias.nelement()
62+
63+ # 1. Row
64+ BLK_HEADS = 8 # TODO
65+ COL_ITER = triton .cdiv (num_heads , BLK_HEADS )
66+
67+ # 2. Col
68+ if batch <= num_cores :
69+ progs = batch
70+ BLK_BATCHES = 1
71+ ROW_ITER = 1
72+ else :
73+ progs = num_cores
74+
75+ factor = 64 # Black box ub factor
76+ row_per_core = triton .cdiv (batch , num_cores )
77+ BLK_BATCHES = triton .next_power_of_2 (triton .cdiv (1572864 , factor * BLK_HEADS ) // a .element_size ()) // 2
78+ ROW_ITER = triton .cdiv (row_per_core , BLK_BATCHES )
79+
80+
5481 g = torch .empty_like (a , dtype = torch .float32 )
82+
83+ grid = (progs , seq_len )
5584 fused_gdn_gating_kernel [grid ](g ,
5685 A_log ,
5786 a ,
@@ -61,7 +90,8 @@ def fused_gdn_gating(
6190 batch ,
6291 beta ,
6392 threshold ,
64- 8 ,
93+ BLK_HEADS = BLK_HEADS ,
94+ COL_ITER = COL_ITER ,
6595 BLK_BATCHES = BLK_BATCHES ,
66- num_warps = 1 )
67- return g
96+ ROW_ITER = ROW_ITER , )
97+ return g
0 commit comments