@@ -68,6 +68,7 @@ def torch_chunk_gated_delta_rule(
6868 value ,
6969 g ,
7070 beta ,
71+ eye_constant ,
7172 chunk_size = 64 ,
7273 initial_state = None ,
7374 output_final_state = True ,
@@ -113,42 +114,44 @@ def torch_chunk_gated_delta_rule(
113114
114115 # chunk decay
115116 g = g .cumsum (dim = - 1 )
117+ g_exp = g .exp ()
116118 decay_mask = ((g .unsqueeze (- 1 ) -
117119 g .unsqueeze (- 2 )).tril ().exp ().float ()).tril ()
118120 attn = - ((torch .matmul (k_beta .contiguous (),
119121 key .transpose (- 1 , - 2 ).contiguous ())) *
120122 decay_mask ).masked_fill (mask , 0 )
121123 for i in range (1 , chunk_size ):
122- row = attn [..., i , :i ].clone ()
123- sub = attn [..., :i , :i ]. clone ()
124- attn [..., i , :i ] = row + (row .unsqueeze (- 1 ) * sub ).sum (- 2 )
125- attn = attn + torch . eye ( chunk_size , dtype = attn . dtype , device = attn . device )
124+ row = attn [..., i , :i ].contiguous ()
125+ sub = attn [..., :i , :]
126+ attn [..., i , :i ] = row + (row .unsqueeze (- 1 ) * sub ).sum (- 2 )[..., : i ]
127+ attn = attn + eye_constant
126128 value = attn @ v_beta
127- k_cumdecay = attn @ (k_beta * g . exp () .unsqueeze (- 1 ))
129+ k_cumdecay = attn @ (k_beta * g_exp .unsqueeze (- 1 ))
128130 last_recurrent_state = (torch .zeros (batch_size , num_heads , k_head_dim ,
129131 v_head_dim ).to (value ) if initial_state
130132 is None else initial_state .to (value ))
131133 core_attn_out = torch .zeros_like (value )
132- mask = torch .triu (torch .ones (chunk_size ,
134+ mask = torch .tril (torch .ones (chunk_size ,
133135 chunk_size ,
134136 dtype = torch .bool ,
135137 device = query .device ),
136- diagonal = 1 )
138+ diagonal = 0 )
139+ mask = mask .view (1 , 1 , 1 , chunk_size , chunk_size )
140+ attn = (query @ key .transpose (- 1 , - 2 )) * decay_mask * mask
141+ qg = query * g_exp [..., None ]
142+ delta_g_exp = (g [:, :, :, - 1 , None ] - g ).exp ()[..., None ]
143+ k_term = (key * delta_g_exp )
137144
138145 # for each chunk
139146 for i in range (0 , tot_len // chunk_size ):
140147 q_i , k_i , v_i = query [:, :, i ], key [:, :, i ], value [:, :, i ]
141- attn = (q_i @ k_i .transpose (- 1 , - 2 ) *
142- decay_mask [:, :, i ]).masked_fill_ (mask , 0 )
143148 v_prime = (k_cumdecay [:, :, i ]) @ last_recurrent_state
144149 v_new = v_i - v_prime
145- attn_inter = ( q_i * g [:, :, i , :, None ]. exp ()) @ last_recurrent_state
146- core_attn_out [:, :, i ] = attn_inter + attn @ v_new
150+ attn_inter = qg [:,:, i ] @ last_recurrent_state
151+ core_attn_out [:, :, i ] = attn_inter + attn [:,:, i ] @ v_new
147152 last_recurrent_state = (
148- last_recurrent_state * g [:, :, i , - 1 , None , None ].exp () +
149- (k_i *
150- (g [:, :, i , - 1 , None ] - g [:, :, i ]).exp ()[..., None ]).transpose (
151- - 1 , - 2 ) @ v_new )
153+ last_recurrent_state * g_exp [:, :, i , - 1 , None , None ] +
154+ k_term [:,:,i ].transpose (- 1 , - 2 ) @ v_new )
152155
153156 if not output_final_state :
154157 last_recurrent_state = None
@@ -456,6 +459,11 @@ def __init__(
456459 dtype = torch .float32 ,
457460 device = self .conv1d .weight .device )
458461
462+ self .chunk_size = 64
463+ self .eye_constant = torch .eye (self .chunk_size ,
464+ dtype = torch .float32 ,
465+ device = self .conv1d .weight .device )
466+
459467 # time step projection (discretization)
460468 # instantiate once and copy inv_dt in init_weights of PretrainedModel
461469 self .dt_bias = nn .Parameter (
@@ -660,6 +668,8 @@ def forward(
660668 value_non_spec ,
661669 g = g ,
662670 beta = beta ,
671+ eye_constant = self .eye_constant ,
672+ chunk_size = self .chunk_size ,
663673 initial_state = None ,
664674 output_final_state = True ,
665675 use_qk_l2norm_in_kernel = True ,
0 commit comments