@@ -54,21 +54,19 @@ def forward(self, hidden_states: torch.Tensor, **kwargs):
5454 q , k , v = map (lambda x : rearrange (x , "... (h d) -> ... h d" , d = self .head_dim ), [q , k , v ])
5555 if mode == "fused_chunk" :
5656 q , k = self .feature_map (q ), self .feature_map (k )
57- o , _ = fused_chunk_linear_attn (q , k , v , normalize = True , scale = 1 , head_first = False )
57+ o , _ = fused_chunk_linear_attn (q , k , v , normalize = True , scale = 1 )
5858 elif mode == 'chunk' :
5959 q , k = self .feature_map (q ), self .feature_map (k )
60- o , _ = chunk_linear_attn (q , k , v , normalize = True , scale = 1 , head_first = False )
60+ o , _ = chunk_linear_attn (q , k , v , normalize = True , scale = 1 )
6161 elif mode == 'parallel' :
6262 assert q .shape [- 1 ] <= 128
63- o = parallel_based (q , k , v , scale = 1 , use_norm = True , head_first = False )
63+ o = parallel_based (q , k , v , scale = 1 , use_norm = True )
6464 o = rearrange (o , 'b t h d -> b t (h d)' )
6565 o = self .o_proj (o )
6666 o = self .dropout (o )
6767 return o
6868
69- # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119
70-
71- def forward_reference (self , hidden_states : torch .Tensor , filters : torch .Tensor = None , * args , ** kwargs ):
69+ def forward_reference (self , hidden_states : torch .Tensor , ** kwargs ):
7270 """
7371 x (torch.Tensor): tensor of shape (b, d, t)
7472 y (torch.Tensor): tensor of shape (b, d, t)
0 commit comments