Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 46 additions & 3 deletions sharktank/sharktank/layers/paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,41 @@ def write(
start_positions=start_positions,
)

def build_mask(
self,
mask: Optional[torch.Tensor],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your behavior for this mask is weird. Sometimes you return it, other times you ignore it. Rethink what is needed for arguments.

sliding_window: Optional[int],
kv_size: int,
n_tokens: int,
dtype: torch.dtype,
device: Optional[torch.device] = None,
):
"""
Returns a causal (and optional sliding-window) mask of shape [n_tokens, kv_size].
Future positions (k > current global query pos) are -inf; if sliding_window is set,
keys older than (current_pos - sliding_window + 1) are also -inf.
"""
neg_inf = float("-inf")
q_positions = torch.arange(n_tokens, device=device)
kv_positions = torch.arange(kv_size, device=device)
offset_tensor = torch.full_like(q_positions, kv_size - n_tokens)
global_q_pos = q_positions + offset_tensor
future = kv_positions.unsqueeze(0) > global_q_pos.unsqueeze(1)

if mask is None:
mask = torch.zeros(n_tokens, kv_size, dtype=dtype, device=device)

if sliding_window and sliding_window > 0:
sliding_window_tensor = torch.full_like(global_q_pos, sliding_window - 1)
first_allowed_k_pos = (global_q_pos - sliding_window_tensor).clamp_min(0)
too_old = kv_positions.unsqueeze(0) < first_allowed_k_pos.unsqueeze(1)
invalid = future | too_old
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename future, too_old to future_ctx, initial_ctx, previous_ctx, accordingly?

else:
invalid = future

mask = mask.masked_fill(invalid, neg_inf)
return mask

def attention(
self,
*,
Expand Down Expand Up @@ -773,17 +808,25 @@ def attention(
v_planes, quantizer=cache_quantizer, dtype=self.attn_dtype
)

effective_mask = self.build_mask(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is already mask construction here.
It is a parent function down the call stack.
I am curious what is the reason to split it like that.

Considering that the other method is called paged_attention it probably does make sense to move it out of there and put all mask construction here as the construction is not related to the paging details.

This move can probably happen in another PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

start_positions does not get propagated in the call chain. Maybe they should stay separate.

mask,
sliding_window,
k.shape[-2],
q.shape[-2],
self.attn_dtype,
mask.device if mask is not None else None,
)

return ops.scaled_dot_product_attention(
q=q, # [bs, ..., sl, dim]
k=k, # [bs, ..., sl, dim]
v=v, # [bs, ..., sl, dim]
a=mask, # [bs, ..., sl, sl] or None
is_causal=mask is None, # assumes causal masking when true
a=effective_mask, # [bs, ..., sl, sl] or None
is_causal=effective_mask is None,
scale=scale, # defaults to 1/sqrt(dim)
softcap=softcap,
impl=attention_kernel, # if none, automatically select a kernel
sink=sink,
sliding_window=sliding_window,
)

def forward_decode(
Expand Down
117 changes: 26 additions & 91 deletions sharktank/sharktank/ops/attention_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,67 +25,6 @@
from ._registry import AnyType


def build_causal_and_sw_prefill(mask_prefill, n_tokens, sliding_window, dtype, device):
if mask_prefill is None:
mask_prefill = torch.triu(
torch.full((n_tokens, n_tokens), -float("inf"), dtype=dtype, device=device),
diagonal=1,
)

if sliding_window > 0:
mask_prefill += torch.tril(
torch.full((n_tokens, n_tokens), -float("inf"), dtype=dtype, device=device),
diagonal=-sliding_window,
)
return mask_prefill


def create_mask_sliding_window(
a, attn_weights, sliding_window, n_tokens, kv_size, dtype, device
):
if sliding_window is None or sliding_window <= 0:
if a is not None:
attn_weights = attn_weights + a
return attn_weights

is_prefill = kv_size == n_tokens
if is_prefill:
# prefill path: causal mask within sliding window
a = build_causal_and_sw_prefill(
mask_prefill=a,
n_tokens=n_tokens,
sliding_window=(sliding_window or 0),
device=device,
dtype=dtype,
)

else:
# decode path
if sliding_window > 0 and kv_size > sliding_window:
start_idx = kv_size - sliding_window
neg_inf = float("-inf")
a[..., :start_idx] = neg_inf

if a is not None:
attn_weights = attn_weights + a
return attn_weights


def create_mask(a, attn_weights, is_causal):
if a is not None:
attn_weights = attn_weights + a
elif is_causal:
mask = torch.full(
(attn_weights.shape[2], attn_weights.shape[3]),
float("-inf"),
dtype=attn_weights.dtype,
device=attn_weights.device,
)
mask = torch.triu(mask, diagonal=1)[None, None, :, :]
attn_weights = attn_weights + mask
return attn_weights


# These two versions should be preserved in this order
@scaled_dot_product_attention.override(
AnyTensor,
Expand All @@ -95,7 +34,7 @@ def create_mask(a, attn_weights, is_causal):
impl_name="decomposed",
)
def scaled_dot_product_attention_decomposed(
q, k, v, a, sink, sliding_window, is_causal, scale, softcap, impl
q, k, v, a, sink, is_causal, scale, softcap, impl
):

if scale is None:
Expand All @@ -105,42 +44,39 @@ def scaled_dot_product_attention_decomposed(
k = unbox_tensor(k)
v = unbox_tensor(v)
bs, n_heads, n_tokens, head_dim = q.shape
kv_size = k.shape[-2]

attn_weights = torch.matmul(q, k.transpose(-2, -1))
attn_weights = attn_weights * scale
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale
if softcap is not None:
attn_weights = softcap * torch.tanh(attn_weights / softcap)

use_sink_path = (sink is not None) or (sliding_window is not None)
if not use_sink_path:
# standard causal/masked attention
attn_weights = create_mask(a, attn_weights, is_causal)
attn_weights = ops.softmax(attn_weights, dim=-1)
out = torch.matmul(unbox_tensor(attn_weights), v)
return out.to(q.dtype)

# sliding-window (and optional sink) path
attn_weights = create_mask_sliding_window(
a,
attn_weights=attn_weights,
n_tokens=n_tokens,
kv_size=kv_size,
sliding_window=sliding_window,
dtype=q.dtype,
device=attn_weights.device,
)
if a is not None:
attn_weights = attn_weights + a
elif is_causal:
seq_len = attn_weights.shape[-1]
causal = torch.triu(
torch.full(
(seq_len, seq_len),
float("-inf"),
device=attn_weights.device,
dtype=attn_weights.dtype,
),
diagonal=1,
)
attn_weights = attn_weights + causal

if sink is not None:
sink = sink.to(q.dtype)
sink = sink.reshape(1, -1, 1, 1).expand(bs, -1, n_tokens, 1)
sink = sink.to(q.dtype).to(q.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be taken care of before sdpa, where the sink is generated?

# Sink should match [bs, n_heads, n_tokens, sink_size] to concat with attn_weights [bs, n_heads, n_tokens, kv_size]
sink = sink.reshape(1, n_heads, 1, 1).expand(bs, n_heads, n_tokens, 1)
Comment on lines +69 to +70
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, if possible this has to be done before sink is fed to sdpa.


attn_weights = ops.cat([attn_weights, sink], dim=-1)
attn_weights = ops.softmax(attn_weights, dim=-1)[..., :-1]
attn_weights = ops.softmax(attn_weights, dim=-1)[..., : -sink.shape[-1]]
else:
attn_weights = ops.softmax(attn_weights, dim=-1)

attn_weights = unbox_tensor(attn_weights)
out = torch.matmul(attn_weights, v)

return out.to(q.dtype)


Expand All @@ -162,9 +98,9 @@ def _extract_linear_scale(t):
impl_name="sharktank",
)
def scaled_dot_product_flash_attention_sharktank(
q, k, v, a, sink, sliding_window, is_causal, scale, softcap, impl
q, k, v, a, sink, is_causal, scale, softcap, impl
):
if sliding_window is not None or sink is not None:
if sink is not None:
return NotImplemented
if softcap:
return NotImplemented
Expand Down Expand Up @@ -199,7 +135,6 @@ def scaled_dot_product_flash_attention_sharktank(
v = v.to(torch.float16)

if a is not None:
a = unbox_tensor(a)
if a.dim() == 4:
# TODO: Multiple tests are relying on inconsistent behavior of the attention mask.
# Attention mask ranks should be consistent.
Expand All @@ -217,9 +152,9 @@ def scaled_dot_product_flash_attention_sharktank(
AnyTensor, AnyTensor, AnyTensor, AnyType, impl_name="torch"
)
def scaled_dot_product_attention_torch(
q, k, v, a, sink, sliding_window, is_causal, scale, softcap, impl
q, k, v, a, sink, is_causal, scale, softcap, impl
):
if sliding_window is not None or sink is not None:
if sink is not None:
return NotImplemented
if softcap is not None:
return NotImplemented
Expand Down
4 changes: 2 additions & 2 deletions sharktank/sharktank/ops/sharded_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,9 +995,9 @@ def matmul_split(
impl_name="sharded",
)
def scaled_dot_product_attention_sharded(
q, k, v, a, sink, sliding_window, is_causal, scale, softcap, impl
q, k, v, a, sink, is_causal, scale, softcap, impl
) -> SplitPrimitiveTensor:
if sink is not None or sliding_window is not None:
if sink is not None:
return NotImplemented
if q.shard_count != k.shard_count or q.shard_count != v.shard_count:
raise ValueError("Incompatible number of shards for qkv")
Expand Down
1 change: 0 additions & 1 deletion sharktank/sharktank/ops/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,6 @@ def scaled_dot_product_attention(
v: AnyTensor,
a: Optional[AnyTensor],
sink: Optional[AnyTensor] = None,
sliding_window: Optional[AnyTensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
softcap: Optional[float] = None,
Expand Down
27 changes: 20 additions & 7 deletions sharktank/tests/layers/paged_llama_attention_block_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,9 @@ def forward(self, h, seq_block_ids, cache_state: list[torch.Tensor]):

# Shapes: (bs, seq_len, n_heads, kv_heads, head_dim)
_SHAPE_CASES = [
(1, 64, 8, 1, 64),
(2, 128, 8, 2, 64),
(1, 64, 8, 1, 64), # 4 full blocks @ stride 16
(2, 128, 8, 2, 64), # 8 full blocks @ stride 16
(1, 16, 8, 1, 64), # 1 full block exactly matching stride
]
_CONTEXT_LEN = [2048]
_DT_CASES = [
Expand All @@ -185,6 +186,7 @@ def forward(self, h, seq_block_ids, cache_state: list[torch.Tensor]):
19,
0.25,
), # sink path enabled TODO: https://github.com/nod-ai/shark-ai/issues/2156
(4, 0.5),
]


Expand All @@ -208,18 +210,18 @@ def _reference_sink_batched(q, k, v, sink, mode, sliding_window):

sink_ = sink.reshape(n_kv_heads, q_mul, 1, 1).expand(-1, -1, n_tokens, -1)
sink_ = sink_.unsqueeze(0).expand(bs, -1, -1, -1, -1)

mask = torch.triu(q_.new_full((n_tokens, n_tokens), -float("inf")), diagonal=1)
if sliding_window > 0:
if sliding_window is not None and sliding_window > 0:
mask += torch.tril(
mask.new_full((n_tokens, n_tokens), -float("inf")), diagonal=-sliding_window
)

qk_ = torch.einsum("bqhmd,bkhmd->bhmqk", q_, k_) * sm_scale
qk_ = qk_ + mask[None, None, :, :]

# Concatenate sink column and apply softmax then drop sink logits
qk_ = torch.cat([qk_, sink_], dim=-1)
w = torch.softmax(qk_, dim=-1)[..., :-1] # drop sink column
w = torch.softmax(qk_, dim=-1)[..., :-1]

attn = torch.einsum("bhmqk,bkhmd->bqhmd", w, v_)
out = attn.reshape(bs, n_tokens, n_kv_heads * q_mul, -1).permute(0, 2, 1, 3)
Expand Down Expand Up @@ -258,7 +260,7 @@ def _reference_base(q, k, v, mode):

def _make_reference_for_case(q, k, v, mode, sliding_window, sink):
# Choose the correct reference implementation for this configuration.
if (sliding_window is not None) and (sink is not None):
if (sliding_window is not None) or (sink is not None):
return _reference_sink_batched(q, k, v, sink, mode, sliding_window)
else:
return _reference_base(q, k, v, mode)
Expand Down Expand Up @@ -434,11 +436,22 @@ def test_forward_sink_eager(
):
torch.manual_seed(1234)

# Use a dynamic stride so that very short sequences (< default stride)
# still form a single full block without requiring padding. The cache
# implementation currently expects the (block_seq_len * block_seq_stride)
# product to exactly match the flattened sequence dimension passed to
# write(). With a fixed stride=16 and seqlen=8 we tried to unflatten an
# 8-length dimension into (1,16) causing the RuntimeError observed:
# unflatten: Provided sizes [1, 16] don't multiply up to size 8
# Setting stride=min(16, seqlen) ensures partial (short) sequences map
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason, we mention this but don't enforce it here?

# to (1, seqlen) which is valid.
block_seq_stride = 16

kv_cache = build_cache(
transformer_block_count=1,
attn_head_count=kv_heads,
attn_head_dim=head_dim,
block_seq_stride=16,
block_seq_stride=block_seq_stride,
cache_dtype=dtype,
)
pa = PagedGQAttention(
Expand Down
Loading
Loading