@@ -23,6 +23,7 @@ pub struct Qwen3Config {
2323 pub rope_theta : f32 ,
2424 pub sliding_window : Option < usize > ,
2525 pub use_sliding_window : bool ,
26+ pub eos_token_id : usize ,
2627}
2728
2829struct Qwen3Attention {
@@ -164,8 +165,8 @@ impl Qwen3Attention {
164165 . concat ( ) ,
165166 ) ?;
166167
167- let ( q, _res ) = self . q_norm . forward ( & q, None ) ?;
168- let ( k, _res ) = self . k_norm . forward ( & k, None ) ?;
168+ let ( q, _ ) = self . q_norm . forward ( & q, None ) ?;
169+ let ( k, _ ) = self . k_norm . forward ( & k, None ) ?;
169170
170171 let q = q. transpose ( 1 , 2 ) ?;
171172 let k = k. transpose ( 1 , 2 ) ?;
@@ -355,16 +356,21 @@ impl Qwen3Layer {
355356 ) -> Result < Tensor > {
356357 let _enter = self . span . enter ( ) ;
357358
358- let ( normed_hidden_states, res) = self . input_layer_norm . forward ( hidden_states, None ) ?;
359+ let ( normed_hidden_states, residual) =
360+ self . input_layer_norm . forward ( hidden_states, None ) ?;
361+
359362 let attn_output =
360363 self . attention
361364 . forward ( & normed_hidden_states, attention_bias, cos, sin) ?;
365+
362366 let ( normed_attn_res_output, attn_res) = self
363367 . post_attention_layer_norm
364- . forward ( & attn_output, Some ( & res) ) ?;
368+ . forward ( & attn_output, Some ( & residual) ) ?;
369+
365370 let mlp_output = self . mlp . forward ( & normed_attn_res_output) ?;
366371
367372 let output = ( & mlp_output + & attn_res) ?;
373+
368374 Ok ( output)
369375 }
370376}
@@ -378,6 +384,7 @@ pub struct Qwen3Model {
378384 pool : Pool ,
379385 pub device : Device ,
380386 num_attention_heads : usize ,
387+ pad_token_id : u32 ,
381388
382389 span : tracing:: Span ,
383390}
@@ -427,12 +434,35 @@ impl Qwen3Model {
427434 rotary_cache,
428435 rotary_dim,
429436 pool,
437+ pad_token_id : config. eos_token_id as u32 ,
430438 device : vb. device ( ) . clone ( ) ,
431439 num_attention_heads : config. num_attention_heads ,
432440 span : tracing:: span!( tracing:: Level :: TRACE , "model" ) ,
433441 } )
434442 }
435443
444+ fn get_causal_attention_bias ( & self , attention_bias : Tensor ) -> Result < Tensor > {
445+ let ( bs, dim, seq_len, _) = attention_bias. dims4 ( ) ?;
446+
447+ let device = attention_bias. device ( ) ;
448+
449+ let mask: Vec < u8 > = ( 0 ..seq_len)
450+ . flat_map ( |i| ( 0 ..seq_len) . map ( move |j| ( j > i) as u8 ) )
451+ . collect ( ) ;
452+
453+ let causal_mask = Tensor :: from_slice ( & mask, ( seq_len, seq_len) , & Device :: Cpu ) ?;
454+ let causal_mask = causal_mask. expand ( & [ bs, dim, seq_len, seq_len] ) ?;
455+
456+ let negatives = Tensor :: full ( f32:: MIN , attention_bias. shape ( ) , & Device :: Cpu ) ?;
457+ let zeros = Tensor :: zeros_like ( & attention_bias) ?. to_device ( & Device :: Cpu ) ?;
458+
459+ let causal_mask = causal_mask
460+ . where_cond ( & negatives, & zeros) ?
461+ . to_device ( device) ?;
462+
463+ attention_bias. broadcast_add ( & causal_mask)
464+ }
465+
436466 pub fn forward ( & self , batch : Batch ) -> Result < ( Option < Tensor > , Option < Tensor > ) > {
437467 let _enter = self . span . enter ( ) ;
438468
@@ -441,93 +471,77 @@ impl Qwen3Model {
441471
442472 let shape = ( batch_size, max_length) ;
443473
444- let ( input_ids, position_ids, input_lengths, attention_bias, _attention_mask) =
445- if batch_size > 1 {
446- // Prepare padded batch
447- let elems = batch_size * max_length;
448-
449- let mut input_ids = Vec :: with_capacity ( elems) ;
450- let mut position_ids = Vec :: with_capacity ( elems) ;
451- let mut attention_mask = Vec :: with_capacity ( elems) ;
452- let mut attention_bias = Vec :: with_capacity ( elems) ;
453- let mut input_lengths = Vec :: with_capacity ( batch_size) ;
454- let mut masking = false ;
455-
456- for i in 0 ..batch_size {
457- let start = batch. cumulative_seq_lengths [ i] as usize ;
458- let end = batch. cumulative_seq_lengths [ i + 1 ] as usize ;
459- let seq_length = end - start;
460- input_lengths. push ( seq_length) ;
461-
462- // Input ids
463- for j in start..end {
464- input_ids. push ( batch. input_ids [ j] ) ;
465- position_ids. push ( batch. position_ids [ j] ) ;
466- attention_mask. push ( 1.0_f32 ) ;
467- attention_bias. push ( 0.0 ) ;
468- }
474+ let ( input_ids, position_ids, input_lengths, attention_bias) = if batch_size > 1 {
475+ // Prepare padded batch
476+ let elems = batch_size * max_length;
477+
478+ let mut input_ids = Vec :: with_capacity ( elems) ;
479+ let mut position_ids = Vec :: with_capacity ( elems) ;
480+ let mut attention_bias = Vec :: with_capacity ( elems) ;
481+ let mut input_lengths = Vec :: with_capacity ( batch_size) ;
482+ let mut masking = false ;
483+
484+ for i in 0 ..batch_size {
485+ let start = batch. cumulative_seq_lengths [ i] as usize ;
486+ let end = batch. cumulative_seq_lengths [ i + 1 ] as usize ;
487+ let seq_length = end - start;
488+ input_lengths. push ( seq_length) ;
489+
490+ for j in start..end {
491+ input_ids. push ( batch. input_ids [ j] ) ;
492+ position_ids. push ( batch. position_ids [ j] ) ;
493+ attention_bias. push ( 0.0 ) ;
494+ }
469495
470- // Pad to max_length
471- for _ in seq_length..max_length {
472- input_ids . push ( 0 ) ;
473- position_ids . push ( 0 ) ;
474- attention_mask . push ( 0.0_f32 ) ;
475- attention_bias . push ( f32 :: NEG_INFINITY ) ;
476- masking = true ;
496+ let padding = max_length - seq_length ;
497+ if padding > 0 {
498+ masking = true ;
499+ for _ in 0 ..padding {
500+ input_ids . insert ( start , self . pad_token_id ) ;
501+ position_ids . insert ( start , 0 ) ;
502+ attention_bias . insert ( start , f32 :: MIN ) ;
477503 }
478504 }
505+ }
479506
480- let input_ids = Tensor :: from_vec ( input_ids, shape, & self . device ) ?;
481- let position_ids = Tensor :: from_vec ( position_ids, shape, & self . device ) ?;
482- let attention_mask = if masking {
483- Some ( Tensor :: from_vec ( attention_mask, shape, & self . device ) ?)
484- } else {
485- None
486- } ;
487-
488- let attention_bias = if masking {
489- let attention_bias = Tensor :: from_vec (
490- attention_bias,
491- ( batch_size, 1 , 1 , max_length) ,
492- & self . device ,
493- ) ?;
494- // Broadcast once instead of at every layer
495- let attention_bias = attention_bias
496- . broadcast_as ( (
497- batch_size,
498- self . num_attention_heads ,
499- max_length,
500- max_length,
501- ) ) ?
502- . contiguous ( ) ?;
503- Some ( attention_bias)
504- } else {
505- None
506- } ;
507-
508- (
509- input_ids,
510- position_ids,
511- input_lengths,
512- attention_bias,
513- attention_mask,
514- )
507+ let input_ids = Tensor :: from_vec ( input_ids, shape, & self . device ) ?;
508+ let position_ids = Tensor :: from_vec ( position_ids, shape, & self . device ) ?;
509+
510+ let attention_bias = if masking {
511+ let attention_bias =
512+ Tensor :: from_vec ( attention_bias, ( batch_size, 1 , 1 , max_length) , & self . device ) ?;
513+ // Broadcast once instead of at every layer
514+ let attention_bias = attention_bias
515+ . broadcast_as ( ( batch_size, self . num_attention_heads , max_length, max_length) ) ?
516+ . contiguous ( ) ?;
517+ Some ( attention_bias)
515518 } else {
516- let input_ids = Tensor :: from_vec (
517- batch. input_ids . clone ( ) ,
518- ( 1 , batch. input_ids . len ( ) ) ,
519- & self . device ,
520- ) ?;
521- let position_ids = Tensor :: from_vec (
522- batch. position_ids . clone ( ) ,
523- ( 1 , batch. position_ids . len ( ) ) ,
524- & self . device ,
525- ) ?;
526- let input_lengths = vec ! [ batch. input_ids. len( ) ] ;
527-
528- ( input_ids, position_ids, input_lengths, None , None )
519+ None
529520 } ;
530521
522+ ( input_ids, position_ids, input_lengths, attention_bias)
523+ } else {
524+ let input_ids = Tensor :: from_vec (
525+ batch. input_ids . clone ( ) ,
526+ ( 1 , batch. input_ids . len ( ) ) ,
527+ & self . device ,
528+ ) ?;
529+ let position_ids = Tensor :: from_vec (
530+ batch. position_ids . clone ( ) ,
531+ ( 1 , batch. position_ids . len ( ) ) ,
532+ & self . device ,
533+ ) ?;
534+ let input_lengths = vec ! [ batch. input_ids. len( ) ] ;
535+
536+ ( input_ids, position_ids, input_lengths, None )
537+ } ;
538+
539+ let attention_bias = if let Some ( attn_bias) = attention_bias {
540+ Some ( self . get_causal_attention_bias ( attn_bias) ?)
541+ } else {
542+ None
543+ } ;
544+
531545 let mut hidden_states = self . embeddings . forward ( & input_ids) ?;
532546
533547 let cos = self
@@ -583,7 +597,7 @@ impl Qwen3Model {
583597 . iter ( )
584598 . map ( |& i| {
585599 let i = i as usize ;
586- let last_token_idx = input_lengths [ i ] - 1 ;
600+ let last_token_idx = max_length - 1 ;
587601 outputs. i ( ( i, last_token_idx) ) ?. unsqueeze ( 0 )
588602 } )
589603 . collect ( ) ;
0 commit comments