Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
4a3ef98
Revert mask functions to utils from ops
archana-ramalingam Oct 14, 2025
fa26a60
Update tests
archana-ramalingam Oct 14, 2025
3f3e839
Merge branch 'main' into refactor-mask-utils
archana-ramalingam Oct 14, 2025
e724a47
Fix tests
archana-ramalingam Oct 14, 2025
8bc203a
Merge branch 'refactor-mask-utils' of https://github.com/nod-ai/shark…
archana-ramalingam Oct 14, 2025
263be91
Organize the mask fns
archana-ramalingam Oct 14, 2025
1f83657
Merge branch 'main' into refactor-mask-utils
archana-ramalingam Oct 14, 2025
e1a1400
Remove TODOs addressed in #2293 & #2430
archana-ramalingam Oct 14, 2025
9de1c1a
Merge branch 'refactor-mask-utils' of https://github.com/nod-ai/shark…
archana-ramalingam Oct 14, 2025
ea617b7
Make mask functions pipeline parallelism compatible
archana-ramalingam Oct 15, 2025
9311256
Allow trivially_replicable to be used outside overriding only ops
archana-ramalingam Oct 16, 2025
bbecaa6
Add issue link to TODOs for chunked attention mask
archana-ramalingam Oct 16, 2025
e42a590
Merge branch 'main' into refactor-mask-utils
archana-ramalingam Oct 16, 2025
351e8c8
Merge branch 'main' into refactor-mask-utils
Alex-Vasile Oct 16, 2025
f5e331f
Merge remote-tracking branch 'origin/main' into refactor-mask-utils
Alex-Vasile Oct 20, 2025
3759635
Remove redundant to(device)
Alex-Vasile Oct 20, 2025
92acce1
Remove refactors
archana-ramalingam Oct 20, 2025
bed907f
Merge branch 'refactor-mask-utils' of https://github.com/nod-ai/shark…
archana-ramalingam Oct 20, 2025
26641ae
Apply suggestion from @Alex-Vasile
Alex-Vasile Oct 20, 2025
64da616
Apply suggestion from @Alex-Vasile
Alex-Vasile Oct 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions sharktank/sharktank/layers/paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
TensorScaledLayout,
)
from sharktank import ops, kernels
from sharktank.utils.attention import *
from sharktank.kernels.mlir_kernel import *
from sharktank.types.tensors import AnyTensor, QuantizedTensor, ReplicatedTensor
from sharktank.types.quantizers import unpack_to_raw_tensor, pack_raw_tensor
Expand Down Expand Up @@ -908,8 +909,8 @@ def paged_attention(
if is_prefill:
source_len = seq_block_ids.shape[1] * self.block_seq_stride
target_len = q.shape[1]
input_mask = ops.input_mask(seq_lens, source_len)
mask = ops.attention_mask(
input_mask = create_input_mask(seq_lens, source_len)
mask = create_attention_mask(
input_mask,
start_positions,
source_len=source_len,
Expand All @@ -918,13 +919,13 @@ def paged_attention(
)
use_chunked_attention_mask = self.attention_chunk_size is not None
if use_chunked_attention_mask and self.use_rope:
mask = ops.chunked_attention_mask(mask, self.attention_chunk_size)
mask = create_chunked_attention_mask(mask, self.attention_chunk_size)
else:
input_mask = ops.input_mask(
input_mask = create_input_mask(
seq_lens,
seq_block_ids.shape[1] * self.block_seq_stride,
)
mask = ops.attention_mask_for_decode(
mask = create_attention_mask_for_decode(
input_mask, attention_dtype=self.activation_dtype
)
if self.attention_chunk_size is not None:
Expand Down
108 changes: 0 additions & 108 deletions sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,6 @@

from sharktank.kernels.topk import iree_topk
from sharktank.ops.shape import normalize_negative_dim
from sharktank.utils.attention import (
create_boolean_chunked_attention_mask,
create_causal_context_mask,
max_negative_value,
)

from ._registry import AllOfType, AllOfExprs, AllOfExprsVariadic, IsOfType, AnyType
from .quantized_impls import quantized_tensor_layout_of_type
Expand Down Expand Up @@ -116,55 +111,6 @@ def _split_argmax(input_tensor, dim, keepdim: bool = False, chunk_size: int = 12
return final_index


def attention_mask_default(
boolean_input_mask: torch.Tensor,
start_positions: torch.Tensor | None,
*,
source_len: int,
target_len: int,
attention_dtype: torch.dtype,
) -> torch.Tensor:
device = boolean_input_mask.device

# Combine the causal context mask and input mask.
dtype = (
torch.float32 if attention_dtype == torch.float8_e4m3fnuz else attention_dtype
)
causal_mask = create_causal_context_mask(
src_len=source_len,
target_len=target_len,
start_positions=start_positions,
device=device,
)
boolean_mask = torch.logical_or(causal_mask, boolean_input_mask[:, None, None, :])
numeric_mask = torch.where(boolean_mask, max_negative_value(dtype, device), 0).to(
dtype
)
return numeric_mask.to(device)


attention_mask.override(Tensor, Tensor)(attention_mask_default)
attention_mask.override(Tensor)(attention_mask_default)


@attention_mask_for_decode.override(Tensor)
def attention_mask_for_decode_default(
boolean_input_mask: AnyTensor,
*,
attention_dtype: torch.dtype,
) -> torch.Tensor:
boolean_input_mask = unbox_tensor(boolean_input_mask)

device = boolean_input_mask.device
dtype = (
torch.float32 if attention_dtype == torch.float8_e4m3fnuz else attention_dtype
)
numeric_mask = torch.where(
boolean_input_mask, max_negative_value(dtype, device), 0
).to(dtype)
return numeric_mask.unsqueeze(1).unsqueeze(1).to(device)


@cat.override(AllOfType(Tensor, PrimitiveTensor))
def cat_default(tensors: Sequence[Tensor | PrimitiveTensor], dim: int):
result = torch.cat([unbox_tensor(t) for t in tensors], dim)
Expand All @@ -180,50 +126,6 @@ def chunk_default(
return torch.chunk(unbox_tensor(tensor), chunks, dim)


@chunked_attention_mask.override(Tensor)
def chunked_attention_mask_default(
attention_mask: torch.Tensor, attention_chunk_size: int
) -> torch.Tensor:
assert attention_mask.dim() == 4, "Attention mask must be 4-dimensional"
assert (
attention_mask.shape[1] == 1
), f"Attention mask shape[1] ({attention_mask.shape[1]}) must be 1"
s2 = attention_mask.shape[2]
s3 = attention_mask.shape[3]
assert (
s2 == s3
), f"Attention mask must be square in the last two dimensions ({s2} != {s3})"

sl = attention_mask.shape[2]
assert (
sl % attention_chunk_size == 0
), f"Sequence length ({sl}) must be divisible by attention chunk size ({attention_chunk_size})"

attention_mask = unbox_tensor(attention_mask)

device = attention_mask.device
batch_seq_len = attention_mask.shape[2]
# TODO: handle decode step
start_index = 0
end_index = batch_seq_len
chunked_boolean_attention_mask = create_boolean_chunked_attention_mask(
attention_chunk_size=attention_chunk_size,
# TODO: handle decode step
start_index=start_index,
end_index=end_index,
device=device,
)

return torch.where(
chunked_boolean_attention_mask,
attention_mask,
torch.tensor(
max_negative_value(attention_mask.dtype, device=device),
dtype=attention_mask.dtype,
),
)


# conv2d


Expand Down Expand Up @@ -657,16 +559,6 @@ def index_select_default(
return torch.index_select(unbox_tensor(tensor), dim, unbox_tensor(index))


@input_mask.override(Tensor)
def input_mask_default(seq_lens: torch.Tensor, batch_seqlen: int) -> torch.Tensor:
seq_lens = unbox_tensor(seq_lens)

range_vector = torch.arange(0, batch_seqlen, 1, device=seq_lens.device)
matrix = seq_lens.unsqueeze(dim=-1)
mask = range_vector >= matrix
return mask


@interpolate.override(Tensor)
def interpolate_default(
input: Tensor,
Expand Down
95 changes: 0 additions & 95 deletions sharktank/sharktank/ops/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,9 @@
"all_reduce",
"arange",
"argmax",
"attention_mask",
"attention_mask_for_decode",
"barrier_on_logical_device",
"cat",
"chunk",
"chunked_attention_mask",
"conv2d",
"conv3d",
"conv1d",
Expand All @@ -66,7 +63,6 @@
"index_copy_",
"index_put_",
"index_select",
"input_mask",
"interpolate",
"linear",
"masked_fill",
Expand Down Expand Up @@ -189,63 +185,6 @@ def argmax(
...


@overridable
def attention_mask(
boolean_input_mask: AnyTensor,
start_positions: AnyTensor | None = None,
*,
source_len: int,
target_len: int,
attention_dtype: torch.dtype,
) -> torch.Tensor:
"""
Generates a causal attention mask of [bs, 1, sl, sl] of activation dtype.

All masked positions are -inf and unmasked are 0.0.

The causal context mask will either be generated or use the initialization time buffer.
Since this is a bool tensor of context_length^2, different deployment
scenarios can benefit from managing this in different ways.
"""
...


@attention_mask.trampoline
def _attention_mask_trampoline(
d: SignatureDispatcher,
boolean_input_mask: AnyTensor,
start_positions: AnyTensor | None = None,
*,
source_len: int,
target_len: int,
attention_dtype: torch.dtype,
):
tensors = [boolean_input_mask]
if start_positions is not None:
tensors.append(start_positions)
for override in d.find_overrides(tensors):
result = override(
boolean_input_mask,
start_positions,
source_len=source_len,
target_len=target_len,
attention_dtype=attention_dtype,
)
if result is not NotImplemented:
return override, result
else:
d.fail(tensors)


@overridable(dispatch_args=(0,))
def attention_mask_for_decode(
boolean_input_mask: AnyTensor,
*,
attention_dtype: torch.dtype,
) -> torch.Tensor:
...


@overridable
def cat(tensors: Tuple[AnyTensor, ...] | List[AnyTensor], dim: int = 0) -> AnyTensor:
...
Expand All @@ -269,26 +208,6 @@ def chunk(tensor: AnyTensor, chunks: int, dim: int = 0) -> tuple[AnyTensor, ...]
...


@overridable(dispatch_args=(0,))
def chunked_attention_mask(
attention_mask: torch.Tensor, attention_chunk_size: int
) -> torch.Tensor:
"""
Apply a chunked attention mask onto a mask.

This is a convenience function that combines the creation of the boolean
chunked attention mask and its application to the provided attention mask.

Args:
attention_mask: The original attention mask of shape [bs, 1, sl, sl].
attention_chunk_size: The size of each attention chunk.

Returns:
A new attention mask with chunked masking applied.
"""
...


@overridable
def conv2d(
input: AnyTensor,
Expand Down Expand Up @@ -705,20 +624,6 @@ def index_select(tensor: AnyTensor, dim: int, index: AnyTensor) -> AnyTensor:
...


@overridable(dispatch_args=(0,))
def input_mask(seq_lens: AnyTensor, batch_seqlen: int) -> AnyTensor:
"""
Compute a boolean input mask for a batch of sequence lengths.

The mask will be [bs, batch_seqlen] with True at any position that is masked.

Args:
seq_lens: [bs] tensor of integers representing the sequence lengths.
batch_seqlen: The maximum sequence length in the batch.
"""
...


@overridable(dispatch_args=(0,))
def interpolate(
input: AnyTensor,
Expand Down
Loading
Loading