Skip to content

Commit 3ddd6b4

Browse files
authored
[torchlib] Fix sdpa dtype in attn_mask (#2445)
Discovered in benchmark that the op.Where generates fp32 output when the whole model is set to fp16.
1 parent a998e5d commit 3ddd6b4

File tree

1 file changed

+3
-3
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+3
-3
lines changed

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2069,9 +2069,9 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx(
20692069
query_scaled = op.Mul(query, op.Sqrt(scale))
20702070
key_transposed_scaled = op.Mul(key_transposed, op.Sqrt(scale))
20712071
# Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf'))
2072-
attn_mask = op.Where(
2073-
attn_mask, op.Constant(value_float=0.0), op.Constant(value_float=-float("inf"))
2074-
)
2072+
zero = op.Constant(value=ir.tensor(0.0, dtype=query.dtype))
2073+
neg_inf = op.Constant(value=ir.tensor(-float("inf"), dtype=query.dtype))
2074+
attn_mask = op.Where(attn_mask, zero, neg_inf)
20752075
attn_weight = op.Softmax(
20762076
op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask),
20772077
axis=-1,

0 commit comments

Comments
 (0)