Skip to content

Commit c7509b5

Browse files
committed
Add Gemma3 Conversion script to port weights from HF directly
1 parent e74791e commit c7509b5

File tree

3 files changed

+503
-0
lines changed

3 files changed

+503
-0
lines changed
Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
import numpy as np
2+
3+
from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone
4+
from keras_hub.src.models.gemma3.gemma3_vision_encoder import (
5+
Gemma3VisionEncoder,
6+
)
7+
from keras_hub.src.utils.preset_utils import get_file
8+
9+
backbone_cls = Gemma3Backbone
10+
11+
12+
def load_image_converter_config(transformers_config):
13+
if "vision_config" in transformers_config:
14+
image_size = transformers_config["vision_config"].get("image_size", 224)
15+
return {
16+
"image_size": (image_size, image_size),
17+
"scale": 1 / 127.5,
18+
"offset": -1.0,
19+
}
20+
else:
21+
return None
22+
23+
24+
def convert_backbone_config(transformers_config):
25+
if transformers_config["model_type"] == "gemma3_text":
26+
image_size = None
27+
vision_encoder = None
28+
transformer_config = transformers_config
29+
else:
30+
image_size = transformers_config["vision_config"].get("image_size", 224)
31+
vision_encoder_config = {
32+
"image_size": image_size,
33+
"patch_size": transformers_config["vision_config"].get(
34+
"patch_size", 16
35+
),
36+
"num_heads": transformers_config["vision_config"].get(
37+
"num_attention_heads", 12
38+
),
39+
"hidden_dim": transformers_config["vision_config"].get(
40+
"hidden_size", 768
41+
),
42+
"num_layers": transformers_config["vision_config"].get(
43+
"num_hidden_layers", 12
44+
),
45+
"intermediate_dim": transformers_config["vision_config"].get(
46+
"intermediate_size", 3072
47+
),
48+
"output_dim": 2560,
49+
"pool_size": 4,
50+
"layer_norm_epsilon": transformers_config["vision_config"].get(
51+
"layer_norm_eps", 1e-6
52+
),
53+
}
54+
vision_encoder = Gemma3VisionEncoder(**vision_encoder_config)
55+
transformer_config = transformers_config["text_config"]
56+
57+
return {
58+
"vocabulary_size": transformer_config.get(
59+
"vocab_size", 262144 if vision_encoder is None else 262208
60+
),
61+
"image_size": image_size,
62+
"num_layers": transformer_config.get("num_hidden_layers", 26),
63+
"num_query_heads": transformer_config.get("num_attention_heads", 8),
64+
"num_key_value_heads": transformer_config.get("num_key_value_heads", 4),
65+
"hidden_dim": transformer_config.get("hidden_size", 2304),
66+
"intermediate_dim": transformer_config.get("intermediate_size", 9216),
67+
"head_dim": transformer_config.get("head_dim", 256),
68+
"use_post_ffw_norm": True,
69+
"use_post_attention_norm": True,
70+
"attention_logit_softcap": transformer_config.get(
71+
"attn_logit_softcap", None
72+
),
73+
"final_logit_softcap": transformer_config.get(
74+
"final_logit_softcap", None
75+
),
76+
"use_sliding_window_attention": True,
77+
"query_head_dim_normalize": True,
78+
"sliding_window_size": transformer_config.get("sliding_window", 4096),
79+
"local_rope_scaling_factor": 1.0,
80+
"global_rope_scaling_factor": (
81+
transformer_config.get("rope_scaling") or {}
82+
).get("factor", 1.0),
83+
"layer_norm_epsilon": transformer_config.get("rms_norm_eps", 1e-6),
84+
"use_bidirectional_attention": transformer_config.get(
85+
"use_bidirectional_attention", False
86+
),
87+
"vision_encoder": vision_encoder,
88+
}
89+
90+
91+
def convert_weights(backbone, loader, transformers_config):
92+
if transformers_config["model_type"] == "gemma3_text":
93+
prefix = "model"
94+
else:
95+
prefix = "language_model.model"
96+
97+
loader.port_weight(
98+
keras_variable=backbone.get_layer("token_embedding").embeddings,
99+
hf_weight_key=f"{prefix}.embed_tokens.weight",
100+
)
101+
102+
def transpose(x, shape):
103+
return np.transpose(x)
104+
105+
vision_encoder = backbone.vision_encoder
106+
if vision_encoder is not None:
107+
image_encoder = vision_encoder.get_layer("image_encoder")
108+
109+
loader.port_weight(
110+
keras_variable=image_encoder.vision_embeddings.patch_embedding.kernel,
111+
hf_weight_key="vision_tower.vision_model.embeddings.patch_embedding.weight",
112+
hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
113+
)
114+
loader.port_weight(
115+
keras_variable=image_encoder.vision_embeddings.patch_embedding.bias,
116+
hf_weight_key="vision_tower.vision_model.embeddings.patch_embedding.bias",
117+
)
118+
119+
loader.port_weight(
120+
keras_variable=image_encoder.vision_embeddings.position_embedding.embeddings,
121+
hf_weight_key="vision_tower.vision_model.embeddings.position_embedding.weight",
122+
)
123+
124+
for i in range(image_encoder.num_layers):
125+
loader.port_weight(
126+
keras_variable=image_encoder.resblocks[i].layer_norm_1.gamma,
127+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight",
128+
)
129+
loader.port_weight(
130+
keras_variable=image_encoder.resblocks[i].layer_norm_1.beta,
131+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias",
132+
)
133+
loader.port_weight(
134+
keras_variable=image_encoder.resblocks[
135+
i
136+
].attn.query_proj.kernel,
137+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight",
138+
hook_fn=transpose,
139+
)
140+
loader.port_weight(
141+
keras_variable=image_encoder.resblocks[i].attn.query_proj.bias,
142+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias",
143+
)
144+
loader.port_weight(
145+
keras_variable=image_encoder.resblocks[i].attn.key_proj.kernel,
146+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight",
147+
hook_fn=transpose,
148+
)
149+
loader.port_weight(
150+
keras_variable=image_encoder.resblocks[i].attn.key_proj.bias,
151+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias",
152+
)
153+
loader.port_weight(
154+
keras_variable=image_encoder.resblocks[
155+
i
156+
].attn.value_proj.kernel,
157+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight",
158+
hook_fn=transpose,
159+
)
160+
loader.port_weight(
161+
keras_variable=image_encoder.resblocks[i].attn.value_proj.bias,
162+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias",
163+
)
164+
loader.port_weight(
165+
keras_variable=image_encoder.resblocks[i].attn.out_proj.kernel,
166+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight",
167+
hook_fn=transpose,
168+
)
169+
loader.port_weight(
170+
keras_variable=image_encoder.resblocks[i].attn.out_proj.bias,
171+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias",
172+
)
173+
174+
loader.port_weight(
175+
keras_variable=image_encoder.resblocks[i].layer_norm_2.gamma,
176+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight",
177+
)
178+
loader.port_weight(
179+
keras_variable=image_encoder.resblocks[i].layer_norm_2.beta,
180+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias",
181+
)
182+
loader.port_weight(
183+
keras_variable=image_encoder.resblocks[i].mlp_dense_1.kernel,
184+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight",
185+
hook_fn=transpose,
186+
)
187+
loader.port_weight(
188+
keras_variable=image_encoder.resblocks[i].mlp_dense_1.bias,
189+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias",
190+
)
191+
loader.port_weight(
192+
keras_variable=image_encoder.resblocks[i].mlp_dense_2.kernel,
193+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight",
194+
hook_fn=transpose,
195+
)
196+
loader.port_weight(
197+
keras_variable=image_encoder.resblocks[i].mlp_dense_2.bias,
198+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias",
199+
)
200+
201+
loader.port_weight(
202+
keras_variable=image_encoder.encoder_layer_norm.gamma,
203+
hf_weight_key="vision_tower.vision_model.post_layernorm.weight",
204+
)
205+
loader.port_weight(
206+
keras_variable=image_encoder.encoder_layer_norm.beta,
207+
hf_weight_key="vision_tower.vision_model.post_layernorm.bias",
208+
)
209+
210+
loader.port_weight(
211+
keras_variable=vision_encoder.get_layer(
212+
"vision_output_encoder"
213+
).vision_soft_embedding_norm.scale,
214+
hf_weight_key="multi_modal_projector.mm_soft_emb_norm.weight",
215+
)
216+
217+
loader.port_weight(
218+
keras_variable=vision_encoder.get_layer(
219+
"vision_output_encoder"
220+
).vision_input_projection.kernel,
221+
hf_weight_key="multi_modal_projector.mm_input_projection_weight",
222+
)
223+
224+
for i in range(backbone.num_layers):
225+
decoder_layer = backbone.get_layer(f"decoder_block_{i}")
226+
227+
loader.port_weight(
228+
keras_variable=decoder_layer.pre_attention_norm.scale,
229+
hf_weight_key=f"{prefix}.layers.{i}.input_layernorm.weight",
230+
)
231+
loader.port_weight(
232+
keras_variable=decoder_layer.post_attention_norm.scale,
233+
hf_weight_key=f"{prefix}.layers.{i}.post_attention_layernorm.weight",
234+
)
235+
loader.port_weight(
236+
keras_variable=decoder_layer.pre_ffw_norm.scale,
237+
hf_weight_key=f"{prefix}.layers.{i}.pre_feedforward_layernorm.weight",
238+
)
239+
loader.port_weight(
240+
keras_variable=decoder_layer.post_ffw_norm.scale,
241+
hf_weight_key=f"{prefix}.layers.{i}.post_feedforward_layernorm.weight",
242+
)
243+
244+
# Attention layers
245+
246+
## Query
247+
loader.port_weight(
248+
keras_variable=decoder_layer.attention.query_dense.kernel,
249+
hf_weight_key=f"{prefix}.layers.{i}.self_attn.q_proj.weight",
250+
hook_fn=lambda hf_tensor, keras_shape: np.transpose(
251+
np.reshape(
252+
hf_tensor,
253+
(keras_shape[0], keras_shape[2], keras_shape[1]),
254+
),
255+
axes=(0, 2, 1),
256+
),
257+
)
258+
loader.port_weight(
259+
keras_variable=decoder_layer.attention.query_norm.scale,
260+
hf_weight_key=f"{prefix}.layers.{i}.self_attn.q_norm.weight",
261+
)
262+
## Key
263+
loader.port_weight(
264+
keras_variable=decoder_layer.attention.key_dense.kernel,
265+
hf_weight_key=f"{prefix}.layers.{i}.self_attn.k_proj.weight",
266+
hook_fn=lambda hf_tensor, keras_shape: np.transpose(
267+
np.reshape(
268+
hf_tensor,
269+
(keras_shape[0], keras_shape[2], keras_shape[1]),
270+
),
271+
axes=(0, 2, 1),
272+
),
273+
)
274+
loader.port_weight(
275+
keras_variable=decoder_layer.attention.key_norm.scale,
276+
hf_weight_key=f"{prefix}.layers.{i}.self_attn.k_norm.weight",
277+
)
278+
## Value
279+
loader.port_weight(
280+
keras_variable=decoder_layer.attention.value_dense.kernel,
281+
hf_weight_key=f"{prefix}.layers.{i}.self_attn.v_proj.weight",
282+
hook_fn=lambda hf_tensor, keras_shape: np.transpose(
283+
np.reshape(
284+
hf_tensor,
285+
(keras_shape[0], keras_shape[2], keras_shape[1]),
286+
),
287+
axes=(0, 2, 1),
288+
),
289+
)
290+
## Output
291+
loader.port_weight(
292+
keras_variable=decoder_layer.attention.output_dense.kernel,
293+
hf_weight_key=f"{prefix}.layers.{i}.self_attn.o_proj.weight",
294+
# rearrange_patterns="c (a b) -> a b c",
295+
# rearrange_dims={"a": backbone.num_query_heads},
296+
hook_fn=lambda hf_tensor, keras_shape: np.transpose(
297+
np.reshape(
298+
hf_tensor,
299+
(keras_shape[2], keras_shape[0], keras_shape[1]),
300+
),
301+
axes=(1, 2, 0),
302+
),
303+
)
304+
305+
# MLP layers
306+
loader.port_weight(
307+
keras_variable=decoder_layer.gating_ffw.kernel,
308+
hf_weight_key=f"{prefix}.layers.{i}.mlp.gate_proj.weight",
309+
# rearrange_patterns="b a -> a b",
310+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
311+
)
312+
loader.port_weight(
313+
keras_variable=decoder_layer.gating_ffw_2.kernel,
314+
hf_weight_key=f"{prefix}.layers.{i}.mlp.up_proj.weight",
315+
# rearrange_patterns="b a -> a b",
316+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
317+
)
318+
loader.port_weight(
319+
keras_variable=decoder_layer.ffw_linear.kernel,
320+
hf_weight_key=f"{prefix}.layers.{i}.mlp.down_proj.weight",
321+
# rearrange_patterns="b a -> a b",
322+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
323+
)
324+
325+
# Final normalization layer
326+
loader.port_weight(
327+
keras_variable=backbone.get_layer("final_normalization").scale,
328+
hf_weight_key=f"{prefix}.norm.weight",
329+
)
330+
331+
return backbone
332+
333+
334+
def convert_tokenizer(cls, preset, **kwargs):
335+
return cls(get_file(preset, "tokenizer.model"), **kwargs)

keras_hub/src/utils/transformers/preset_loader.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from keras_hub.src.utils.transformers import convert_distilbert
1212
from keras_hub.src.utils.transformers import convert_esm
1313
from keras_hub.src.utils.transformers import convert_gemma
14+
from keras_hub.src.utils.transformers import convert_gemma3
1415
from keras_hub.src.utils.transformers import convert_gpt2
1516
from keras_hub.src.utils.transformers import convert_llama3
1617
from keras_hub.src.utils.transformers import convert_mistral
@@ -46,6 +47,8 @@ def __init__(self, preset, config):
4647
self.converter = convert_esm
4748
elif model_type in ("gemma", "gemma2"):
4849
self.converter = convert_gemma
50+
elif model_type in ("gemma3", "gemma3_text"):
51+
self.converter = convert_gemma3
4952
elif model_type == "gpt2":
5053
self.converter = convert_gpt2
5154
elif model_type == "llama":

0 commit comments

Comments
 (0)