Skip to content

Commit 7c71c53

Browse files
authored
Add translation method for scaled dot product attention torch op (#1857)
* add translation method for scaled dot product attention torch op * refactor get_mask method out of sdpa op translation main method
1 parent 2edf48a commit 7c71c53

File tree

5 files changed

+387
-5
lines changed

5 files changed

+387
-5
lines changed

coremltools/converters/mil/frontend/torch/ops.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5779,3 +5779,114 @@ def torchvision_nms(context, node):
57795779
def tupleindex(context, node):
57805780
tuple_input, index_input = _get_inputs(context, node, expected=2)
57815781
context.add(tuple_input[index_input.val], node.name)
5782+
5783+
5784+
def _get_attn_mask(is_causal: Var, attn_mask: Var, query_var: Var, key_var: Var) -> Var:
5785+
if is_causal.val:
5786+
# create mask of shape (target_seq, source_seq)
5787+
# s.t the diagonal and lower triangular of the matrix is all 1s
5788+
# and upper triangular is a large negative number (e.g. -30k)
5789+
target_seq = query_var.shape[-2]
5790+
source_seq = key_var.shape[-2]
5791+
if is_symbolic(target_seq) or is_symbolic(source_seq):
5792+
raise NotImplementedError(
5793+
"scaled_dot_product_attention op: "
5794+
"is_causal flag not handled when sequence length is symbolic"
5795+
)
5796+
5797+
all_ones = mb.fill(value=1.0, shape=(target_seq, source_seq))
5798+
all_negative_inf = mb.fill(value=-3e4, shape=(target_seq, source_seq))
5799+
all_ones_lower = mb.band_part(
5800+
x=all_ones, lower=-1, upper=0
5801+
) # will 0 out upper triangle, excluding diag
5802+
all_negative_inf_upper = mb.band_part(
5803+
x=all_negative_inf, lower=0, upper=-1
5804+
) # will 0 out lower triangle, excluding diag
5805+
all_negative_inf_diag_only = mb.band_part(x=all_negative_inf_upper, lower=0, upper=0)
5806+
all_negative_inf_upper_no_diag = mb.sub(
5807+
x=all_negative_inf_upper, y=all_negative_inf_diag_only
5808+
)
5809+
return mb.add(x=all_ones_lower, y=all_negative_inf_upper_no_diag)
5810+
elif is_bool(attn_mask.dtype):
5811+
"""
5812+
compute float mask as:
5813+
mask = cast(bool_mask) + (1-cast(bool_mask)) * -30k*ones(shape(bool_mask))
5814+
"""
5815+
shape = mb.shape(x=attn_mask)
5816+
negative_inf = mb.fill(
5817+
shape=shape, value=_np.array([-3e4]).astype(types.nptype_from_builtin(query_var.dtype))
5818+
)
5819+
mask = mb.cast(x=attn_mask, dtype=types.builtin_to_string(query_var.dtype))
5820+
compliment_of_mask = mb.sub(
5821+
x=_np.array([1.0]).astype(types.nptype_from_builtin(mask.dtype)), y=mask
5822+
)
5823+
compliment_of_mask = mb.mul(x=negative_inf, y=compliment_of_mask)
5824+
return mb.add(x=mask, y=compliment_of_mask)
5825+
else:
5826+
return attn_mask
5827+
5828+
5829+
5830+
@register_torch_op
5831+
def scaled_dot_product_attention(context, node):
5832+
"""
5833+
Input shapes/types:
5834+
- query : (target_seq, d) or (B, target_seq, d) or (B, h, target_seq, d) or (B,.., target_seq, d)
5835+
- key : (source_seq, d) or (B, source_seq, d) or (B, h, source_seq, d) or (B,.., source_seq, d)
5836+
- value: (source_seq, d_v) or (B, source_seq, d_v) or (B, h, source_seq, d_v) or (B,.., source_seq, d_v)
5837+
- attn_mask : (target_seq, source_seq) or (B, target_seq, source_seq) or (B, h, target_seq, source_seq) or
5838+
(B, ..., target_seq, source_seq)
5839+
- is_causal : bool
5840+
5841+
Output shape: (target_seq, d_v) or (B,...,target_seq, d_v)
5842+
5843+
output = softmax(scale*Q*K^transpose + mask) * V
5844+
5845+
See details at:
5846+
https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
5847+
"""
5848+
q, k, v, attn_mask, dropout, is_causal = _get_inputs(context, node, expected=6)
5849+
if attn_mask is not None and is_causal.val:
5850+
raise ValueError(
5851+
"scaled_dot_product_attention op: attn_mask cannot be provided when is_causal is set to True."
5852+
)
5853+
5854+
# check that ranks of q, k, v and attn_mask match
5855+
if k.rank != q.rank:
5856+
raise ValueError(
5857+
"Rank of query and key do not match in scaled_dot_product_attention torch op"
5858+
)
5859+
if v.rank != q.rank:
5860+
raise ValueError(
5861+
"Rank of query and value do not match in scaled_dot_product_attention torch op"
5862+
)
5863+
5864+
is_mask_present = False
5865+
if is_causal.val or attn_mask is not None:
5866+
is_mask_present = True
5867+
mask = _get_attn_mask(is_causal, attn_mask, q, k)
5868+
5869+
# scale the query input
5870+
embed_size = q.shape[-1]
5871+
if is_symbolic(embed_size):
5872+
raise ValueError(
5873+
"The embedding size, i.e. last dimension of the shape of query tensor"
5874+
" cannot be symbolic, in scaled_dot_product_attention op"
5875+
)
5876+
multiplicative_scale_factor = 1 / _math.sqrt(embed_size)
5877+
q = mb.mul(x=q, y=multiplicative_scale_factor)
5878+
5879+
# multiply query and key input tensors
5880+
# shape of output: (target_seq, source_seq) or (B,...,target_seq, source_seq)
5881+
attn_weights = mb.matmul(x=q, y=k, transpose_y=True)
5882+
5883+
# add mask if applicable
5884+
if is_mask_present:
5885+
attn_weights = mb.add(x=attn_weights, y=mask)
5886+
5887+
# do softmax
5888+
attn_weights_normalized = mb.softmax(x=attn_weights, axis=-1)
5889+
5890+
# multiply attn_weights and value tensor
5891+
res = mb.matmul(x=attn_weights_normalized, y=v, name=node.name)
5892+
context.add(res)

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8612,3 +8612,217 @@ def forward(self, x):
86128612
self.run_compare_torch(x, OuterModel(),
86138613
input_as_shape=False, use_scripting=True,
86148614
backend=backend, compute_unit=compute_unit)
8615+
8616+
class TestScaledDotProductAttention(TorchBaseTest):
8617+
"""
8618+
Tests for torch.nn.functional.scaled_dot_product_attention op
8619+
(https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
8620+
"""
8621+
8622+
@pytest.mark.parametrize(
8623+
"compute_unit, backend, rank",
8624+
itertools.product(
8625+
compute_units,
8626+
backends,
8627+
[2, 3, 4, 5],
8628+
),
8629+
)
8630+
def test_different_input_ranks_no_mask(self, compute_unit, backend, rank):
8631+
"""
8632+
The query/key/value inputs can be any rank 2 or greater.
8633+
"""
8634+
batch_size, seq_len, n_heads_1, n_heads_2, d = 2, 10, 3, 4, 7
8635+
if rank == 2:
8636+
input_shape = (seq_len, d)
8637+
elif rank == 3:
8638+
input_shape = (batch_size, seq_len, d)
8639+
elif rank == 4:
8640+
input_shape = (batch_size, n_heads_1, seq_len, d)
8641+
elif rank == 5:
8642+
input_shape = (batch_size, n_heads_1, n_heads_1, seq_len, d)
8643+
else:
8644+
raise ValueError("invalid rank")
8645+
8646+
model = ModuleWrapper(
8647+
function=nn.functional.scaled_dot_product_attention,
8648+
kwargs={
8649+
"attn_mask": None,
8650+
"dropout_p": 0.0,
8651+
"is_causal": False,
8652+
},
8653+
)
8654+
8655+
self.run_compare_torch(
8656+
[input_shape] * 3,
8657+
model,
8658+
backend=backend,
8659+
compute_unit=compute_unit,
8660+
)
8661+
8662+
@pytest.mark.parametrize(
8663+
"compute_unit, backend, seq_lengths, include_heads",
8664+
itertools.product(
8665+
compute_units,
8666+
backends,
8667+
[(5, 5), (5, 7), (6, 4)],
8668+
[False, True],
8669+
),
8670+
)
8671+
def test_is_causal_flag(self, compute_unit, backend, seq_lengths, include_heads):
8672+
source_seq_len, target_seq_len = seq_lengths
8673+
query_shape = (2, 2, target_seq_len, 7) if include_heads else (2, target_seq_len, 7)
8674+
key_shape = (2, 2, source_seq_len, 7) if include_heads else (2, source_seq_len, 7)
8675+
value_shape = key_shape
8676+
8677+
model = ModuleWrapper(
8678+
function=nn.functional.scaled_dot_product_attention,
8679+
kwargs={
8680+
"attn_mask": None,
8681+
"is_causal": True,
8682+
},
8683+
)
8684+
res = self.run_compare_torch(
8685+
[query_shape, key_shape, value_shape],
8686+
model,
8687+
backend=backend,
8688+
compute_unit=compute_unit,
8689+
)
8690+
# check that "fill" and "band_part" ops, which are needed to compute mask, have been constant folded
8691+
mil_prog = res[1]._get_mil_internal()
8692+
# assert that "lstm" ops are present in the mil program
8693+
assert len(mil_prog.find_ops(op_type="fill")) == 0
8694+
assert len(mil_prog.find_ops(op_type="band_part")) == 0
8695+
8696+
@pytest.mark.parametrize(
8697+
"compute_unit, backend, seq_lengths, bool_mask",
8698+
itertools.product(
8699+
compute_units,
8700+
backends,
8701+
[(5, 5), (7, 5)],
8702+
[False, True],
8703+
),
8704+
)
8705+
def test_attn_mask(self, compute_unit, backend, seq_lengths, bool_mask):
8706+
source_seq_len, target_seq_len = seq_lengths
8707+
query_shape = (2, 3, target_seq_len, 7)
8708+
key_shape = (2, 3, source_seq_len, 7)
8709+
value_shape = key_shape
8710+
mask_shape = (target_seq_len, source_seq_len)
8711+
8712+
query = generate_input_data(query_shape)
8713+
key = generate_input_data(key_shape)
8714+
value = generate_input_data(value_shape)
8715+
if bool_mask:
8716+
mask = torch.rand(mask_shape) > 0.5
8717+
mask = mask.bool()
8718+
else:
8719+
mask = generate_input_data(mask_shape)
8720+
8721+
model = ModuleWrapper(function=nn.functional.scaled_dot_product_attention)
8722+
self.run_compare_torch(
8723+
(query, key, value, mask),
8724+
model,
8725+
backend=backend,
8726+
compute_unit=compute_unit,
8727+
input_as_shape=False,
8728+
)
8729+
8730+
@pytest.mark.parametrize(
8731+
"compute_unit, backend, mask_as_input",
8732+
itertools.product(
8733+
compute_units,
8734+
backends,
8735+
[True, False],
8736+
),
8737+
)
8738+
def test_toy_xformer_with_sdpa(self, compute_unit, backend, mask_as_input):
8739+
embedding_size = 32
8740+
seq_length = 16
8741+
n_heads = 4
8742+
batch_size = 2
8743+
num_blocks = 3
8744+
8745+
class AttentionBlock(nn.Module):
8746+
def __init__(self, embed_dim=embedding_size, n_head=n_heads):
8747+
super().__init__()
8748+
self.query_proj_op = nn.Linear(embed_dim, embed_dim)
8749+
self.key_proj_op = nn.Linear(embed_dim, embed_dim)
8750+
self.value_proj_op = nn.Linear(embed_dim, embed_dim)
8751+
self.out_proj_op = nn.Linear(embed_dim, embed_dim)
8752+
self.n_head = n_head
8753+
8754+
def forward(self, x, mask=None):
8755+
# in comments below for shapes, using following notation:
8756+
# B: batch_size, S: seq_length, E: embedding_size, h: n_heads
8757+
# x: (B,S,E)
8758+
# mask: (S,S)
8759+
batch_size, seq_len, dim = x.shape
8760+
query_proj = self.query_proj_op(x) # (B,S,E)
8761+
key_proj = self.key_proj_op(x) # (B,S,E)
8762+
value_proj = self.value_proj_op(x) # (B,S,E)
8763+
# reshape to (B, h, S, E/h)
8764+
query_proj = query_proj.reshape(
8765+
batch_size, seq_len, self.n_head, dim // self.n_head
8766+
).permute(
8767+
0, 2, 1, 3
8768+
) # (B, h, S, E/h)
8769+
key_proj = key_proj.reshape(
8770+
batch_size, seq_len, self.n_head, dim // self.n_head
8771+
).permute(
8772+
0, 2, 1, 3
8773+
) # (B, h, S, E/h)
8774+
value_proj = value_proj.reshape(
8775+
batch_size, seq_len, self.n_head, dim // self.n_head
8776+
).permute(
8777+
0, 2, 1, 3
8778+
) # (B, h, S, E/h)
8779+
# now do scaled dot produce attention
8780+
if mask is None:
8781+
out = nn.functional.scaled_dot_product_attention(
8782+
query_proj, key_proj, value_proj, is_causal=True
8783+
) # (B, h, S, E/h)
8784+
else:
8785+
out = nn.functional.scaled_dot_product_attention(
8786+
query_proj, key_proj, value_proj, mask
8787+
) # (B, h, S, E/h)
8788+
# reshape back to (B, S, E)
8789+
out = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, dim) # (B, S, E)
8790+
return self.out_proj_op(out)
8791+
8792+
class MLPBlock(nn.Module):
8793+
def __init__(self, embed_dim=embedding_size):
8794+
super().__init__()
8795+
self.fc1 = nn.Linear(embed_dim, embed_dim)
8796+
self.activation = nn.GELU()
8797+
self.fc2 = nn.Linear(embed_dim, embed_dim)
8798+
8799+
def forward(self, x):
8800+
x = self.fc1(x)
8801+
x = self.activation(x)
8802+
return self.fc2(x)
8803+
8804+
class ToyTransformer(nn.Module):
8805+
def __init__(self, n_blocks=num_blocks, embed_dim=embedding_size):
8806+
super().__init__()
8807+
self.attn_block = AttentionBlock(embed_dim=embed_dim)
8808+
self.mlp = MLPBlock(embed_dim=embed_dim)
8809+
self.n_blocks = n_blocks
8810+
self.lnorm = nn.LayerNorm(embed_dim)
8811+
8812+
def forward(self, x, mask=None):
8813+
for i in range(self.n_blocks):
8814+
x = self.attn_block(x, mask) + x
8815+
x = self.lnorm(x)
8816+
x = self.mlp(x) + x
8817+
x = self.lnorm(x)
8818+
return x
8819+
8820+
model = ToyTransformer()
8821+
self.run_compare_torch(
8822+
[(batch_size, seq_length, embedding_size), (seq_length, seq_length)]
8823+
if mask_as_input
8824+
else [(batch_size, seq_length, embedding_size)],
8825+
model,
8826+
backend=backend,
8827+
compute_unit=compute_unit,
8828+
)

coremltools/converters/mil/mil/ops/defs/iOS15/tensor_operation.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@
3939
class band_part(Operation):
4040
"""
4141
Returns a tensor setting everything outside a center band to zeros for the innermost
42-
matrix. Special cases:
42+
matrix. That is,
43+
band(m, n) = (lower < 0 || (m-n) <= lower) && (upper < 0 || (n-m) <= upper)
44+
output[i, j, k, ..., m, n] = band(m, n) * input[i, j, k, ..., m, n]
45+
46+
Special cases:
4347
4448
- ``band_part(x, 0, -1)`` returns upper triangular part.
4549
- ``band_part(x, -1, 0)`` returns lower triangular part.
@@ -86,6 +90,19 @@ def default_inputs(self):
8690
def type_inference(self):
8791
return self.x.sym_type
8892

93+
@precondition(allow=VALUE)
94+
def value_inference(self):
95+
M, N = self.x.val.shape[-2:]
96+
band = np.zeros((M, N), dtype=types.nptype_from_builtin(self.x.sym_type))
97+
num_lower = self.lower.val
98+
num_upper = self.upper.val
99+
for m in range(M):
100+
for n in range(N):
101+
band[m, n] = (num_lower < 0 or (m - n) <= num_lower) and (
102+
num_upper < 0 or (n - m) <= num_upper
103+
)
104+
return np.multiply(band, self.x.val)
105+
89106

90107
@register_op
91108
class cumsum(Operation):

coremltools/converters/mil/mil/ops/defs/iOS16/tensor_operation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
@register_op(opset_version=_IOS16_TARGET)
2121
class fill_like(Operation):
2222
"""
23-
Returns a tensor with the same size as the input tensor filled with a constant value.
23+
Returns a tensor with the same shape as the input tensor filled with a constant value.
2424
2525
Parameters
2626
----------
@@ -45,7 +45,7 @@ class fill_like(Operation):
4545
ref_tensor=TensorInputType(type_domain="T"),
4646
value=TensorInputType(const=True, optional=True, type_domain="U"),
4747
)
48-
48+
4949
type_domains = {
5050
"T": (types.fp16, types.fp32, types.int32, types.bool),
5151
"U": (types.fp16, types.fp32, types.int32, types.bool),

0 commit comments

Comments
 (0)