-
Notifications
You must be signed in to change notification settings - Fork 67
[Sharktank] sliding_window out of ops.attention #2293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Coverage reportClick to see where and how coverage changed
This report was generated by python-coverage-comment-action |
mask = torch.triu(mask, diagonal=1) | ||
return mask.to(device) | ||
|
||
is_prefill = kv_size == n_tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generating the mask should barely matter for prefill vs decode. The only thing that makes a difference in the offset.
|
||
def build_mask( | ||
self, | ||
mask: Optional[torch.Tensor], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your behavior for this mask is weird. Sometimes you return it, other times you ignore it. Rethink what is needed for arguments.
sink: torch.Tensor, bs: int, n_heads: int, n_tokens: int | ||
) -> torch.Tensor: | ||
"""Prepare sink tensor for attention: [sink_size, n_heads] -> [bs, n_heads, n_tokens, sink_size]""" | ||
if sink.dim() == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just make sure the input is dim == 2 and if its the same per head materialize an unary dimension
n_tokens: int, | ||
dtype: torch.dtype, | ||
device: torch.device, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This appears just to be a copy of the above component - replicated behavior doesn't really help for testing.
ab56847
to
a7d6d53
Compare
Adding moe for gpt-oss. Also cleaned moe implementions and add numeric testing for pregather.
237e2ce
to
ff80fcb
Compare
ff80fcb
to
51b12ee
Compare
v_planes, quantizer=cache_quantizer, dtype=self.attn_dtype | ||
) | ||
|
||
effective_mask = self.build_mask( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is already mask construction here.
It is a parent function down the call stack.
I am curious what is the reason to split it like that.
Considering that the other method is called paged_attention
it probably does make sense to move it out of there and put all mask construction here as the construction is not related to the paging details.
This move can probably happen in another PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
start_positions
does not get propagated in the call chain. Maybe they should stay separate.
# Sink weight | ||
sink_expanded = sink.reshape(1, 1, 1, 1).expand(1, 1, 2, 1) | ||
attn_with_sink = torch.cat([attn_weights, sink_expanded], dim=-1) | ||
sink_weights_full = torch.softmax(attn_with_sink, dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we replace all the torch ops used in this test with sharktank ops? Eg: torch.softmax
-> ops.softmax
sliding_window_tensor = torch.full_like(global_q_pos, sliding_window - 1) | ||
first_allowed_k_pos = (global_q_pos - sliding_window_tensor).clamp_min(0) | ||
too_old = kv_positions.unsqueeze(0) < first_allowed_k_pos.unsqueeze(1) | ||
invalid = future | too_old |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we rename future, too_old to future_ctx, initial_ctx, previous_ctx, accordingly?
fake_quant: Optional[bool], | ||
softcap: Optional[float] = None, | ||
scale: Optional[torch.Tensor | ReplicatedTensor] = None, | ||
mask: Optional[torch.Tensor | ReplicatedTensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't all these types be included for mask
type hint in build_mask
?
if sink is not None: | ||
sink = sink.to(q.dtype) | ||
sink = sink.reshape(1, -1, 1, 1).expand(bs, -1, n_tokens, 1) | ||
sink = sink.to(q.dtype).to(q.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be taken care of before sdpa, where the sink is generated?
# Sink should match [bs, n_heads, n_tokens, sink_size] to concat with attn_weights [bs, n_heads, n_tokens, kv_size] | ||
sink = sink.reshape(1, n_heads, 1, 1).expand(bs, n_heads, n_tokens, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above, if possible this has to be done before sink
is fed to sdpa.
# write(). With a fixed stride=16 and seqlen=8 we tried to unflatten an | ||
# 8-length dimension into (1,16) causing the RuntimeError observed: | ||
# unflatten: Provided sizes [1, 16] don't multiply up to size 8 | ||
# Setting stride=min(16, seqlen) ensures partial (short) sequences map |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason, we mention this but don't enforce it here?
) | ||
|
||
|
||
def test__invoke_golden_mask_cases(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Double _
after "test"
taking sliding window out of ops.attention