@@ -455,8 +455,13 @@ impl Qwen3Model {
455455 let causal_mask = Tensor :: from_slice ( & mask, ( seq_len, seq_len) , device) ?;
456456 let causal_mask = causal_mask. expand ( & [ bs, dim, seq_len, seq_len] ) ?;
457457
458+ let min_value = match self . dtype {
459+ DType :: F32 => f32:: MIN ,
460+ _ => -65504.0 , // f16 minimum value
461+ } ;
462+
458463 let negatives =
459- Tensor :: full ( f32 :: MIN , attention_bias. shape ( ) , device) ?. to_dtype ( self . dtype ) ?;
464+ Tensor :: full ( min_value , attention_bias. shape ( ) , device) ?. to_dtype ( self . dtype ) ?;
460465 let zeros = Tensor :: zeros_like ( & attention_bias) ?. to_dtype ( self . dtype ) ?;
461466
462467 let causal_mask = causal_mask
@@ -514,7 +519,8 @@ impl Qwen3Model {
514519
515520 let attention_bias = if masking {
516521 let attention_bias =
517- Tensor :: from_vec ( attention_bias, ( batch_size, 1 , 1 , max_length) , & self . device ) ?;
522+ Tensor :: from_vec ( attention_bias, ( batch_size, 1 , 1 , max_length) , & self . device ) ?
523+ . to_dtype ( self . dtype ) ?;
518524 // Broadcast once instead of at every layer
519525 let attention_bias = attention_bias
520526 . broadcast_as ( ( batch_size, self . num_attention_heads , max_length, max_length) ) ?
0 commit comments