Skip to content

Commit bfeba12

Browse files
authored
Add Video Swin Transformer (#2369)
* init video swin * add: 3d window size computation * add: mlp layer * add: patch embedding layer * add: patch merging layer * add: window attention layer * add: basic layer for video swin * update: basic layer for video swin * add: swin blocks for video swin * create and add: video swin backbone * rename: video swin layers to model specific * update module import * update module import * set class method to private usage * set init params for backbone * rm redundant imports * add video swin layer test cases * add: videoswin backbone aliases * add: video swin backbone presets * add: video swin backbone presets test * update: video swin backbone presets test * add: video classifier task * add: video swin classifier presets * run formatters * rename module name/id" * add hard-coded normalization for include rescaling=true * add docstring for videoswin backbone * update metadata: backbone presets no weights * update: backbone presets no weights test * update video swin aliases for no weights * add: video swin backbone presets with weights * update: video swin aliases with weights presets * update video swin layer test cases * added patch merging test * imported video swins presets to backbone presets list" * fix: typos" * run formatters" * fix: linting issue * fix: linting issue * fix: video swin layer test cases" * add: video swin backbone test * rm redundant code * disable preset test temporary * set include rescale to true * add video swin components to __init__ * update docstrings: video siwn layers scripts * update copywrite status: video siwn layers test scripts * update copywrite status: video siwn backbone scripts * bug fixes: video swin backbone layers * update get config of video swin backbone * enable: video swin backbone test cases * update: video swin backbone test cases * update: video swin backbone preset test cases * run formatters * fix typos: video swin backbone test cases * add: non implemented property for test reason * fix: typos * add: video classifier test * update: video classifier test * update: video classifier test input shape * bug fix: mlp layer build method * updated: swin back layer build method * bug fix: use tf.TensorShape in compute_output_shape method * update: video_classifier_test model.predict to model.call * update test cases and format the code * update docstrings and preset config * fix jax DynamicJaxprTrace issue for * update config of backbone aliases * add can run in mixed precision test * add can run on gray video * minor fix * specify axis in keras.ops.take to match with tf.gather * specify include rescaling to backbone class * remove shift size form get config of video basic layer * add support arbitrary input shape * minor updates to swin layers * test method update for swin layers * update test method to swin backbone * remove unsed code * bug fix in call method of patch embed layer * fix typo in patch merging layer * minor fix * fix keras.ops.cond issue with jax * no test for jit compile in torch * reduce tensor size for forward test * minor fix * remove kcv export decorator * update keras.Layer import * remove unused layer import * replace keras.layers instead of layers * update keras.Layer to keras.layers.Layer for keras2 * add window_size param to aliases * move vide swin layer to model specific directory * minor fix
1 parent c123d51 commit bfeba12

13 files changed

+2171
-0
lines changed

keras_cv/models/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,11 +179,24 @@
179179
ResNetV2Backbone,
180180
)
181181
from keras_cv.models.backbones.vgg16.vgg16_backbone import VGG16Backbone
182+
from keras_cv.models.backbones.video_swin.video_swin_aliases import (
183+
VideoSwinBBackbone,
184+
)
185+
from keras_cv.models.backbones.video_swin.video_swin_aliases import (
186+
VideoSwinSBackbone,
187+
)
188+
from keras_cv.models.backbones.video_swin.video_swin_aliases import (
189+
VideoSwinTBackbone,
190+
)
191+
from keras_cv.models.backbones.video_swin.video_swin_backbone import (
192+
VideoSwinBackbone,
193+
)
182194
from keras_cv.models.backbones.vit_det.vit_det_aliases import ViTDetBBackbone
183195
from keras_cv.models.backbones.vit_det.vit_det_aliases import ViTDetHBackbone
184196
from keras_cv.models.backbones.vit_det.vit_det_aliases import ViTDetLBackbone
185197
from keras_cv.models.backbones.vit_det.vit_det_backbone import ViTDetBackbone
186198
from keras_cv.models.classification.image_classifier import ImageClassifier
199+
from keras_cv.models.classification.video_classifier import VideoClassifier
187200
from keras_cv.models.feature_extractor.clip import CLIP
188201
from keras_cv.models.object_detection.retinanet.retinanet import RetinaNet
189202
from keras_cv.models.object_detection.yolo_v8.yolo_v8_backbone import (

keras_cv/models/backbones/backbone_presets.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from keras_cv.models.backbones.mobilenet_v3 import mobilenet_v3_backbone_presets
2929
from keras_cv.models.backbones.resnet_v1 import resnet_v1_backbone_presets
3030
from keras_cv.models.backbones.resnet_v2 import resnet_v2_backbone_presets
31+
from keras_cv.models.backbones.video_swin import video_swin_backbone_presets
3132
from keras_cv.models.backbones.vit_det import vit_det_backbone_presets
3233
from keras_cv.models.object_detection.yolo_v8 import yolo_v8_backbone_presets
3334

@@ -42,6 +43,7 @@
4243
**efficientnet_lite_backbone_presets.backbone_presets_no_weights,
4344
**yolo_v8_backbone_presets.backbone_presets_no_weights,
4445
**vit_det_backbone_presets.backbone_presets_no_weights,
46+
**video_swin_backbone_presets.backbone_presets_no_weights,
4547
}
4648

4749
backbone_presets_with_weights = {
@@ -55,6 +57,7 @@
5557
**efficientnet_lite_backbone_presets.backbone_presets_with_weights,
5658
**yolo_v8_backbone_presets.backbone_presets_with_weights,
5759
**vit_det_backbone_presets.backbone_presets_with_weights,
60+
**video_swin_backbone_presets.backbone_presets_with_weights,
5861
}
5962

6063
backbone_presets = {
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2024 The KerasCV Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright 2024 The KerasCV Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import copy
16+
17+
from keras_cv.models.backbones.video_swin.video_swin_backbone import (
18+
VideoSwinBackbone,
19+
)
20+
from keras_cv.models.backbones.video_swin.video_swin_backbone_presets import (
21+
backbone_presets,
22+
)
23+
from keras_cv.utils.python_utils import classproperty
24+
25+
ALIAS_DOCSTRING = """VideoSwin{size}Backbone model.
26+
27+
Reference:
28+
- [Video Swin Transformer](https://arxiv.org/abs/2106.13230)
29+
- [Video Swin Transformer GitHub](https://github.com/SwinTransformer/Video-Swin-Transformer)
30+
31+
For transfer learning use cases, make sure to read the
32+
[guide to transfer learning & fine-tuning](https://keras.io/guides/transfer_learning/).
33+
34+
Examples:
35+
```python
36+
input_data = np.ones(shape=(1, 32, 224, 224, 3))
37+
38+
# Randomly initialized backbone
39+
model = VideoSwin{size}Backbone()
40+
output = model(input_data)
41+
```
42+
""" # noqa: E501
43+
44+
45+
class VideoSwinTBackbone(VideoSwinBackbone):
46+
def __new__(
47+
cls,
48+
embed_dim=96,
49+
depths=[2, 2, 6, 2],
50+
num_heads=[3, 6, 12, 24],
51+
window_size=[8, 7, 7],
52+
include_rescaling=True,
53+
**kwargs,
54+
):
55+
kwargs.update(
56+
{
57+
"embed_dim": embed_dim,
58+
"depths": depths,
59+
"num_heads": num_heads,
60+
"window_size": window_size,
61+
"include_rescaling": include_rescaling,
62+
}
63+
)
64+
return VideoSwinBackbone.from_preset("videoswin_tiny", **kwargs)
65+
66+
@classproperty
67+
def presets(cls):
68+
"""Dictionary of preset names and configurations."""
69+
return {
70+
"videoswin_tiny_kinetics400": copy.deepcopy(
71+
backbone_presets["videoswin_tiny_kinetics400"]
72+
),
73+
}
74+
75+
@classproperty
76+
def presets_with_weights(cls):
77+
"""Dictionary of preset names and configurations that include
78+
weights."""
79+
return cls.presets
80+
81+
82+
class VideoSwinSBackbone(VideoSwinBackbone):
83+
def __new__(
84+
cls,
85+
embed_dim=96,
86+
depths=[2, 2, 18, 2],
87+
num_heads=[3, 6, 12, 24],
88+
window_size=[8, 7, 7],
89+
include_rescaling=True,
90+
**kwargs,
91+
):
92+
kwargs.update(
93+
{
94+
"embed_dim": embed_dim,
95+
"depths": depths,
96+
"num_heads": num_heads,
97+
"window_size": window_size,
98+
"include_rescaling": include_rescaling,
99+
}
100+
)
101+
return VideoSwinBackbone.from_preset("videoswin_small", **kwargs)
102+
103+
@classproperty
104+
def presets(cls):
105+
"""Dictionary of preset names and configurations."""
106+
return {
107+
"videoswin_small_kinetics400": copy.deepcopy(
108+
backbone_presets["videoswin_small_kinetics400"]
109+
),
110+
}
111+
112+
@classproperty
113+
def presets_with_weights(cls):
114+
"""Dictionary of preset names and configurations that include
115+
weights."""
116+
return cls.presets
117+
118+
119+
class VideoSwinBBackbone(VideoSwinBackbone):
120+
def __new__(
121+
cls,
122+
embed_dim=128,
123+
depths=[2, 2, 18, 2],
124+
num_heads=[4, 8, 16, 32],
125+
window_size=[8, 7, 7],
126+
include_rescaling=True,
127+
**kwargs,
128+
):
129+
kwargs.update(
130+
{
131+
"embed_dim": embed_dim,
132+
"depths": depths,
133+
"num_heads": num_heads,
134+
"window_size": window_size,
135+
"include_rescaling": include_rescaling,
136+
}
137+
)
138+
return VideoSwinBackbone.from_preset("videoswin_base", **kwargs)
139+
140+
@classproperty
141+
def presets(cls):
142+
"""Dictionary of preset names and configurations."""
143+
return {
144+
"videoswin_base_kinetics400": copy.deepcopy(
145+
backbone_presets["videoswin_base_kinetics400"]
146+
),
147+
}
148+
149+
@classproperty
150+
def presets_with_weights(cls):
151+
"""Dictionary of preset names and configurations that include
152+
weights."""
153+
return cls.presets
154+
155+
156+
setattr(VideoSwinTBackbone, "__doc__", ALIAS_DOCSTRING.format(size="T"))
157+
setattr(VideoSwinSBackbone, "__doc__", ALIAS_DOCSTRING.format(size="S"))
158+
setattr(VideoSwinBBackbone, "__doc__", ALIAS_DOCSTRING.format(size="B"))

0 commit comments

Comments
 (0)