Skip to content

Commit 60c36bc

Browse files
authored
Olmo2 Bug fix (quic#589)
Fixed -10000 with MIN_MASK Signed-off-by: Dipankar Sarkar <[email protected]>
1 parent a379e6e commit 60c36bc

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

QEfficient/transformers/models/olmo2/modeling_olmo2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from QEfficient.transformers.cache_utils import QEffDynamicCache
2929
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
30+
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE
3031

3132

3233
class QEffOlmo2RotaryEmbedding(Olmo2RotaryEmbedding):
@@ -109,7 +110,9 @@ def eager_attention_forward(
109110

110111
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
111112
if attention_mask is not None:
112-
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
113+
attn_weights = torch.where(
114+
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
115+
)
113116
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
114117
attn_output = torch.matmul(attn_weights, value_states)
115118
attn_output = attn_output.transpose(1, 2).contiguous()

0 commit comments

Comments
 (0)