@@ -198,13 +198,66 @@ impl GTELayer {
198198 }
199199}
200200
201+ pub struct GTEClassificationHead {
202+ pooler : Option < Linear > ,
203+ classifier : Linear ,
204+ span : tracing:: Span ,
205+ }
206+
207+ impl GTEClassificationHead {
208+ #[ allow( dead_code) ]
209+ pub ( crate ) fn load ( vb : VarBuilder , config : & GTEConfig ) -> Result < Self > {
210+ let n_classes = match & config. id2label {
211+ None => candle:: bail!( "`id2label` must be set for classifier models" ) ,
212+ Some ( id2label) => id2label. len ( ) ,
213+ } ;
214+
215+ let pooler = if let Ok ( pooler_weight) = vb
216+ . pp ( "pooler.dense" )
217+ . get ( ( config. hidden_size , config. hidden_size ) , "weight" )
218+ {
219+ let pooler_bias = vb. pp ( "pooler.dense" ) . get ( config. hidden_size , "bias" ) ?;
220+ Some ( Linear :: new ( pooler_weight, Some ( pooler_bias) , None ) )
221+ } else {
222+ None
223+ } ;
224+
225+ let classifier_weight = vb
226+ . pp ( "classifier" )
227+ . get ( ( n_classes, config. hidden_size ) , "weight" ) ?;
228+ let classifier_bias = vb. pp ( "classifier" ) . get ( n_classes, "bias" ) ?;
229+ let classifier = Linear :: new ( classifier_weight, Some ( classifier_bias) , None ) ;
230+
231+ Ok ( Self {
232+ classifier,
233+ pooler,
234+ span : tracing:: span!( tracing:: Level :: TRACE , "classifier" ) ,
235+ } )
236+ }
237+
238+ pub ( crate ) fn forward ( & self , hidden_states : & Tensor ) -> Result < Tensor > {
239+ let _enter = self . span . enter ( ) ;
240+
241+ let mut hidden_states = hidden_states. unsqueeze ( 1 ) ?;
242+ if let Some ( pooler) = self . pooler . as_ref ( ) {
243+ hidden_states = pooler. forward ( & hidden_states) ?;
244+ hidden_states = hidden_states. tanh ( ) ?;
245+ }
246+
247+ let hidden_states = self . classifier . forward ( & hidden_states) ?;
248+ let hidden_states = hidden_states. squeeze ( 1 ) ?;
249+ Ok ( hidden_states)
250+ }
251+ }
252+
201253pub struct FlashGTEModel {
202254 word_embeddings : Embedding ,
203255 token_type_embeddings : Option < Embedding > ,
204256 layers : Vec < GTELayer > ,
205257 embeddings_norm : LayerNorm ,
206258 cos_cache : Tensor ,
207259 sin_cache : Tensor ,
260+ classifier : Option < GTEClassificationHead > ,
208261 pool : Pool ,
209262 pub device : Device ,
210263
@@ -233,11 +286,14 @@ impl FlashGTEModel {
233286 candle:: bail!( "Only `PositionEmbeddingType::Rope` is supported" ) ;
234287 }
235288
236- let pool = match model_type {
289+ let ( pool, classifier ) = match model_type {
237290 ModelType :: Classifier => {
238- candle:: bail!( "`classifier` model type is not supported for GTE" )
291+ let pool = Pool :: Cls ;
292+
293+ let classifier = GTEClassificationHead :: load ( vb. clone ( ) , config) ?;
294+ ( pool, Some ( classifier) )
239295 }
240- ModelType :: Embedding ( pool) => pool,
296+ ModelType :: Embedding ( pool) => ( pool, None ) ,
241297 } ;
242298
243299 let word_embeddings = Embedding :: new (
@@ -292,6 +348,7 @@ impl FlashGTEModel {
292348 embeddings_norm,
293349 cos_cache,
294350 sin_cache,
351+ classifier,
295352 pool,
296353 device : vb. device ( ) . clone ( ) ,
297354 span : tracing:: span!( tracing:: Level :: TRACE , "model" ) ,
@@ -457,7 +514,20 @@ impl Model for FlashGTEModel {
457514 fn is_padded ( & self ) -> bool {
458515 false
459516 }
517+
460518 fn embed ( & self , batch : Batch ) -> Result < ( Option < Tensor > , Option < Tensor > ) > {
461519 self . forward ( batch)
462520 }
521+
522+ fn predict ( & self , batch : Batch ) -> Result < Tensor > {
523+ match & self . classifier {
524+ None => candle:: bail!( "`predict` is not implemented for this model" ) ,
525+ Some ( classifier) => {
526+ let ( pooled_embeddings, _raw_embeddings) = self . forward ( batch) ?;
527+ let pooled_embeddings =
528+ pooled_embeddings. expect ( "pooled_embeddings is empty. This is a bug." ) ;
529+ classifier. forward ( & pooled_embeddings)
530+ }
531+ }
532+ }
463533}
0 commit comments