diff --git a/timm/models/fastvit.py b/timm/models/fastvit.py index 96e1d593c6..9939095c43 100644 --- a/timm/models/fastvit.py +++ b/timm/models/fastvit.py @@ -12,9 +12,17 @@ import torch import torch.nn as nn -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn, \ - ClassifierHead +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD +from timm.layers import ( + DropPath, + trunc_normal_, + create_conv2d, + ConvNormAct, + SqueezeExcite, + use_fused_attn, + ClassifierHead, + LayerNorm2d, +) from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import checkpoint_seq @@ -427,7 +435,8 @@ def convolutional_stem( in_chs: int, out_chs: int, act_layer: Type[nn.Module] = nn.GELU, - inference_mode: bool = False + inference_mode: bool = False, + use_scale_branch: bool = True, ) -> nn.Sequential: """Build convolutional stem with MobileOne blocks. @@ -447,6 +456,7 @@ def convolutional_stem( stride=2, act_layer=act_layer, inference_mode=inference_mode, + use_scale_branch=use_scale_branch, ), MobileOneBlock( in_chs=out_chs, @@ -456,6 +466,7 @@ def convolutional_stem( group_size=1, act_layer=act_layer, inference_mode=inference_mode, + use_scale_branch=use_scale_branch, ), MobileOneBlock( in_chs=out_chs, @@ -464,6 +475,7 @@ def convolutional_stem( stride=1, act_layer=act_layer, inference_mode=inference_mode, + use_scale_branch=use_scale_branch, ), ) @@ -1118,6 +1130,7 @@ def __init__( drop_path_rate: float = 0.0, layer_scale_init_value: float = 1e-5, lkc_use_act: bool = False, + stem_use_scale_branch: bool = True, fork_feat: bool = False, cls_ratio: float = 2.0, global_pool: str = 'avg', @@ -1137,6 +1150,7 @@ def __init__( embed_dims[0], act_layer, inference_mode, + use_scale_branch=stem_use_scale_branch, ) # Build the main stages of the network architecture @@ -1412,6 +1426,39 @@ def _cfg(url="", **kwargs): num_classes=512, # CLIP proj dim mean=(0., 0., 0.), std=(1., 1., 1.) ), + + "fastvit_mci0.apple_mclip2_dfndr2b": _cfg( + hf_hub_id='timm/', + crop_pct=1.0, + num_classes=512, # CLIP proj dim + mean=(0., 0., 0.), std=(1., 1., 1.), + license='apple-amlr' + ), + "fastvit_mci2.apple_mclip2_dfndr2b": _cfg( + hf_hub_id='timm/', + crop_pct=0.95, + num_classes=512, # CLIP proj dim + mean=(0., 0., 0.), std=(1., 1., 1.), + license='apple-amlr' + ), + "fastvit_mci3.apple_mclip2_dfndr2b": _cfg( + hf_hub_id='timm/', + crop_pct=0.95, + num_classes=768, # CLIP proj dim + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + pool_size=(4, 4), + first_conv='stem.0.conv_kxk.0.conv', + license='apple-amlr' + ), + "fastvit_mci4.apple_mclip2_dfndr2b": _cfg( + hf_hub_id='timm/', + crop_pct=0.95, + num_classes=768, # CLIP proj dim + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + pool_size=(4, 4), + first_conv='stem.0.conv_kxk.0.conv', + license='apple-amlr' + ), }) @@ -1420,6 +1467,9 @@ def checkpoint_filter_fn(state_dict, model): if 'stem.0.conv_kxk.0.conv.weight' in state_dict: return state_dict # non-original checkpoint, no remapping needed + if 'module.visual.trunk.stem.0.conv_kxk.0.conv.weight' in state_dict: + return {k.replace('module.visual.trunk.', ''): v for k, v in state_dict.items() if k.startswith('module.visual.trunk')} + state_dict = state_dict.get('state_dict', state_dict) if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict: # remap MobileCLIP checkpoints @@ -1632,3 +1682,54 @@ def fastvit_mci2(pretrained=False, **kwargs): lkc_use_act=True, ) return _create_fastvit('fastvit_mci2', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def fastvit_mci3(pretrained=False, **kwargs): + """Instantiate L model variant.""" + model_args = dict( + layers=(2, 12, 24, 4, 2), + embed_dims=(96, 192, 384, 768, 1536), + mlp_ratios=(4, 4, 4, 4, 4), + se_downsamples=(False, False, False, False, False), + downsamples=(False, True, True, True, True), + pos_embs=( + None, + None, + None, + partial(RepConditionalPosEnc, spatial_shape=(7, 7)), + partial(RepConditionalPosEnc, spatial_shape=(7, 7)) + ), + token_mixers=("repmixer", "repmixer", "repmixer", "attention", "attention"), + lkc_use_act=True, + norm_layer=partial(LayerNorm2d, eps=1e-5), + stem_use_scale_branch=False, + ) + model = _create_fastvit('fastvit_mci3', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def fastvit_mci4(pretrained=False, **kwargs): + """Instantiate XL model variant.""" + model_args = dict( + layers=(2, 12, 24, 4, 4), + embed_dims=(128, 256, 512, 1024, 2048), + mlp_ratios=(4, 4, 4, 4, 4), + se_downsamples=(False, False, False, False, False), + downsamples=(False, True, True, True, True), + pos_embs=( + None, + None, + None, + partial(RepConditionalPosEnc, spatial_shape=(7, 7)), + partial(RepConditionalPosEnc, spatial_shape=(7, 7)) + ), + token_mixers=("repmixer", "repmixer", "repmixer", "attention", "attention"), + lkc_use_act=True, + norm_layer=partial(LayerNorm2d, eps=1e-5), + stem_use_scale_branch=False, + ) + + model = _create_fastvit('fastvit_mci4', pretrained=pretrained, **dict(model_args, **kwargs)) + return model \ No newline at end of file diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 498f3178c5..de682c6d65 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -1416,6 +1416,8 @@ def checkpoint_filter_fn( # remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj) out_dict['head.weight'] = state_dict['visual.head.proj.weight'] out_dict['head.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0]) + elif 'module.visual.trunk.pos_embed' in state_dict: + prefix = 'module.visual.trunk.' elif 'preprocessor.patchifier.proj.weight' in state_dict: state_dict = _convert_aimv2(state_dict, model) @@ -2007,6 +2009,13 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 336, 336), num_classes=768), + 'vit_large_patch14_clip_224.apple_mclip2_dfndr2b': _cfg( + hf_hub_id='timm/', + num_classes=768, + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, + license='apple-amlr' + ), + # experimental (may be removed) 'vit_base_patch32_plus_256.untrained': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95), 'vit_base_patch16_plus_240.untrained': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95), diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 0ff4823497..7290bb7697 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -230,6 +230,12 @@ def _cfg(url='', **kwargs): num_classes=512, mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv', ), + 'vit_base_mci_224.apple_mclip2_dfndr2b': _cfg( + hf_hub_id='timm/', + num_classes=512, + mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv', + license='apple-amlr' + ), })