File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed
onnxscript/function_libs/torch_lib/ops Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -2069,9 +2069,9 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx(
2069
2069
query_scaled = op .Mul (query , op .Sqrt (scale ))
2070
2070
key_transposed_scaled = op .Mul (key_transposed , op .Sqrt (scale ))
2071
2071
# Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf'))
2072
- attn_mask = op .Where (
2073
- attn_mask , op .Constant (value_float = 0.0 ), op . Constant ( value_float = - float ("inf" ))
2074
- )
2072
+ zero = op .Constant ( value = ir . tensor ( 0.0 , dtype = query . dtype ))
2073
+ neg_inf = op .Constant (value = ir . tensor ( - float ("inf" ), dtype = query . dtype ))
2074
+ attn_mask = op . Where ( attn_mask , zero , neg_inf )
2075
2075
attn_weight = op .Softmax (
2076
2076
op .Add (op .MatMul (query_scaled , key_transposed_scaled ), attn_mask ),
2077
2077
axis = - 1 ,
You can’t perform that action at this time.
0 commit comments