Skip to content

Commit ff80fcb

Browse files
committed
fixing the device issue for build_mask
1 parent bdc8948 commit ff80fcb

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

sharktank/sharktank/layers/paged_attention.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ def build_mask(
745745
kv_size: int,
746746
n_tokens: int,
747747
dtype: torch.dtype,
748-
device: torch.device,
748+
device: Optional[torch.device] = None,
749749
):
750750
"""
751751
Returns a causal (and optional sliding-window) mask of shape [n_tokens, kv_size].
@@ -809,7 +809,12 @@ def attention(
809809
)
810810

811811
effective_mask = self.build_mask(
812-
mask, sliding_window, k.shape[-2], q.shape[-2], self.attn_dtype, q.device
812+
mask,
813+
sliding_window,
814+
k.shape[-2],
815+
q.shape[-2],
816+
self.attn_dtype,
817+
mask.device if mask is not None else None,
813818
)
814819

815820
return ops.scaled_dot_product_attention(

0 commit comments

Comments
 (0)