Skip to content

Commit 9c08586

Browse files
committed
Add tests and docstrings.
1 parent 0c1dce1 commit 9c08586

File tree

2 files changed

+289
-21
lines changed

2 files changed

+289
-21
lines changed
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import os
2+
3+
import keras
4+
import pytest
5+
from keras import ops
6+
7+
from keras_hub.src.models.dinov3.dinov3_backbone import DINOV3Backbone
8+
from keras_hub.src.tests.test_case import TestCase
9+
10+
11+
class DINOV3BackboneTest(TestCase):
12+
def setUp(self):
13+
self.init_kwargs = {
14+
"patch_size": 14,
15+
"num_layers": 2,
16+
"hidden_dim": 16,
17+
"num_heads": 2,
18+
"intermediate_dim": 16 * 4,
19+
"layer_scale_init_value": 1.0,
20+
"num_register_tokens": 4,
21+
"use_gated_mlp": False,
22+
"image_shape": (70, 70, 3),
23+
"name": "dinov3_backbone",
24+
}
25+
self.input_data = {
26+
"images": ops.ones((2, 70, 70, 3)),
27+
}
28+
29+
def test_backbone_basics(self):
30+
patch_size = self.init_kwargs["patch_size"]
31+
image_size = self.init_kwargs["image_shape"][0]
32+
hidden_dim = self.init_kwargs["hidden_dim"]
33+
num_register_tokens = self.init_kwargs["num_register_tokens"]
34+
sequence_length = (
35+
(image_size // patch_size) ** 2 + 1 + num_register_tokens
36+
)
37+
self.run_vision_backbone_test(
38+
cls=DINOV3Backbone,
39+
init_kwargs=self.init_kwargs,
40+
input_data=self.input_data,
41+
expected_output_shape=(2, sequence_length, hidden_dim),
42+
expected_pyramid_output_keys=["stem", "stage1", "stage2"],
43+
expected_pyramid_image_sizes=[(sequence_length, hidden_dim)] * 3,
44+
run_data_format_check=False,
45+
)
46+
47+
@pytest.mark.large
48+
def test_saved_model(self):
49+
self.run_model_saving_test(
50+
cls=DINOV3Backbone,
51+
init_kwargs=self.init_kwargs,
52+
input_data=self.input_data,
53+
)
54+
55+
@pytest.mark.large
56+
def test_position_embedding_interpolation(self):
57+
model = DINOV3Backbone(**self.init_kwargs)
58+
model_output = model(self.input_data)
59+
60+
# Test not using interpolation in `save` and `load_model`.
61+
path = os.path.join(self.get_temp_dir(), "model.keras")
62+
model.save(path)
63+
restored_model = keras.models.load_model(path)
64+
restored_output = restored_model(self.input_data)
65+
self.assertAllClose(model_output, restored_output, atol=1e-5, rtol=1e-5)
66+
67+
# Test using interpolation in `save_to_preset` and `from_preset` if
68+
# image_shape is different.
69+
path = os.path.join(self.get_temp_dir(), "model")
70+
model.save_to_preset(path)
71+
restored_model = DINOV3Backbone.from_preset(
72+
path,
73+
image_shape=(128, 128, 3), # From 70 to 128.
74+
)
75+
input_data = {
76+
"images": ops.ones((2, 128, 128, 3)),
77+
}
78+
restored_output = restored_model(input_data)
79+
self.assertNotEqual(model_output.shape, restored_output.shape)
80+
81+
@pytest.mark.kaggle_key_required
82+
@pytest.mark.extra_large
83+
def test_smallest_preset(self):
84+
self.skipTest("Presets are not uploaded yet.")
85+
self.run_preset_test(
86+
cls=DINOV3Backbone,
87+
preset="dinov3_vit_small_lvd1689m",
88+
input_data=self.input_data,
89+
expected_output_shape=(2, 1374, 768),
90+
)
91+
92+
@pytest.mark.kaggle_key_required
93+
@pytest.mark.extra_large
94+
def test_all_presets(self):
95+
self.skipTest("Presets are not uploaded yet.")
96+
for preset in DINOV3Backbone.presets:
97+
self.run_preset_test(
98+
cls=DINOV3Backbone,
99+
preset=preset,
100+
input_data=self.input_data,
101+
)

0 commit comments

Comments
 (0)