Skip to content

Commit b042f5b

Browse files
authored
Add condition to dropout and ref to isnan (#2482)
op.Dropout is only enabled when `dropout_p` is not 0, and added a reference issue discussion about why op.Where and op.IsNaN are needed when attention mask is boolean value.
1 parent e2fe5e7 commit b042f5b

File tree

1 file changed

+7
-3
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+7
-3
lines changed

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2037,7 +2037,8 @@ def _aten_scaled_dot_product_attention_no_mask_onnx(
20372037
op.MatMul(query_scaled, key_transposed_scaled),
20382038
axis=-1,
20392039
)
2040-
attn_weight, _ = op.Dropout(attn_weight, dropout_p)
2040+
if dropout_p != 0:
2041+
attn_weight, _ = op.Dropout(attn_weight, dropout_p)
20412042
return op.MatMul(attn_weight, value)
20422043

20432044

@@ -2080,8 +2081,10 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx(
20802081
# due to the presence of -inf in an entire row (padding tokens), resulting in 0/0 (NaN) in the softmax output.
20812082
# This is because there's no safe/masked softmax imp in ONNX, so we need to handle NaN values explicitly to match
20822083
# the behavior of PyTorch with boolean masks.
2084+
# Reference: https://github.com/pytorch/pytorch/issues/103749
20832085
attn_weight = op.Where(op.IsNaN(attn_weight), zero, attn_weight)
2084-
attn_weight, _ = op.Dropout(attn_weight, dropout_p)
2086+
if dropout_p != 0:
2087+
attn_weight, _ = op.Dropout(attn_weight, dropout_p)
20852088
return op.MatMul(attn_weight, value)
20862089

20872090

@@ -2116,7 +2119,8 @@ def _aten_scaled_dot_product_attention_float_mask_onnx(
21162119
op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask),
21172120
axis=-1,
21182121
)
2119-
attn_weight, _ = op.Dropout(attn_weight, dropout_p)
2122+
if dropout_p != 0:
2123+
attn_weight, _ = op.Dropout(attn_weight, dropout_p)
21202124
return op.MatMul(attn_weight, value)
21212125

21222126

0 commit comments

Comments
 (0)