Skip to content

Commit 6bfd5e6

Browse files
authored
[Token Mixing] Remove the head_first arg from token mixing layers (#347)
1 parent ea3972e commit 6bfd5e6

File tree

16 files changed

+4
-36
lines changed

16 files changed

+4
-36
lines changed

fla/layers/abc.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,6 @@ def forward(
194194
s=s,
195195
initial_state=recurrent_state,
196196
output_final_state=use_cache,
197-
head_first=False
198197
)
199198
if past_key_values is not None:
200199
past_key_values.update(

fla/layers/based.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,21 +54,19 @@ def forward(self, hidden_states: torch.Tensor, **kwargs):
5454
q, k, v = map(lambda x: rearrange(x, "... (h d) -> ... h d", d=self.head_dim), [q, k, v])
5555
if mode == "fused_chunk":
5656
q, k = self.feature_map(q), self.feature_map(k)
57-
o, _ = fused_chunk_linear_attn(q, k, v, normalize=True, scale=1, head_first=False)
57+
o, _ = fused_chunk_linear_attn(q, k, v, normalize=True, scale=1)
5858
elif mode == 'chunk':
5959
q, k = self.feature_map(q), self.feature_map(k)
60-
o, _ = chunk_linear_attn(q, k, v, normalize=True, scale=1, head_first=False)
60+
o, _ = chunk_linear_attn(q, k, v, normalize=True, scale=1)
6161
elif mode == 'parallel':
6262
assert q.shape[-1] <= 128
63-
o = parallel_based(q, k, v, scale=1, use_norm=True, head_first=False)
63+
o = parallel_based(q, k, v, scale=1, use_norm=True)
6464
o = rearrange(o, 'b t h d -> b t (h d)')
6565
o = self.o_proj(o)
6666
o = self.dropout(o)
6767
return o
6868

69-
# https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119
70-
71-
def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs):
69+
def forward_reference(self, hidden_states: torch.Tensor, **kwargs):
7270
"""
7371
x (torch.Tensor): tensor of shape (b, d, t)
7472
y (torch.Tensor): tensor of shape (b, d, t)

fla/layers/delta_net.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,6 @@ def forward(
249249
initial_state=recurrent_state,
250250
output_final_state=use_cache,
251251
cu_seqlens=cu_seqlens,
252-
head_first=False,
253252
use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False
254253
)
255254
elif mode == 'chunk':
@@ -261,7 +260,6 @@ def forward(
261260
initial_state=recurrent_state,
262261
output_final_state=use_cache,
263262
cu_seqlens=cu_seqlens,
264-
head_first=False,
265263
use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False
266264
)
267265
else:

fla/layers/gated_deltanet.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,6 @@ def forward(
256256
initial_state=recurrent_state,
257257
output_final_state=use_cache,
258258
cu_seqlens=cu_seqlens,
259-
head_first=False,
260259
use_qk_l2norm_in_kernel=True
261260
)
262261
elif mode == 'fused_recurrent':
@@ -269,7 +268,6 @@ def forward(
269268
initial_state=recurrent_state,
270269
output_final_state=use_cache,
271270
cu_seqlens=cu_seqlens,
272-
head_first=False,
273271
use_qk_l2norm_in_kernel=True
274272
)
275273
else:

fla/layers/gated_deltaproduct.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,6 @@ def forward(
307307
initial_state=recurrent_state,
308308
output_final_state=use_cache,
309309
cu_seqlens=offsets,
310-
head_first=False,
311310
use_qk_l2norm_in_kernel=True
312311
)
313312
else:
@@ -319,7 +318,6 @@ def forward(
319318
initial_state=recurrent_state,
320319
output_final_state=use_cache,
321320
cu_seqlens=offsets,
322-
head_first=False,
323321
use_qk_l2norm_in_kernel=True
324322
)
325323
else:

fla/layers/gla.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,6 @@ def forward(
235235
initial_state=recurrent_state,
236236
output_final_state=use_cache,
237237
cu_seqlens=cu_seqlens,
238-
head_first=False
239238
)
240239
elif mode == 'fused_chunk':
241240
o, recurrent_state = fused_chunk_gla(
@@ -245,7 +244,6 @@ def forward(
245244
g=gk,
246245
initial_state=recurrent_state,
247246
output_final_state=use_cache,
248-
head_first=False
249247
)
250248
elif mode == 'chunk':
251249
o, recurrent_state = chunk_gla(
@@ -256,7 +254,6 @@ def forward(
256254
initial_state=recurrent_state,
257255
output_final_state=use_cache,
258256
cu_seqlens=cu_seqlens,
259-
head_first=False
260257
)
261258
else:
262259
raise NotImplementedError(f"Not supported mode `{mode}`.")

fla/layers/gsa.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,6 @@ def forward(
190190
output_final_state=use_cache,
191191
scale=self.scale,
192192
cu_seqlens=cu_seqlens,
193-
head_first=False
194193
)
195194
elif mode == 'chunk':
196195
o, recurrent_state = chunk_gsa(
@@ -203,7 +202,6 @@ def forward(
203202
output_final_state=use_cache,
204203
scale=self.scale,
205204
cu_seqlens=cu_seqlens,
206-
head_first=False
207205
)
208206
else:
209207
raise NotImplementedError(f"Not supported mode `{mode}`.")

fla/layers/hgrn2.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,6 @@ def forward(
162162
initial_state=recurrent_state,
163163
output_final_state=use_cache,
164164
cu_seqlens=cu_seqlens,
165-
head_first=False
166165
)
167166
elif mode == 'fused_chunk':
168167
o, recurrent_state = fused_chunk_gla(
@@ -172,7 +171,6 @@ def forward(
172171
g=g,
173172
initial_state=recurrent_state,
174173
output_final_state=use_cache,
175-
head_first=False
176174
)
177175
elif mode == 'chunk':
178176
o, recurrent_state = chunk_gla(
@@ -183,7 +181,6 @@ def forward(
183181
initial_state=recurrent_state,
184182
output_final_state=use_cache,
185183
cu_seqlens=cu_seqlens,
186-
head_first=False
187184
)
188185
else:
189186
raise NotImplementedError(f"Not supported mode `{mode}`.")

fla/layers/lightnet.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,6 @@ def forward(
168168
initial_state=recurrent_state,
169169
output_final_state=use_cache,
170170
cu_seqlens=cu_seqlens,
171-
head_first=False
172171
)
173172
elif mode == 'chunk':
174173
o, recurrent_state = chunk_gla(
@@ -179,7 +178,6 @@ def forward(
179178
initial_state=recurrent_state,
180179
output_final_state=use_cache,
181180
cu_seqlens=cu_seqlens,
182-
head_first=False
183181
)
184182
else:
185183
raise NotImplementedError(f"Not supported mode `{mode}`.")

fla/layers/linear_attn.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,6 @@ def forward(
142142
k=k,
143143
v=v,
144144
normalize=self.do_feature_map_norm,
145-
head_first=False
146145
)
147146
elif mode == 'fused_chunk':
148147
o, final_state = fused_chunk_linear_attn(

0 commit comments

Comments
 (0)