@@ -121,6 +121,8 @@ def __init__(
121121 qk_norm : bool = False ,
122122 scale_norm : bool = True ,
123123 rotate_half : bool = False ,
124+ device = None ,
125+ dtype = None ,
124126 ):
125127 """
126128 Args:
@@ -139,6 +141,7 @@ def __init__(
139141 scale_norm: Enable normalization (scaling) of attention output with norm_layer
140142 rotate_half: Use half rotation layout instead of interleaved
141143 """
144+ dd = {'device' : device , 'dtype' : dtype }
142145 super ().__init__ ()
143146 if scale_norm or qk_norm :
144147 assert norm_layer is not None , 'norm_layer must be provided if qk_norm or scale_norm is True'
@@ -154,25 +157,25 @@ def __init__(
154157 self .rotate_half = rotate_half
155158
156159 if qkv_fused :
157- self .qkv = nn .Linear (dim , attn_dim * 3 , bias = False )
160+ self .qkv = nn .Linear (dim , attn_dim * 3 , bias = False , ** dd )
158161 self .q_proj = self .k_proj = self .v_proj = None
159162 if qkv_bias :
160- self .q_bias = nn .Parameter (torch .zeros (attn_dim ))
161- self .register_buffer ('k_bias' , torch .zeros (attn_dim ), persistent = False )
162- self .v_bias = nn .Parameter (torch .zeros (attn_dim ))
163+ self .q_bias = nn .Parameter (torch .zeros (attn_dim , ** dd ))
164+ self .register_buffer ('k_bias' , torch .zeros (attn_dim , ** dd ), persistent = False )
165+ self .v_bias = nn .Parameter (torch .zeros (attn_dim , ** dd ))
163166 else :
164167 self .q_bias = self .k_bias = self .v_bias = None
165168 else :
166- self .q_proj = nn .Linear (dim , attn_dim , bias = qkv_bias )
167- self .k_proj = nn .Linear (dim , attn_dim , bias = False )
168- self .v_proj = nn .Linear (dim , attn_dim , bias = qkv_bias )
169+ self .q_proj = nn .Linear (dim , attn_dim , bias = qkv_bias , ** dd )
170+ self .k_proj = nn .Linear (dim , attn_dim , bias = False , ** dd )
171+ self .v_proj = nn .Linear (dim , attn_dim , bias = qkv_bias , ** dd )
169172 self .qkv = None
170173 self .q_bias = self .k_bias = self .v_bias = None
171- self .q_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
172- self .k_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
174+ self .q_norm = norm_layer (self .head_dim , ** dd ) if qk_norm else nn .Identity ()
175+ self .k_norm = norm_layer (self .head_dim , ** dd ) if qk_norm else nn .Identity ()
173176 self .attn_drop = nn .Dropout (attn_drop )
174- self .norm = norm_layer (attn_dim ) if scale_norm else nn .Identity ()
175- self .proj = nn .Linear (attn_dim , dim )
177+ self .norm = norm_layer (attn_dim , ** dd ) if scale_norm else nn .Identity ()
178+ self .proj = nn .Linear (attn_dim , dim , ** dd )
176179 self .proj_drop = nn .Dropout (proj_drop )
177180
178181 def forward (
@@ -263,6 +266,8 @@ def __init__(
263266 act_layer : Callable = nn .GELU ,
264267 norm_layer : Callable = LayerNorm ,
265268 attn_head_dim : Optional [int ] = None ,
269+ device = None ,
270+ dtype = None ,
266271 ** kwargs ,
267272 ):
268273 """ Initialize the EVA transformer block.
@@ -286,8 +291,10 @@ def __init__(
286291 norm_layer: Normalization layer constructor
287292 attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads)
288293 """
294+ dd = {'device' : device , 'dtype' : dtype }
289295 super ().__init__ ()
290- self .norm1 = norm_layer (dim )
296+
297+ self .norm1 = norm_layer (dim , ** dd )
291298 attn_cls = AttentionRope if attn_type == 'rope' else EvaAttention
292299 self .attn = attn_cls (
293300 dim ,
@@ -301,11 +308,12 @@ def __init__(
301308 norm_layer = norm_layer ,
302309 scale_norm = scale_attn_inner ,
303310 rotate_half = rotate_half ,
311+ ** dd ,
304312 )
305- self .gamma_1 = nn .Parameter (init_values * torch .ones (dim )) if init_values is not None else None
313+ self .gamma_1 = nn .Parameter (init_values * torch .ones (dim , ** dd )) if init_values is not None else None
306314 self .drop_path1 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
307315
308- self .norm2 = norm_layer (dim )
316+ self .norm2 = norm_layer (dim , ** dd )
309317 hidden_features = int (dim * mlp_ratio )
310318 if swiglu_mlp :
311319 if scale_mlp or swiglu_align_to :
@@ -316,6 +324,7 @@ def __init__(
316324 norm_layer = norm_layer if scale_mlp else None ,
317325 drop = proj_drop ,
318326 align_to = swiglu_align_to ,
327+ ** dd ,
319328 )
320329 else :
321330 # w/o any extra norm, an impl with packed weights is used
@@ -326,6 +335,7 @@ def __init__(
326335 act_layer = nn .SiLU ,
327336 gate_last = False ,
328337 drop = proj_drop ,
338+ ** dd ,
329339 )
330340 else :
331341 self .mlp = Mlp (
@@ -334,8 +344,9 @@ def __init__(
334344 act_layer = act_layer ,
335345 norm_layer = norm_layer if scale_mlp else None ,
336346 drop = proj_drop ,
347+ ** dd ,
337348 )
338- self .gamma_2 = nn .Parameter (init_values * torch .ones (dim )) if init_values is not None else None
349+ self .gamma_2 = nn .Parameter (init_values * torch .ones (dim , ** dd )) if init_values is not None else None
339350 self .drop_path2 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
340351
341352 def forward (
@@ -376,6 +387,8 @@ def __init__(
376387 act_layer : Callable = nn .GELU ,
377388 norm_layer : Callable = nn .LayerNorm ,
378389 attn_head_dim : Optional [int ] = None ,
390+ device = None ,
391+ dtype = None ,
379392 ):
380393 """ Initialize the post-norm EVA transformer block.
381394
@@ -398,7 +411,9 @@ def __init__(
398411 norm_layer: Normalization layer constructor
399412 attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads)
400413 """
414+ dd = {'device' : device , 'dtype' : dtype }
401415 super ().__init__ ()
416+
402417 attn_cls = AttentionRope if attn_type == 'rope' else EvaAttention
403418 self .attn = attn_cls (
404419 dim ,
@@ -412,8 +427,9 @@ def __init__(
412427 norm_layer = norm_layer ,
413428 scale_norm = scale_attn_inner ,
414429 rotate_half = rotate_half ,
430+ ** dd ,
415431 )
416- self .norm1 = norm_layer (dim )
432+ self .norm1 = norm_layer (dim , ** dd )
417433 self .drop_path1 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
418434
419435 hidden_features = int (dim * mlp_ratio )
@@ -426,6 +442,7 @@ def __init__(
426442 norm_layer = norm_layer if scale_mlp else None ,
427443 drop = proj_drop ,
428444 align_to = swiglu_align_to ,
445+ ** dd ,
429446 )
430447 else :
431448 # w/o any extra norm, an impl with packed fc1 weights is used, matches existing GluMLP
@@ -436,6 +453,7 @@ def __init__(
436453 act_layer = nn .SiLU ,
437454 gate_last = False ,
438455 drop = proj_drop ,
456+ ** dd ,
439457 )
440458 else :
441459 self .mlp = Mlp (
@@ -444,8 +462,9 @@ def __init__(
444462 act_layer = act_layer ,
445463 norm_layer = norm_layer if scale_mlp else None ,
446464 drop = proj_drop ,
465+ ** dd ,
447466 )
448- self .norm2 = norm_layer (dim )
467+ self .norm2 = norm_layer (dim , ** dd )
449468 self .drop_path2 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
450469
451470 def forward (
@@ -513,6 +532,8 @@ def __init__(
513532 dynamic_img_pad : bool = False ,
514533 ref_feat_shape : Optional [Union [Tuple [int , int ], int ]] = None ,
515534 head_init_scale : float = 0.001 ,
535+ device = None ,
536+ dtype = None ,
516537 ):
517538 """Initialize the EVA Vision Transformer model.
518539
@@ -562,6 +583,7 @@ def __init__(
562583 head_init_scale: Initialization scale for classification head weights
563584 """
564585 super ().__init__ ()
586+ dd = {'device' : device , 'dtype' : dtype }
565587 assert global_pool in ('' , 'avg' , 'avgmax' , 'max' , 'token' , 'map' )
566588 self .num_classes = num_classes
567589 self .global_pool = global_pool
@@ -594,16 +616,17 @@ def __init__(
594616 dynamic_img_pad = dynamic_img_pad ,
595617 bias = not use_pre_transformer_norm ,
596618 ** embed_args ,
619+ ** dd ,
597620 )
598621 num_patches = self .patch_embed .num_patches
599622 r = self .patch_embed .feat_ratio () if hasattr (self .patch_embed , 'feat_ratio' ) else patch_size
600623
601- self .cls_token = nn .Parameter (torch .zeros (1 , 1 , embed_dim )) if class_token else None
602- self .reg_token = nn .Parameter (torch .zeros (1 , num_reg_tokens , embed_dim )) if num_reg_tokens else None
624+ self .cls_token = nn .Parameter (torch .empty (1 , 1 , embed_dim , ** dd )) if class_token else None
625+ self .reg_token = nn .Parameter (torch .empty (1 , num_reg_tokens , embed_dim , ** dd )) if num_reg_tokens else None
603626 self .cls_embed = class_token and self .reg_token is None
604627
605628 num_pos_tokens = num_patches if no_embed_class else num_patches + self .num_prefix_tokens
606- self .pos_embed = nn .Parameter (torch .zeros (1 , num_pos_tokens , embed_dim )) if use_abs_pos_emb else None
629+ self .pos_embed = nn .Parameter (torch .empty (1 , num_pos_tokens , embed_dim , ** dd )) if use_abs_pos_emb else None
607630 self .pos_drop = nn .Dropout (p = pos_drop_rate )
608631 if patch_drop_rate > 0 :
609632 self .patch_drop = PatchDropoutWithIndices (patch_drop_rate , num_prefix_tokens = self .num_prefix_tokens )
@@ -621,6 +644,7 @@ def __init__(
621644 feat_shape = None if dynamic_img_size else self .patch_embed .grid_size ,
622645 temperature = rope_temperature ,
623646 grid_indexing = rope_grid_indexing ,
647+ ** dd ,
624648 )
625649 if rope_type == 'mixed' :
626650 rope_kwargs .update (dict (depth = depth ))
@@ -636,7 +660,7 @@ def __init__(
636660 else :
637661 self .rope = None
638662
639- self .norm_pre = norm_layer (embed_dim ) if activate_pre_norm else nn .Identity ()
663+ self .norm_pre = norm_layer (embed_dim , ** dd ) if activate_pre_norm else nn .Identity ()
640664
641665 dpr = calculate_drop_path_rates (drop_path_rate , depth ) # stochastic depth decay rule
642666 block_fn = EvaBlockPostNorm if use_post_norm else EvaBlock
@@ -659,12 +683,13 @@ def __init__(
659683 drop_path = dpr [i ],
660684 norm_layer = norm_layer ,
661685 init_values = init_values ,
686+ ** dd ,
662687 )
663688 for i in range (depth )])
664689 self .feature_info = [
665690 dict (module = f'blocks.{ i } ' , num_chs = embed_dim , reduction = r ) for i in range (depth )]
666691
667- self .norm = norm_layer (embed_dim ) if activate_post_norm else nn .Identity ()
692+ self .norm = norm_layer (embed_dim , ** dd ) if activate_post_norm else nn .Identity ()
668693
669694 if global_pool == 'map' :
670695 self .attn_pool = AttentionPoolLatent (
@@ -673,23 +698,26 @@ def __init__(
673698 mlp_ratio = attn_pool_mlp_ratio or mlp_ratio ,
674699 norm_layer = norm_layer ,
675700 act_layer = nn .GELU ,
701+ ** dd ,
676702 )
677703 else :
678704 self .attn_pool = None
679- self .fc_norm = norm_layer (embed_dim ) if activate_fc_norm else nn .Identity ()
705+ self .fc_norm = norm_layer (embed_dim , ** dd ) if activate_fc_norm else nn .Identity ()
680706 self .head_drop = nn .Dropout (drop_rate )
681- self .head = nn .Linear (embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
707+ self .head = nn .Linear (embed_dim , num_classes , ** dd ) if num_classes > 0 else nn .Identity ()
682708
709+ self .init_weights (head_init_scale = head_init_scale )
710+
711+ def init_weights (self , head_init_scale = None ):
683712 self .apply (self ._init_weights )
684713 if self .pos_embed is not None :
685714 trunc_normal_ (self .pos_embed , std = .02 )
686715 if self .cls_token is not None :
687716 trunc_normal_ (self .cls_token , std = .02 )
688717 if self .reg_token is not None :
689718 trunc_normal_ (self .reg_token , std = .02 )
690-
691719 self .fix_init_weight ()
692- if isinstance (self .head , nn .Linear ):
720+ if head_init_scale and isinstance (self .head , nn .Linear ):
693721 trunc_normal_ (self .head .weight , std = .02 )
694722 self .head .weight .data .mul_ (head_init_scale )
695723 self .head .bias .data .mul_ (head_init_scale )
0 commit comments