@@ -2,7 +2,7 @@ use crate::layers::{
22 apply_rotary, get_cos_sin, get_cublas_lt_wrapper, get_inv_freqs, HiddenAct , Linear , RMSNorm ,
33} ;
44use crate :: models:: Model ;
5- use candle:: { Device , IndexOp , Result , Tensor , D } ;
5+ use candle:: { DType , Device , IndexOp , Result , Tensor , D } ;
66use candle_nn:: { Embedding , Module , VarBuilder } ;
77use serde:: Deserialize ;
88use text_embeddings_backend_core:: { Batch , ModelType , Pool } ;
@@ -382,10 +382,12 @@ pub struct Qwen3Model {
382382 rotary_cache : ( Tensor , Tensor ) ,
383383 rotary_dim : usize ,
384384 pool : Pool ,
385- pub device : Device ,
386385 num_attention_heads : usize ,
387386 pad_token_id : u32 ,
388387
388+ dtype : DType ,
389+ device : Device ,
390+
389391 span : tracing:: Span ,
390392}
391393
@@ -435,30 +437,30 @@ impl Qwen3Model {
435437 rotary_dim,
436438 pool,
437439 pad_token_id : config. eos_token_id as u32 ,
438- device : vb. device ( ) . clone ( ) ,
439440 num_attention_heads : config. num_attention_heads ,
441+ dtype : vb. dtype ( ) ,
442+ device : vb. device ( ) . clone ( ) ,
440443 span : tracing:: span!( tracing:: Level :: TRACE , "model" ) ,
441444 } )
442445 }
443446
444447 fn get_causal_attention_bias ( & self , attention_bias : Tensor ) -> Result < Tensor > {
445448 let ( bs, dim, seq_len, _) = attention_bias. dims4 ( ) ?;
446449
447- let device = attention_bias. device ( ) ;
448-
449450 let mask: Vec < u8 > = ( 0 ..seq_len)
450451 . flat_map ( |i| ( 0 ..seq_len) . map ( move |j| ( j > i) as u8 ) )
451452 . collect ( ) ;
452453
453454 let causal_mask = Tensor :: from_slice ( & mask, ( seq_len, seq_len) , & Device :: Cpu ) ?;
454455 let causal_mask = causal_mask. expand ( & [ bs, dim, seq_len, seq_len] ) ?;
455456
456- let negatives = Tensor :: full ( f32:: MIN , attention_bias. shape ( ) , & Device :: Cpu ) ?;
457- let zeros = Tensor :: zeros_like ( & attention_bias) ?. to_device ( & Device :: Cpu ) ?;
457+ let negatives =
458+ Tensor :: full ( f32:: MIN , attention_bias. shape ( ) , & Device :: Cpu ) ?. to_dtype ( self . dtype ) ?;
459+ let zeros = Tensor :: zeros_like ( & attention_bias) ?. to_dtype ( self . dtype ) ?;
458460
459461 let causal_mask = causal_mask
460462 . where_cond ( & negatives, & zeros) ?
461- . to_device ( device) ?;
463+ . to_device ( & self . device ) ?;
462464
463465 attention_bias. broadcast_add ( & causal_mask)
464466 }
@@ -494,7 +496,7 @@ impl Qwen3Model {
494496 for _ in 0 ..padding {
495497 input_ids. push ( self . pad_token_id ) ;
496498 position_ids. push ( 0 ) ;
497- attention_bias. push ( f32:: MIN ) ;
499+ attention_bias. push ( f32:: NEG_INFINITY ) ;
498500 }
499501 }
500502
@@ -539,7 +541,7 @@ impl Qwen3Model {
539541 // Create attention bias for causal masking even for single sequences
540542 let attention_bias = Tensor :: zeros (
541543 ( 1 , self . num_attention_heads , seq_len, seq_len) ,
542- candle :: DType :: F32 ,
544+ self . dtype ,
543545 & self . device ,
544546 ) ?;
545547
0 commit comments