Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 105 additions & 4 deletions timm/models/fastvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,17 @@
import torch
import torch.nn as nn

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn, \
ClassifierHead
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import (
DropPath,
trunc_normal_,
create_conv2d,
ConvNormAct,
SqueezeExcite,
use_fused_attn,
ClassifierHead,
LayerNorm2d,
)
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq
Expand Down Expand Up @@ -427,7 +435,8 @@ def convolutional_stem(
in_chs: int,
out_chs: int,
act_layer: Type[nn.Module] = nn.GELU,
inference_mode: bool = False
inference_mode: bool = False,
use_scale_branch: bool = True,
) -> nn.Sequential:
"""Build convolutional stem with MobileOne blocks.

Expand All @@ -447,6 +456,7 @@ def convolutional_stem(
stride=2,
act_layer=act_layer,
inference_mode=inference_mode,
use_scale_branch=use_scale_branch,
),
MobileOneBlock(
in_chs=out_chs,
Expand All @@ -456,6 +466,7 @@ def convolutional_stem(
group_size=1,
act_layer=act_layer,
inference_mode=inference_mode,
use_scale_branch=use_scale_branch,
),
MobileOneBlock(
in_chs=out_chs,
Expand All @@ -464,6 +475,7 @@ def convolutional_stem(
stride=1,
act_layer=act_layer,
inference_mode=inference_mode,
use_scale_branch=use_scale_branch,
),
)

Expand Down Expand Up @@ -1118,6 +1130,7 @@ def __init__(
drop_path_rate: float = 0.0,
layer_scale_init_value: float = 1e-5,
lkc_use_act: bool = False,
stem_use_scale_branch: bool = True,
fork_feat: bool = False,
cls_ratio: float = 2.0,
global_pool: str = 'avg',
Expand All @@ -1137,6 +1150,7 @@ def __init__(
embed_dims[0],
act_layer,
inference_mode,
use_scale_branch=stem_use_scale_branch,
)

# Build the main stages of the network architecture
Expand Down Expand Up @@ -1412,6 +1426,39 @@ def _cfg(url="", **kwargs):
num_classes=512, # CLIP proj dim
mean=(0., 0., 0.), std=(1., 1., 1.)
),

"fastvit_mci0.apple_mclip2_dfndr2b": _cfg(
hf_hub_id='timm/',
crop_pct=1.0,
num_classes=512, # CLIP proj dim
mean=(0., 0., 0.), std=(1., 1., 1.),
license='apple-amlr'
),
"fastvit_mci2.apple_mclip2_dfndr2b": _cfg(
hf_hub_id='timm/',
crop_pct=0.95,
num_classes=512, # CLIP proj dim
mean=(0., 0., 0.), std=(1., 1., 1.),
license='apple-amlr'
),
"fastvit_mci3.apple_mclip2_dfndr2b": _cfg(
hf_hub_id='timm/',
crop_pct=0.95,
num_classes=768, # CLIP proj dim
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
pool_size=(4, 4),
first_conv='stem.0.conv_kxk.0.conv',
license='apple-amlr'
),
"fastvit_mci4.apple_mclip2_dfndr2b": _cfg(
hf_hub_id='timm/',
crop_pct=0.95,
num_classes=768, # CLIP proj dim
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
pool_size=(4, 4),
first_conv='stem.0.conv_kxk.0.conv',
license='apple-amlr'
),
})


Expand All @@ -1420,6 +1467,9 @@ def checkpoint_filter_fn(state_dict, model):
if 'stem.0.conv_kxk.0.conv.weight' in state_dict:
return state_dict # non-original checkpoint, no remapping needed

if 'module.visual.trunk.stem.0.conv_kxk.0.conv.weight' in state_dict:
return {k.replace('module.visual.trunk.', ''): v for k, v in state_dict.items() if k.startswith('module.visual.trunk')}

state_dict = state_dict.get('state_dict', state_dict)
if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict:
# remap MobileCLIP checkpoints
Expand Down Expand Up @@ -1632,3 +1682,54 @@ def fastvit_mci2(pretrained=False, **kwargs):
lkc_use_act=True,
)
return _create_fastvit('fastvit_mci2', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def fastvit_mci3(pretrained=False, **kwargs):
"""Instantiate L model variant."""
model_args = dict(
layers=(2, 12, 24, 4, 2),
embed_dims=(96, 192, 384, 768, 1536),
mlp_ratios=(4, 4, 4, 4, 4),
se_downsamples=(False, False, False, False, False),
downsamples=(False, True, True, True, True),
pos_embs=(
None,
None,
None,
partial(RepConditionalPosEnc, spatial_shape=(7, 7)),
partial(RepConditionalPosEnc, spatial_shape=(7, 7))
),
token_mixers=("repmixer", "repmixer", "repmixer", "attention", "attention"),
lkc_use_act=True,
norm_layer=partial(LayerNorm2d, eps=1e-5),
stem_use_scale_branch=False,
)
model = _create_fastvit('fastvit_mci3', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def fastvit_mci4(pretrained=False, **kwargs):
"""Instantiate XL model variant."""
model_args = dict(
layers=(2, 12, 24, 4, 4),
embed_dims=(128, 256, 512, 1024, 2048),
mlp_ratios=(4, 4, 4, 4, 4),
se_downsamples=(False, False, False, False, False),
downsamples=(False, True, True, True, True),
pos_embs=(
None,
None,
None,
partial(RepConditionalPosEnc, spatial_shape=(7, 7)),
partial(RepConditionalPosEnc, spatial_shape=(7, 7))
),
token_mixers=("repmixer", "repmixer", "repmixer", "attention", "attention"),
lkc_use_act=True,
norm_layer=partial(LayerNorm2d, eps=1e-5),
stem_use_scale_branch=False,
)

model = _create_fastvit('fastvit_mci4', pretrained=pretrained, **dict(model_args, **kwargs))
return model
9 changes: 9 additions & 0 deletions timm/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,6 +1416,8 @@ def checkpoint_filter_fn(
# remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
out_dict['head.weight'] = state_dict['visual.head.proj.weight']
out_dict['head.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
elif 'module.visual.trunk.pos_embed' in state_dict:
prefix = 'module.visual.trunk.'
elif 'preprocessor.patchifier.proj.weight' in state_dict:
state_dict = _convert_aimv2(state_dict, model)

Expand Down Expand Up @@ -2007,6 +2009,13 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 336, 336), num_classes=768),

'vit_large_patch14_clip_224.apple_mclip2_dfndr2b': _cfg(
hf_hub_id='timm/',
num_classes=768,
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0,
license='apple-amlr'
),

# experimental (may be removed)
'vit_base_patch32_plus_256.untrained': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95),
'vit_base_patch16_plus_240.untrained': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95),
Expand Down
6 changes: 6 additions & 0 deletions timm/models/vision_transformer_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,12 @@ def _cfg(url='', **kwargs):
num_classes=512,
mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv',
),
'vit_base_mci_224.apple_mclip2_dfndr2b': _cfg(
hf_hub_id='timm/',
num_classes=512,
mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv',
license='apple-amlr'
),
})


Expand Down