-
Notifications
You must be signed in to change notification settings - Fork 68
[Sharktank] sliding_window out of ops.attention #2293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
01df949
9b825aa
4476102
bdc8948
51b12ee
1613f5b
2756718
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -738,6 +738,41 @@ def write( | |
start_positions=start_positions, | ||
) | ||
|
||
def build_mask( | ||
self, | ||
mask: Optional[torch.Tensor], | ||
sliding_window: Optional[int], | ||
kv_size: int, | ||
n_tokens: int, | ||
dtype: torch.dtype, | ||
device: Optional[torch.device] = None, | ||
): | ||
""" | ||
Returns a causal (and optional sliding-window) mask of shape [n_tokens, kv_size]. | ||
Future positions (k > current global query pos) are -inf; if sliding_window is set, | ||
keys older than (current_pos - sliding_window + 1) are also -inf. | ||
""" | ||
neg_inf = float("-inf") | ||
q_positions = torch.arange(n_tokens, device=device) | ||
kv_positions = torch.arange(kv_size, device=device) | ||
offset_tensor = torch.full_like(q_positions, kv_size - n_tokens) | ||
global_q_pos = q_positions + offset_tensor | ||
future = kv_positions.unsqueeze(0) > global_q_pos.unsqueeze(1) | ||
|
||
if mask is None: | ||
mask = torch.zeros(n_tokens, kv_size, dtype=dtype, device=device) | ||
|
||
if sliding_window and sliding_window > 0: | ||
sliding_window_tensor = torch.full_like(global_q_pos, sliding_window - 1) | ||
first_allowed_k_pos = (global_q_pos - sliding_window_tensor).clamp_min(0) | ||
too_old = kv_positions.unsqueeze(0) < first_allowed_k_pos.unsqueeze(1) | ||
invalid = future | too_old | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we rename future, too_old to future_ctx, initial_ctx, previous_ctx, accordingly? |
||
else: | ||
invalid = future | ||
|
||
mask = mask.masked_fill(invalid, neg_inf) | ||
return mask | ||
|
||
def attention( | ||
self, | ||
*, | ||
|
@@ -773,17 +808,25 @@ def attention( | |
v_planes, quantizer=cache_quantizer, dtype=self.attn_dtype | ||
) | ||
|
||
effective_mask = self.build_mask( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is already mask construction here. Considering that the other method is called This move can probably happen in another PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
mask, | ||
sliding_window, | ||
k.shape[-2], | ||
q.shape[-2], | ||
self.attn_dtype, | ||
mask.device if mask is not None else None, | ||
) | ||
|
||
return ops.scaled_dot_product_attention( | ||
q=q, # [bs, ..., sl, dim] | ||
k=k, # [bs, ..., sl, dim] | ||
v=v, # [bs, ..., sl, dim] | ||
a=mask, # [bs, ..., sl, sl] or None | ||
is_causal=mask is None, # assumes causal masking when true | ||
a=effective_mask, # [bs, ..., sl, sl] or None | ||
is_causal=effective_mask is None, | ||
scale=scale, # defaults to 1/sqrt(dim) | ||
softcap=softcap, | ||
impl=attention_kernel, # if none, automatically select a kernel | ||
sink=sink, | ||
sliding_window=sliding_window, | ||
) | ||
|
||
def forward_decode( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,67 +25,6 @@ | |
from ._registry import AnyType | ||
|
||
|
||
def build_causal_and_sw_prefill(mask_prefill, n_tokens, sliding_window, dtype, device): | ||
if mask_prefill is None: | ||
mask_prefill = torch.triu( | ||
torch.full((n_tokens, n_tokens), -float("inf"), dtype=dtype, device=device), | ||
diagonal=1, | ||
) | ||
|
||
if sliding_window > 0: | ||
mask_prefill += torch.tril( | ||
torch.full((n_tokens, n_tokens), -float("inf"), dtype=dtype, device=device), | ||
diagonal=-sliding_window, | ||
) | ||
return mask_prefill | ||
|
||
|
||
def create_mask_sliding_window( | ||
a, attn_weights, sliding_window, n_tokens, kv_size, dtype, device | ||
): | ||
if sliding_window is None or sliding_window <= 0: | ||
if a is not None: | ||
attn_weights = attn_weights + a | ||
return attn_weights | ||
|
||
is_prefill = kv_size == n_tokens | ||
if is_prefill: | ||
# prefill path: causal mask within sliding window | ||
a = build_causal_and_sw_prefill( | ||
mask_prefill=a, | ||
n_tokens=n_tokens, | ||
sliding_window=(sliding_window or 0), | ||
device=device, | ||
dtype=dtype, | ||
) | ||
|
||
else: | ||
# decode path | ||
if sliding_window > 0 and kv_size > sliding_window: | ||
start_idx = kv_size - sliding_window | ||
neg_inf = float("-inf") | ||
a[..., :start_idx] = neg_inf | ||
|
||
if a is not None: | ||
attn_weights = attn_weights + a | ||
return attn_weights | ||
|
||
|
||
def create_mask(a, attn_weights, is_causal): | ||
if a is not None: | ||
attn_weights = attn_weights + a | ||
elif is_causal: | ||
mask = torch.full( | ||
(attn_weights.shape[2], attn_weights.shape[3]), | ||
float("-inf"), | ||
dtype=attn_weights.dtype, | ||
device=attn_weights.device, | ||
) | ||
mask = torch.triu(mask, diagonal=1)[None, None, :, :] | ||
attn_weights = attn_weights + mask | ||
return attn_weights | ||
|
||
|
||
# These two versions should be preserved in this order | ||
@scaled_dot_product_attention.override( | ||
AnyTensor, | ||
|
@@ -95,7 +34,7 @@ def create_mask(a, attn_weights, is_causal): | |
impl_name="decomposed", | ||
) | ||
def scaled_dot_product_attention_decomposed( | ||
q, k, v, a, sink, sliding_window, is_causal, scale, softcap, impl | ||
q, k, v, a, sink, is_causal, scale, softcap, impl | ||
): | ||
|
||
if scale is None: | ||
|
@@ -105,42 +44,39 @@ def scaled_dot_product_attention_decomposed( | |
k = unbox_tensor(k) | ||
v = unbox_tensor(v) | ||
bs, n_heads, n_tokens, head_dim = q.shape | ||
kv_size = k.shape[-2] | ||
|
||
attn_weights = torch.matmul(q, k.transpose(-2, -1)) | ||
attn_weights = attn_weights * scale | ||
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale | ||
if softcap is not None: | ||
attn_weights = softcap * torch.tanh(attn_weights / softcap) | ||
|
||
use_sink_path = (sink is not None) or (sliding_window is not None) | ||
if not use_sink_path: | ||
# standard causal/masked attention | ||
attn_weights = create_mask(a, attn_weights, is_causal) | ||
attn_weights = ops.softmax(attn_weights, dim=-1) | ||
out = torch.matmul(unbox_tensor(attn_weights), v) | ||
return out.to(q.dtype) | ||
|
||
# sliding-window (and optional sink) path | ||
attn_weights = create_mask_sliding_window( | ||
a, | ||
attn_weights=attn_weights, | ||
n_tokens=n_tokens, | ||
kv_size=kv_size, | ||
sliding_window=sliding_window, | ||
dtype=q.dtype, | ||
device=attn_weights.device, | ||
) | ||
if a is not None: | ||
attn_weights = attn_weights + a | ||
elif is_causal: | ||
seq_len = attn_weights.shape[-1] | ||
causal = torch.triu( | ||
torch.full( | ||
(seq_len, seq_len), | ||
float("-inf"), | ||
device=attn_weights.device, | ||
dtype=attn_weights.dtype, | ||
), | ||
diagonal=1, | ||
) | ||
attn_weights = attn_weights + causal | ||
|
||
if sink is not None: | ||
sink = sink.to(q.dtype) | ||
sink = sink.reshape(1, -1, 1, 1).expand(bs, -1, n_tokens, 1) | ||
sink = sink.to(q.dtype).to(q.device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this be taken care of before sdpa, where the sink is generated? |
||
# Sink should match [bs, n_heads, n_tokens, sink_size] to concat with attn_weights [bs, n_heads, n_tokens, kv_size] | ||
sink = sink.reshape(1, n_heads, 1, 1).expand(bs, n_heads, n_tokens, 1) | ||
Comment on lines
+69
to
+70
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above, if possible this has to be done before |
||
|
||
attn_weights = ops.cat([attn_weights, sink], dim=-1) | ||
attn_weights = ops.softmax(attn_weights, dim=-1)[..., :-1] | ||
attn_weights = ops.softmax(attn_weights, dim=-1)[..., : -sink.shape[-1]] | ||
else: | ||
attn_weights = ops.softmax(attn_weights, dim=-1) | ||
|
||
attn_weights = unbox_tensor(attn_weights) | ||
out = torch.matmul(attn_weights, v) | ||
|
||
return out.to(q.dtype) | ||
|
||
|
||
|
@@ -162,9 +98,9 @@ def _extract_linear_scale(t): | |
impl_name="sharktank", | ||
) | ||
def scaled_dot_product_flash_attention_sharktank( | ||
q, k, v, a, sink, sliding_window, is_causal, scale, softcap, impl | ||
q, k, v, a, sink, is_causal, scale, softcap, impl | ||
): | ||
if sliding_window is not None or sink is not None: | ||
if sink is not None: | ||
return NotImplemented | ||
if softcap: | ||
return NotImplemented | ||
|
@@ -199,7 +135,6 @@ def scaled_dot_product_flash_attention_sharktank( | |
v = v.to(torch.float16) | ||
|
||
if a is not None: | ||
a = unbox_tensor(a) | ||
if a.dim() == 4: | ||
# TODO: Multiple tests are relying on inconsistent behavior of the attention mask. | ||
# Attention mask ranks should be consistent. | ||
|
@@ -217,9 +152,9 @@ def scaled_dot_product_flash_attention_sharktank( | |
AnyTensor, AnyTensor, AnyTensor, AnyType, impl_name="torch" | ||
) | ||
def scaled_dot_product_attention_torch( | ||
q, k, v, a, sink, sliding_window, is_causal, scale, softcap, impl | ||
q, k, v, a, sink, is_causal, scale, softcap, impl | ||
): | ||
if sliding_window is not None or sink is not None: | ||
if sink is not None: | ||
return NotImplemented | ||
if softcap is not None: | ||
return NotImplemented | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -168,8 +168,9 @@ def forward(self, h, seq_block_ids, cache_state: list[torch.Tensor]): | |
|
||
# Shapes: (bs, seq_len, n_heads, kv_heads, head_dim) | ||
_SHAPE_CASES = [ | ||
(1, 64, 8, 1, 64), | ||
(2, 128, 8, 2, 64), | ||
(1, 64, 8, 1, 64), # 4 full blocks @ stride 16 | ||
(2, 128, 8, 2, 64), # 8 full blocks @ stride 16 | ||
(1, 16, 8, 1, 64), # 1 full block exactly matching stride | ||
] | ||
_CONTEXT_LEN = [2048] | ||
_DT_CASES = [ | ||
|
@@ -185,6 +186,7 @@ def forward(self, h, seq_block_ids, cache_state: list[torch.Tensor]): | |
19, | ||
0.25, | ||
), # sink path enabled TODO: https://github.com/nod-ai/shark-ai/issues/2156 | ||
(4, 0.5), | ||
] | ||
|
||
|
||
|
@@ -208,18 +210,18 @@ def _reference_sink_batched(q, k, v, sink, mode, sliding_window): | |
|
||
sink_ = sink.reshape(n_kv_heads, q_mul, 1, 1).expand(-1, -1, n_tokens, -1) | ||
sink_ = sink_.unsqueeze(0).expand(bs, -1, -1, -1, -1) | ||
|
||
mask = torch.triu(q_.new_full((n_tokens, n_tokens), -float("inf")), diagonal=1) | ||
if sliding_window > 0: | ||
if sliding_window is not None and sliding_window > 0: | ||
mask += torch.tril( | ||
mask.new_full((n_tokens, n_tokens), -float("inf")), diagonal=-sliding_window | ||
) | ||
|
||
qk_ = torch.einsum("bqhmd,bkhmd->bhmqk", q_, k_) * sm_scale | ||
qk_ = qk_ + mask[None, None, :, :] | ||
|
||
# Concatenate sink column and apply softmax then drop sink logits | ||
qk_ = torch.cat([qk_, sink_], dim=-1) | ||
w = torch.softmax(qk_, dim=-1)[..., :-1] # drop sink column | ||
w = torch.softmax(qk_, dim=-1)[..., :-1] | ||
|
||
attn = torch.einsum("bhmqk,bkhmd->bqhmd", w, v_) | ||
out = attn.reshape(bs, n_tokens, n_kv_heads * q_mul, -1).permute(0, 2, 1, 3) | ||
|
@@ -258,7 +260,7 @@ def _reference_base(q, k, v, mode): | |
|
||
def _make_reference_for_case(q, k, v, mode, sliding_window, sink): | ||
# Choose the correct reference implementation for this configuration. | ||
if (sliding_window is not None) and (sink is not None): | ||
if (sliding_window is not None) or (sink is not None): | ||
return _reference_sink_batched(q, k, v, sink, mode, sliding_window) | ||
else: | ||
return _reference_base(q, k, v, mode) | ||
|
@@ -434,11 +436,22 @@ def test_forward_sink_eager( | |
): | ||
torch.manual_seed(1234) | ||
|
||
# Use a dynamic stride so that very short sequences (< default stride) | ||
# still form a single full block without requiring padding. The cache | ||
# implementation currently expects the (block_seq_len * block_seq_stride) | ||
# product to exactly match the flattened sequence dimension passed to | ||
# write(). With a fixed stride=16 and seqlen=8 we tried to unflatten an | ||
# 8-length dimension into (1,16) causing the RuntimeError observed: | ||
# unflatten: Provided sizes [1, 16] don't multiply up to size 8 | ||
# Setting stride=min(16, seqlen) ensures partial (short) sequences map | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason, we mention this but don't enforce it here? |
||
# to (1, seqlen) which is valid. | ||
block_seq_stride = 16 | ||
|
||
kv_cache = build_cache( | ||
transformer_block_count=1, | ||
attn_head_count=kv_heads, | ||
attn_head_dim=head_dim, | ||
block_seq_stride=16, | ||
block_seq_stride=block_seq_stride, | ||
cache_dtype=dtype, | ||
) | ||
pa = PagedGQAttention( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your behavior for this mask is weird. Sometimes you return it, other times you ignore it. Rethink what is needed for arguments.