Skip to content

Conversation

oyazdanb
Copy link
Contributor

taking sliding window out of ops.attention

@oyazdanb oyazdanb marked this pull request as draft September 19, 2025 20:54
Copy link
Contributor

github-actions bot commented Sep 22, 2025

Coverage report

Click to see where and how coverage changed

FileStatementsMissingCoverageCoverage
(new stmts)
Lines missing
  sharktank/sharktank/layers
  paged_attention.py
  sharktank/sharktank/ops
  attention_impls.py
  sharded_impls.py 1001, 1003
  sharktank/tests/layers
  paged_llama_attention_block_test.py
  sharktank/tests/ops
  test_attention_ops.py 309
Project Total  

This report was generated by python-coverage-comment-action

@oyazdanb oyazdanb marked this pull request as ready for review September 22, 2025 22:02
mask = torch.triu(mask, diagonal=1)
return mask.to(device)

is_prefill = kv_size == n_tokens
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generating the mask should barely matter for prefill vs decode. The only thing that makes a difference in the offset.


def build_mask(
self,
mask: Optional[torch.Tensor],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your behavior for this mask is weird. Sometimes you return it, other times you ignore it. Rethink what is needed for arguments.

sink: torch.Tensor, bs: int, n_heads: int, n_tokens: int
) -> torch.Tensor:
"""Prepare sink tensor for attention: [sink_size, n_heads] -> [bs, n_heads, n_tokens, sink_size]"""
if sink.dim() == 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just make sure the input is dim == 2 and if its the same per head materialize an unary dimension

n_tokens: int,
dtype: torch.dtype,
device: torch.device,
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears just to be a copy of the above component - replicated behavior doesn't really help for testing.

@oyazdanb oyazdanb force-pushed the users/oyazdanb/attn_sliding_window branch from ab56847 to a7d6d53 Compare September 24, 2025 19:12
oyazdanb and others added 4 commits September 24, 2025 22:33
@oyazdanb oyazdanb force-pushed the users/oyazdanb/attn_sliding_window branch 3 times, most recently from 237e2ce to ff80fcb Compare September 24, 2025 23:21
@oyazdanb oyazdanb force-pushed the users/oyazdanb/attn_sliding_window branch from ff80fcb to 51b12ee Compare September 24, 2025 23:22
@oyazdanb oyazdanb requested a review from rsuderman September 26, 2025 18:10
v_planes, quantizer=cache_quantizer, dtype=self.attn_dtype
)

effective_mask = self.build_mask(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is already mask construction here.
It is a parent function down the call stack.
I am curious what is the reason to split it like that.

Considering that the other method is called paged_attention it probably does make sense to move it out of there and put all mask construction here as the construction is not related to the paging details.

This move can probably happen in another PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

start_positions does not get propagated in the call chain. Maybe they should stay separate.

# Sink weight
sink_expanded = sink.reshape(1, 1, 1, 1).expand(1, 1, 2, 1)
attn_with_sink = torch.cat([attn_weights, sink_expanded], dim=-1)
sink_weights_full = torch.softmax(attn_with_sink, dim=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we replace all the torch ops used in this test with sharktank ops? Eg: torch.softmax -> ops.softmax

sliding_window_tensor = torch.full_like(global_q_pos, sliding_window - 1)
first_allowed_k_pos = (global_q_pos - sliding_window_tensor).clamp_min(0)
too_old = kv_positions.unsqueeze(0) < first_allowed_k_pos.unsqueeze(1)
invalid = future | too_old
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename future, too_old to future_ctx, initial_ctx, previous_ctx, accordingly?

fake_quant: Optional[bool],
softcap: Optional[float] = None,
scale: Optional[torch.Tensor | ReplicatedTensor] = None,
mask: Optional[torch.Tensor | ReplicatedTensor] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't all these types be included for mask type hint in build_mask?

if sink is not None:
sink = sink.to(q.dtype)
sink = sink.reshape(1, -1, 1, 1).expand(bs, -1, n_tokens, 1)
sink = sink.to(q.dtype).to(q.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be taken care of before sdpa, where the sink is generated?

Comment on lines +69 to +70
# Sink should match [bs, n_heads, n_tokens, sink_size] to concat with attn_weights [bs, n_heads, n_tokens, kv_size]
sink = sink.reshape(1, n_heads, 1, 1).expand(bs, n_heads, n_tokens, 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, if possible this has to be done before sink is fed to sdpa.

# write(). With a fixed stride=16 and seqlen=8 we tried to unflatten an
# 8-length dimension into (1,16) causing the RuntimeError observed:
# unflatten: Provided sizes [1, 16] don't multiply up to size 8
# Setting stride=min(16, seqlen) ensures partial (short) sequences map
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason, we mention this but don't enforce it here?

)


def test__invoke_golden_mask_cases():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double _ after "test"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants