diff --git a/sharktank/sharktank/layers/paged_attention.py b/sharktank/sharktank/layers/paged_attention.py index ac4f4d13da..4983dc8d77 100644 --- a/sharktank/sharktank/layers/paged_attention.py +++ b/sharktank/sharktank/layers/paged_attention.py @@ -738,6 +738,41 @@ def write( start_positions=start_positions, ) + def build_mask( + self, + mask: Optional[torch.Tensor], + sliding_window: Optional[int], + kv_size: int, + n_tokens: int, + dtype: torch.dtype, + device: Optional[torch.device] = None, + ): + """ + Returns a causal (and optional sliding-window) mask of shape [n_tokens, kv_size]. + Future positions (k > current global query pos) are -inf; if sliding_window is set, + keys older than (current_pos - sliding_window + 1) are also -inf. + """ + neg_inf = float("-inf") + q_positions = torch.arange(n_tokens, device=device) + kv_positions = torch.arange(kv_size, device=device) + offset_tensor = torch.full_like(q_positions, kv_size - n_tokens) + global_q_pos = q_positions + offset_tensor + future = kv_positions.unsqueeze(0) > global_q_pos.unsqueeze(1) + + if mask is None: + mask = torch.zeros(n_tokens, kv_size, dtype=dtype, device=device) + + if sliding_window and sliding_window > 0: + sliding_window_tensor = torch.full_like(global_q_pos, sliding_window - 1) + first_allowed_k_pos = (global_q_pos - sliding_window_tensor).clamp_min(0) + too_old = kv_positions.unsqueeze(0) < first_allowed_k_pos.unsqueeze(1) + invalid = future | too_old + else: + invalid = future + + mask = mask.masked_fill(invalid, neg_inf) + return mask + def attention( self, *, @@ -773,17 +808,25 @@ def attention( v_planes, quantizer=cache_quantizer, dtype=self.attn_dtype ) + effective_mask = self.build_mask( + mask, + sliding_window, + k.shape[-2], + q.shape[-2], + self.attn_dtype, + mask.device if mask is not None else None, + ) + return ops.scaled_dot_product_attention( q=q, # [bs, ..., sl, dim] k=k, # [bs, ..., sl, dim] v=v, # [bs, ..., sl, dim] - a=mask, # [bs, ..., sl, sl] or None - is_causal=mask is None, # assumes causal masking when true + a=effective_mask, # [bs, ..., sl, sl] or None + is_causal=effective_mask is None, scale=scale, # defaults to 1/sqrt(dim) softcap=softcap, impl=attention_kernel, # if none, automatically select a kernel sink=sink, - sliding_window=sliding_window, ) def forward_decode( diff --git a/sharktank/sharktank/ops/attention_impls.py b/sharktank/sharktank/ops/attention_impls.py index f48a29b3ec..0714d331c3 100644 --- a/sharktank/sharktank/ops/attention_impls.py +++ b/sharktank/sharktank/ops/attention_impls.py @@ -25,67 +25,6 @@ from ._registry import AnyType -def build_causal_and_sw_prefill(mask_prefill, n_tokens, sliding_window, dtype, device): - if mask_prefill is None: - mask_prefill = torch.triu( - torch.full((n_tokens, n_tokens), -float("inf"), dtype=dtype, device=device), - diagonal=1, - ) - - if sliding_window > 0: - mask_prefill += torch.tril( - torch.full((n_tokens, n_tokens), -float("inf"), dtype=dtype, device=device), - diagonal=-sliding_window, - ) - return mask_prefill - - -def create_mask_sliding_window( - a, attn_weights, sliding_window, n_tokens, kv_size, dtype, device -): - if sliding_window is None or sliding_window <= 0: - if a is not None: - attn_weights = attn_weights + a - return attn_weights - - is_prefill = kv_size == n_tokens - if is_prefill: - # prefill path: causal mask within sliding window - a = build_causal_and_sw_prefill( - mask_prefill=a, - n_tokens=n_tokens, - sliding_window=(sliding_window or 0), - device=device, - dtype=dtype, - ) - - else: - # decode path - if sliding_window > 0 and kv_size > sliding_window: - start_idx = kv_size - sliding_window - neg_inf = float("-inf") - a[..., :start_idx] = neg_inf - - if a is not None: - attn_weights = attn_weights + a - return attn_weights - - -def create_mask(a, attn_weights, is_causal): - if a is not None: - attn_weights = attn_weights + a - elif is_causal: - mask = torch.full( - (attn_weights.shape[2], attn_weights.shape[3]), - float("-inf"), - dtype=attn_weights.dtype, - device=attn_weights.device, - ) - mask = torch.triu(mask, diagonal=1)[None, None, :, :] - attn_weights = attn_weights + mask - return attn_weights - - # These two versions should be preserved in this order @scaled_dot_product_attention.override( AnyTensor, @@ -95,7 +34,7 @@ def create_mask(a, attn_weights, is_causal): impl_name="decomposed", ) def scaled_dot_product_attention_decomposed( - q, k, v, a, sink, sliding_window, is_causal, scale, softcap, impl + q, k, v, a, sink, is_causal, scale, softcap, impl ): if scale is None: @@ -105,42 +44,39 @@ def scaled_dot_product_attention_decomposed( k = unbox_tensor(k) v = unbox_tensor(v) bs, n_heads, n_tokens, head_dim = q.shape - kv_size = k.shape[-2] - attn_weights = torch.matmul(q, k.transpose(-2, -1)) - attn_weights = attn_weights * scale + attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale if softcap is not None: attn_weights = softcap * torch.tanh(attn_weights / softcap) - use_sink_path = (sink is not None) or (sliding_window is not None) - if not use_sink_path: - # standard causal/masked attention - attn_weights = create_mask(a, attn_weights, is_causal) - attn_weights = ops.softmax(attn_weights, dim=-1) - out = torch.matmul(unbox_tensor(attn_weights), v) - return out.to(q.dtype) - - # sliding-window (and optional sink) path - attn_weights = create_mask_sliding_window( - a, - attn_weights=attn_weights, - n_tokens=n_tokens, - kv_size=kv_size, - sliding_window=sliding_window, - dtype=q.dtype, - device=attn_weights.device, - ) + if a is not None: + attn_weights = attn_weights + a + elif is_causal: + seq_len = attn_weights.shape[-1] + causal = torch.triu( + torch.full( + (seq_len, seq_len), + float("-inf"), + device=attn_weights.device, + dtype=attn_weights.dtype, + ), + diagonal=1, + ) + attn_weights = attn_weights + causal if sink is not None: - sink = sink.to(q.dtype) - sink = sink.reshape(1, -1, 1, 1).expand(bs, -1, n_tokens, 1) + sink = sink.to(q.dtype).to(q.device) + # Sink should match [bs, n_heads, n_tokens, sink_size] to concat with attn_weights [bs, n_heads, n_tokens, kv_size] + sink = sink.reshape(1, n_heads, 1, 1).expand(bs, n_heads, n_tokens, 1) + attn_weights = ops.cat([attn_weights, sink], dim=-1) - attn_weights = ops.softmax(attn_weights, dim=-1)[..., :-1] + attn_weights = ops.softmax(attn_weights, dim=-1)[..., : -sink.shape[-1]] else: attn_weights = ops.softmax(attn_weights, dim=-1) attn_weights = unbox_tensor(attn_weights) out = torch.matmul(attn_weights, v) + return out.to(q.dtype) @@ -162,9 +98,9 @@ def _extract_linear_scale(t): impl_name="sharktank", ) def scaled_dot_product_flash_attention_sharktank( - q, k, v, a, sink, sliding_window, is_causal, scale, softcap, impl + q, k, v, a, sink, is_causal, scale, softcap, impl ): - if sliding_window is not None or sink is not None: + if sink is not None: return NotImplemented if softcap: return NotImplemented @@ -199,7 +135,6 @@ def scaled_dot_product_flash_attention_sharktank( v = v.to(torch.float16) if a is not None: - a = unbox_tensor(a) if a.dim() == 4: # TODO: Multiple tests are relying on inconsistent behavior of the attention mask. # Attention mask ranks should be consistent. @@ -217,9 +152,9 @@ def scaled_dot_product_flash_attention_sharktank( AnyTensor, AnyTensor, AnyTensor, AnyType, impl_name="torch" ) def scaled_dot_product_attention_torch( - q, k, v, a, sink, sliding_window, is_causal, scale, softcap, impl + q, k, v, a, sink, is_causal, scale, softcap, impl ): - if sliding_window is not None or sink is not None: + if sink is not None: return NotImplemented if softcap is not None: return NotImplemented diff --git a/sharktank/sharktank/ops/sharded_impls.py b/sharktank/sharktank/ops/sharded_impls.py index 8576ab5e20..2ba2c7a625 100644 --- a/sharktank/sharktank/ops/sharded_impls.py +++ b/sharktank/sharktank/ops/sharded_impls.py @@ -995,9 +995,9 @@ def matmul_split( impl_name="sharded", ) def scaled_dot_product_attention_sharded( - q, k, v, a, sink, sliding_window, is_causal, scale, softcap, impl + q, k, v, a, sink, is_causal, scale, softcap, impl ) -> SplitPrimitiveTensor: - if sink is not None or sliding_window is not None: + if sink is not None: return NotImplemented if q.shard_count != k.shard_count or q.shard_count != v.shard_count: raise ValueError("Incompatible number of shards for qkv") diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index 59ea9266ff..ba1c0c41aa 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -896,7 +896,6 @@ def scaled_dot_product_attention( v: AnyTensor, a: Optional[AnyTensor], sink: Optional[AnyTensor] = None, - sliding_window: Optional[AnyTensor] = None, is_causal: bool = False, scale: Optional[float] = None, softcap: Optional[float] = None, diff --git a/sharktank/tests/layers/paged_llama_attention_block_test.py b/sharktank/tests/layers/paged_llama_attention_block_test.py index bae9f64983..c44079f156 100644 --- a/sharktank/tests/layers/paged_llama_attention_block_test.py +++ b/sharktank/tests/layers/paged_llama_attention_block_test.py @@ -168,8 +168,9 @@ def forward(self, h, seq_block_ids, cache_state: list[torch.Tensor]): # Shapes: (bs, seq_len, n_heads, kv_heads, head_dim) _SHAPE_CASES = [ - (1, 64, 8, 1, 64), - (2, 128, 8, 2, 64), + (1, 64, 8, 1, 64), # 4 full blocks @ stride 16 + (2, 128, 8, 2, 64), # 8 full blocks @ stride 16 + (1, 16, 8, 1, 64), # 1 full block exactly matching stride ] _CONTEXT_LEN = [2048] _DT_CASES = [ @@ -185,6 +186,7 @@ def forward(self, h, seq_block_ids, cache_state: list[torch.Tensor]): 19, 0.25, ), # sink path enabled TODO: https://github.com/nod-ai/shark-ai/issues/2156 + (4, 0.5), ] @@ -208,9 +210,8 @@ def _reference_sink_batched(q, k, v, sink, mode, sliding_window): sink_ = sink.reshape(n_kv_heads, q_mul, 1, 1).expand(-1, -1, n_tokens, -1) sink_ = sink_.unsqueeze(0).expand(bs, -1, -1, -1, -1) - mask = torch.triu(q_.new_full((n_tokens, n_tokens), -float("inf")), diagonal=1) - if sliding_window > 0: + if sliding_window is not None and sliding_window > 0: mask += torch.tril( mask.new_full((n_tokens, n_tokens), -float("inf")), diagonal=-sliding_window ) @@ -218,8 +219,9 @@ def _reference_sink_batched(q, k, v, sink, mode, sliding_window): qk_ = torch.einsum("bqhmd,bkhmd->bhmqk", q_, k_) * sm_scale qk_ = qk_ + mask[None, None, :, :] + # Concatenate sink column and apply softmax then drop sink logits qk_ = torch.cat([qk_, sink_], dim=-1) - w = torch.softmax(qk_, dim=-1)[..., :-1] # drop sink column + w = torch.softmax(qk_, dim=-1)[..., :-1] attn = torch.einsum("bhmqk,bkhmd->bqhmd", w, v_) out = attn.reshape(bs, n_tokens, n_kv_heads * q_mul, -1).permute(0, 2, 1, 3) @@ -258,7 +260,7 @@ def _reference_base(q, k, v, mode): def _make_reference_for_case(q, k, v, mode, sliding_window, sink): # Choose the correct reference implementation for this configuration. - if (sliding_window is not None) and (sink is not None): + if (sliding_window is not None) or (sink is not None): return _reference_sink_batched(q, k, v, sink, mode, sliding_window) else: return _reference_base(q, k, v, mode) @@ -434,11 +436,22 @@ def test_forward_sink_eager( ): torch.manual_seed(1234) + # Use a dynamic stride so that very short sequences (< default stride) + # still form a single full block without requiring padding. The cache + # implementation currently expects the (block_seq_len * block_seq_stride) + # product to exactly match the flattened sequence dimension passed to + # write(). With a fixed stride=16 and seqlen=8 we tried to unflatten an + # 8-length dimension into (1,16) causing the RuntimeError observed: + # unflatten: Provided sizes [1, 16] don't multiply up to size 8 + # Setting stride=min(16, seqlen) ensures partial (short) sequences map + # to (1, seqlen) which is valid. + block_seq_stride = 16 + kv_cache = build_cache( transformer_block_count=1, attn_head_count=kv_heads, attn_head_dim=head_dim, - block_seq_stride=16, + block_seq_stride=block_seq_stride, cache_dtype=dtype, ) pa = PagedGQAttention( diff --git a/sharktank/tests/ops/test_attention_ops.py b/sharktank/tests/ops/test_attention_ops.py index 3d8fb770ca..63444c6fd4 100644 --- a/sharktank/tests/ops/test_attention_ops.py +++ b/sharktank/tests/ops/test_attention_ops.py @@ -9,10 +9,11 @@ import unittest import torch from parameterized import parameterized - from sharktank import ops from sharktank.ops import attention_impls from sharktank.utils.testing import OpComparisonTestBase, OpTestConfig +from sharktank.layers.paged_attention import PagedMHAttention +import math class TestScaledDotProductAttention(OpComparisonTestBase): @@ -21,18 +22,16 @@ class TestScaledDotProductAttention(OpComparisonTestBase): @parameterized.expand( [ # No causal, no mask - (2, 8, 128, 64, torch.float16, False, False, None, None, None, None), - (2, 8, 128, 64, torch.float32, False, False, None, None, None, None), + (2, 8, 128, 64, torch.float16, False, False, None, None), + (2, 8, 128, 64, torch.float32, False, False, None, None), # Test causal attention - (2, 8, 128, 64, torch.float16, True, False, None, None, None, None), - (2, 8, 128, 64, torch.float16, True, False, 0.125, None, None, None), - # Test explicit masking - (2, 8, 128, 64, torch.float16, False, True, None, None, None, None), - (2, 8, 256, 64, torch.float32, False, True, None, None, None, None), - # Test softcap - (1, 4, 64, 32, torch.float32, False, False, None, 50.0, None, None), - # Test Sink and Sliding Window - (2, 8, 128, 64, torch.bfloat16, True, False, None, None, 0.25, 19), + (2, 8, 128, 64, torch.float16, True, False, None, None), + (2, 8, 128, 64, torch.float16, True, False, 0.125, None), + # Test explicit masking (full causal mask passed explicitly) + (2, 8, 128, 64, torch.float16, False, True, None, None), + (2, 8, 256, 64, torch.float32, False, True, None, None), + # Test softcap (no causal, no mask) + (1, 4, 64, 32, torch.float32, False, False, None, 50.0), ] ) def test_attention_variants( @@ -46,8 +45,6 @@ def test_attention_variants( has_mask, scale, softcap, - sink_scale, - sliding_window, ): """Test attention with various configurations.""" torch.manual_seed(42) @@ -56,27 +53,16 @@ def test_attention_variants( v = torch.randn(batch, heads, seq_len, head_dim, dtype=dtype) if has_mask: - # Create a simple attention mask with shape [1, 1, seq_len, seq_len] - # This broadcasts across all batches and heads + # Explicit full causal mask (no sliding window) for regression against torch impls mask = torch.triu(torch.ones(seq_len, seq_len) * float("-inf"), diagonal=1) - mask = mask.unsqueeze(0).unsqueeze(0) - a = mask.to(dtype) + a = mask.unsqueeze(0).unsqueeze(0).to(dtype) + else: a = None - unsupported = ( - (softcap is not None) - or (sink_scale is not None) - or (sliding_window is not None) - ) + unsupported = softcap is not None fail_on_not_implemented = not unsupported - sink = ( - torch.full((1, heads), sink_scale, dtype=q.dtype) - if sink_scale is not None - else None - ) - if dtype in (torch.float16, torch.bfloat16): atol, rtol = 3e-2, 3e-2 else: @@ -86,14 +72,13 @@ def test_attention_variants( op=ops.scaled_dot_product_attention, reference_impl=attention_impls.scaled_dot_product_attention_decomposed, test_impls="all", - args=[q, k, v, a], + # Provide placeholder sink=None as positional argument expected by reference impl + args=[q, k, v, a, None], kwargs={ "is_causal": is_causal, "scale": scale, "softcap": softcap, "impl": None, - "sink": sink, - "sliding_window": sliding_window, }, atol=atol, rtol=rtol, @@ -102,5 +87,223 @@ def test_attention_variants( self.compare_implementations(config) +class TestSlidingWindowMaskGolden(unittest.TestCase): + def test_causal_mask(self): + + mask = PagedMHAttention.build_mask( + PagedMHAttention, + None, + None, + kv_size=4, + n_tokens=4, + dtype=torch.float32, + device=torch.device("cpu"), + ) + + # Each query can only see keys up to its own position + expected_finite_keys = [ + [0], # query 0 sees key 0 + [0, 1], # query 1 sees keys 0,1 + [0, 1, 2], # query 2 sees keys 0,1,2 + [0, 1, 2, 3], # query 3 sees keys 0,1,2,3 + ] + + self._check_mask_pattern(mask, expected_finite_keys) + + def test_sliding_window_mask(self): + """Test sliding window masking where kv_size > sliding_window.""" + mask = PagedMHAttention.build_mask( + PagedMHAttention, + None, + sliding_window=2, + kv_size=5, + n_tokens=5, + dtype=torch.float32, + device=torch.device("cpu"), + ) + + # Each query sees a sliding window of 2 keys around its position + # kv_size=5 > sliding_window=2, so window is constrained by sliding_window + expected_finite_keys = [ + [0], # query 0: window starts at 0 + [0, 1], # query 1: window covers 0,1 + [1, 2], # query 2: window covers 1,2 + [2, 3], # query 3: window covers 2,3 + [3, 4], # query 4: window covers 3,4 + ] + + self._check_mask_pattern(mask, expected_finite_keys) + + def test_sliding_window_larger_than_kv(self): + + mask = PagedMHAttention.build_mask( + PagedMHAttention, + None, + sliding_window=6, + kv_size=4, + n_tokens=4, + dtype=torch.float32, + device=torch.device("cpu"), + ) + + # When sliding_window > kv_size, window is effectively unconstrained by sliding_window + # Should behave like causal masking since window is larger than sequence + expected_finite_keys = [ + [0], # query 0 sees key 0 + [0, 1], # query 1 sees keys 0,1 + [0, 1, 2], # query 2 sees keys 0,1,2 + [0, 1, 2, 3], # query 3 sees keys 0,1,2,3 + ] + + self._check_mask_pattern(mask, expected_finite_keys) + + def _check_mask_pattern(self, mask, expected_finite_keys): + for qi, expected_keys in enumerate(expected_finite_keys): + row = mask[qi] + finite_keys = (row > float("-inf")).nonzero(as_tuple=True)[0].tolist() + self.assertEqual( + finite_keys, + expected_keys, + f"Query {qi}: expected keys {expected_keys}, got {finite_keys}", + ) + + +class TestSinkAttentionGolden(unittest.TestCase): + def test_sink_vs_no_sink_difference(self): + torch.manual_seed(42) + + # 4D tensors: (batch=1, n_heads=1, n_tokens=2, head_dim=2) + q = torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]) + k = torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]) + v = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]]) + + # Regular causal attention + regular_result = ops.scaled_dot_product_attention(q, k, v, None, is_causal=True) + + # Sink attention + sink = torch.tensor([0.5]) + sink_result = ops.scaled_dot_product_attention( + q, k, v, None, is_causal=True, sink=sink + ) + + # Results should be different when sink is applied + self.assertFalse(torch.allclose(regular_result, sink_result, atol=1e-6)) + + # Both should have same shape + self.assertEqual(regular_result.shape, sink_result.shape) + + def test_sink_softmax_behavior(self): + """ + SINK ATTENTION DIFFERENCE: + - Regular: softmax(attn_weights) -> each row sums to 1.0 + - Sink: softmax(cat([attn_weights, sink_value])) then slice off sink portion + -> each row sums to LESS than 1.0 (sink absorbed probability mass) + VISUAL EXAMPLE with 2x2 attention matrix: + Regular attention weights after softmax: + Query 0: [1.0, 0.0] <- causal mask hides key 1 + Query 1: [0.33, 0.67] <- can see both keys, sums to 1.0 + + Sink attention (with sink=0.5): + Step 1: Concat sink -> [[weights, 0.5], [weights, 0.5]] + Step 2: Softmax entire matrix -> normalizes including sink + Step 3: Slice off sink column -> weights now sum < 1.0 + Query 0: [0.55, 0.0] <- less than 1.0! sink absorbed 0.45 + Query 1: [0.21, 0.43] <- less than 1.0! sink absorbed ~0.36 + + Simple explicit calculation (Query 0 only): + Scaled logits (key0, key1(masked), sink) = [0.7071, -inf, 0.5] + exp = [exp(0.7071)=2.0281, exp(-inf)=0, exp(0.5)=1.6487] + Denominator = 2.0281 + 0 + 1.6487 = 3.6768 + Softmax (before slicing) = [2.0281/3.6768=0.5511, 0, 1.6487/3.6768=0.4489] + After slicing off sink column -> retained = [0.5511, 0.0]; sink absorbed 0.4489 (~0.45) + """ + + # 4D tensors: (batch=1, n_heads=1, n_tokens=2, head_dim=2) + q = torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], dtype=torch.float32) + k = torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], dtype=torch.float32) + v = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32) + sink = torch.tensor([0.5], dtype=torch.float32) + + # Extract tensor dimensions + bs, n_heads, n_tokens, head_dim = q.shape + + # Manual computation to verify sink softmax behavior + scale = 1.0 / math.sqrt(head_dim) + attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale + causal_mask = torch.triu( + torch.full((n_tokens, n_tokens), float("-inf")), diagonal=1 + ) + attn_weights = attn_weights + causal_mask + + # Now test the key difference in sink attention: + # Regular softmax (without sink) + regular_weights = torch.softmax(attn_weights, dim=-1) + + # Sink weight + sink_expanded = sink.reshape(1, 1, 1, 1).expand(1, 1, 2, 1) + attn_with_sink = torch.cat([attn_weights, sink_expanded], dim=-1) + sink_weights_full = torch.softmax(attn_with_sink, dim=-1) + sink_weights = sink_weights_full[..., :-1] # slice off sink portion + + # Verify shapes + self.assertEqual(regular_weights.shape, (1, 1, 2, 2)) + self.assertEqual(sink_weights.shape, (1, 1, 2, 2)) + + # The sink softmax should produce different attention weights + self.assertFalse(torch.allclose(regular_weights, sink_weights, atol=1e-6)) + + # Regular: With causal mask, query 0 can only see key 0, so gets weight 1.0 + # Sink: The sink value competes in softmax, reducing the weight for key 0 + # NOTE on tensor shape: attention weights are 4D = (batch, head, query_index, key_index). + # In this tiny example batch=head=1, so we index as [0, 0, q, k] to pick the scalar for (query q -> key k). + query_0_key_0_regular = regular_weights[0, 0, 0, 0].item() # Query 0 -> Key 0 + query_0_key_1_regular = regular_weights[ + 0, 0, 0, 1 + ].item() # Query 0 -> Key 1 (masked) + + query_0_key_0_sink = sink_weights[ + 0, 0, 0, 0 + ].item() # Query 0 -> Key 0 with sink + query_0_key_1_sink = sink_weights[ + 0, 0, 0, 1 + ].item() # Query 0 -> Key 1 (still masked) + + # Regular: query 0 gives full attention (1.0) to key 0, zero to masked key 1 + self.assertAlmostEqual(query_0_key_0_regular, 1.0, places=6) + self.assertAlmostEqual(query_0_key_1_regular, 0.0, places=6) + + # Sink: query 0 gives LESS attention to key 0 (sink absorbed some probability) + self.assertLess( + query_0_key_0_sink, 1.0 + ) # Less than 1.0 due to sink competition + self.assertAlmostEqual( + query_0_key_1_sink, 0.0, places=6 + ) # Still masked (causal) + + # Check sdpa with explicit causal mask + mask = torch.triu(torch.full((n_tokens, n_tokens), float("-inf")), diagonal=1) + mask = mask.unsqueeze(0).unsqueeze(0) + + production_with_mask = ops.scaled_dot_product_attention( + q, k, v, mask, is_causal=False, sink=sink + ) + expected_result = torch.matmul(sink_weights, v) + self.assertTrue( + torch.allclose(production_with_mask, expected_result, atol=1e-5) + ) + + +def test__invoke_golden_mask_cases(): + """Bridge test so pytest can invoke golden mask checks explicitly.""" + g = TestSlidingWindowMaskGolden() + g.test_causal_mask() + g.test_sliding_window_mask() + g.test_sliding_window_larger_than_kv() + + s = TestSinkAttentionGolden() + s.test_sink_vs_no_sink_difference() + s.test_sink_softmax_behavior() + + if __name__ == "__main__": unittest.main()