|
| 1 | +import numpy as np |
| 2 | + |
| 3 | +from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone |
| 4 | +from keras_hub.src.models.gemma3.gemma3_vision_encoder import ( |
| 5 | + Gemma3VisionEncoder, |
| 6 | +) |
| 7 | +from keras_hub.src.utils.preset_utils import get_file |
| 8 | + |
| 9 | +backbone_cls = Gemma3Backbone |
| 10 | + |
| 11 | + |
| 12 | +def load_image_converter_config(transformers_config): |
| 13 | + if "vision_config" in transformers_config: |
| 14 | + image_size = transformers_config["vision_config"].get("image_size", 224) |
| 15 | + return { |
| 16 | + "image_size": (image_size, image_size), |
| 17 | + "scale": 1 / 127.5, |
| 18 | + "offset": -1.0, |
| 19 | + } |
| 20 | + else: |
| 21 | + return None |
| 22 | + |
| 23 | + |
| 24 | +def convert_backbone_config(transformers_config): |
| 25 | + if transformers_config["model_type"] == "gemma3_text": |
| 26 | + image_size = None |
| 27 | + vision_encoder = None |
| 28 | + transformer_config = transformers_config |
| 29 | + else: |
| 30 | + image_size = transformers_config["vision_config"].get("image_size", 224) |
| 31 | + vision_encoder_config = { |
| 32 | + "image_size": image_size, |
| 33 | + "patch_size": transformers_config["vision_config"].get( |
| 34 | + "patch_size", 16 |
| 35 | + ), |
| 36 | + "num_heads": transformers_config["vision_config"].get( |
| 37 | + "num_attention_heads", 12 |
| 38 | + ), |
| 39 | + "hidden_dim": transformers_config["vision_config"].get( |
| 40 | + "hidden_size", 768 |
| 41 | + ), |
| 42 | + "num_layers": transformers_config["vision_config"].get( |
| 43 | + "num_hidden_layers", 12 |
| 44 | + ), |
| 45 | + "intermediate_dim": transformers_config["vision_config"].get( |
| 46 | + "intermediate_size", 3072 |
| 47 | + ), |
| 48 | + "output_dim": 2560, |
| 49 | + "pool_size": 4, |
| 50 | + "layer_norm_epsilon": transformers_config["vision_config"].get( |
| 51 | + "layer_norm_eps", 1e-6 |
| 52 | + ), |
| 53 | + } |
| 54 | + vision_encoder = Gemma3VisionEncoder(**vision_encoder_config) |
| 55 | + transformer_config = transformers_config["text_config"] |
| 56 | + |
| 57 | + return { |
| 58 | + "vocabulary_size": transformer_config.get( |
| 59 | + "vocab_size", 262144 if vision_encoder is None else 262208 |
| 60 | + ), |
| 61 | + "image_size": image_size, |
| 62 | + "num_layers": transformer_config.get("num_hidden_layers", 26), |
| 63 | + "num_query_heads": transformer_config.get("num_attention_heads", 8), |
| 64 | + "num_key_value_heads": transformer_config.get("num_key_value_heads", 4), |
| 65 | + "hidden_dim": transformer_config.get("hidden_size", 2304), |
| 66 | + "intermediate_dim": transformer_config.get("intermediate_size", 9216), |
| 67 | + "head_dim": transformer_config.get("head_dim", 256), |
| 68 | + "use_post_ffw_norm": True, |
| 69 | + "use_post_attention_norm": True, |
| 70 | + "attention_logit_softcap": transformer_config.get( |
| 71 | + "attn_logit_softcap", None |
| 72 | + ), |
| 73 | + "final_logit_softcap": transformer_config.get( |
| 74 | + "final_logit_softcap", None |
| 75 | + ), |
| 76 | + "use_sliding_window_attention": True, |
| 77 | + "query_head_dim_normalize": True, |
| 78 | + "sliding_window_size": transformer_config.get("sliding_window", 4096), |
| 79 | + "local_rope_scaling_factor": 1.0, |
| 80 | + "global_rope_scaling_factor": ( |
| 81 | + transformer_config.get("rope_scaling") or {} |
| 82 | + ).get("factor", 1.0), |
| 83 | + "layer_norm_epsilon": transformer_config.get("rms_norm_eps", 1e-6), |
| 84 | + "use_bidirectional_attention": transformer_config.get( |
| 85 | + "use_bidirectional_attention", False |
| 86 | + ), |
| 87 | + "vision_encoder": vision_encoder, |
| 88 | + } |
| 89 | + |
| 90 | + |
| 91 | +def convert_weights(backbone, loader, transformers_config): |
| 92 | + if transformers_config["model_type"] == "gemma3_text": |
| 93 | + prefix = "model" |
| 94 | + else: |
| 95 | + prefix = "language_model.model" |
| 96 | + |
| 97 | + loader.port_weight( |
| 98 | + keras_variable=backbone.get_layer("token_embedding").embeddings, |
| 99 | + hf_weight_key=f"{prefix}.embed_tokens.weight", |
| 100 | + ) |
| 101 | + |
| 102 | + def transpose(x, shape): |
| 103 | + return np.transpose(x) |
| 104 | + |
| 105 | + vision_encoder = backbone.vision_encoder |
| 106 | + if vision_encoder is not None: |
| 107 | + image_encoder = vision_encoder.get_layer("image_encoder") |
| 108 | + |
| 109 | + loader.port_weight( |
| 110 | + keras_variable=image_encoder.vision_embeddings.patch_embedding.kernel, |
| 111 | + hf_weight_key="vision_tower.vision_model.embeddings.patch_embedding.weight", |
| 112 | + hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), |
| 113 | + ) |
| 114 | + loader.port_weight( |
| 115 | + keras_variable=image_encoder.vision_embeddings.patch_embedding.bias, |
| 116 | + hf_weight_key="vision_tower.vision_model.embeddings.patch_embedding.bias", |
| 117 | + ) |
| 118 | + |
| 119 | + loader.port_weight( |
| 120 | + keras_variable=image_encoder.vision_embeddings.position_embedding.embeddings, |
| 121 | + hf_weight_key="vision_tower.vision_model.embeddings.position_embedding.weight", |
| 122 | + ) |
| 123 | + |
| 124 | + for i in range(image_encoder.num_layers): |
| 125 | + loader.port_weight( |
| 126 | + keras_variable=image_encoder.resblocks[i].layer_norm_1.gamma, |
| 127 | + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight", |
| 128 | + ) |
| 129 | + loader.port_weight( |
| 130 | + keras_variable=image_encoder.resblocks[i].layer_norm_1.beta, |
| 131 | + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias", |
| 132 | + ) |
| 133 | + loader.port_weight( |
| 134 | + keras_variable=image_encoder.resblocks[ |
| 135 | + i |
| 136 | + ].attn.query_proj.kernel, |
| 137 | + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight", |
| 138 | + hook_fn=transpose, |
| 139 | + ) |
| 140 | + loader.port_weight( |
| 141 | + keras_variable=image_encoder.resblocks[i].attn.query_proj.bias, |
| 142 | + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias", |
| 143 | + ) |
| 144 | + loader.port_weight( |
| 145 | + keras_variable=image_encoder.resblocks[i].attn.key_proj.kernel, |
| 146 | + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight", |
| 147 | + hook_fn=transpose, |
| 148 | + ) |
| 149 | + loader.port_weight( |
| 150 | + keras_variable=image_encoder.resblocks[i].attn.key_proj.bias, |
| 151 | + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias", |
| 152 | + ) |
| 153 | + loader.port_weight( |
| 154 | + keras_variable=image_encoder.resblocks[ |
| 155 | + i |
| 156 | + ].attn.value_proj.kernel, |
| 157 | + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight", |
| 158 | + hook_fn=transpose, |
| 159 | + ) |
| 160 | + loader.port_weight( |
| 161 | + keras_variable=image_encoder.resblocks[i].attn.value_proj.bias, |
| 162 | + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias", |
| 163 | + ) |
| 164 | + loader.port_weight( |
| 165 | + keras_variable=image_encoder.resblocks[i].attn.out_proj.kernel, |
| 166 | + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight", |
| 167 | + hook_fn=transpose, |
| 168 | + ) |
| 169 | + loader.port_weight( |
| 170 | + keras_variable=image_encoder.resblocks[i].attn.out_proj.bias, |
| 171 | + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias", |
| 172 | + ) |
| 173 | + |
| 174 | + loader.port_weight( |
| 175 | + keras_variable=image_encoder.resblocks[i].layer_norm_2.gamma, |
| 176 | + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight", |
| 177 | + ) |
| 178 | + loader.port_weight( |
| 179 | + keras_variable=image_encoder.resblocks[i].layer_norm_2.beta, |
| 180 | + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias", |
| 181 | + ) |
| 182 | + loader.port_weight( |
| 183 | + keras_variable=image_encoder.resblocks[i].mlp_dense_1.kernel, |
| 184 | + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight", |
| 185 | + hook_fn=transpose, |
| 186 | + ) |
| 187 | + loader.port_weight( |
| 188 | + keras_variable=image_encoder.resblocks[i].mlp_dense_1.bias, |
| 189 | + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias", |
| 190 | + ) |
| 191 | + loader.port_weight( |
| 192 | + keras_variable=image_encoder.resblocks[i].mlp_dense_2.kernel, |
| 193 | + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight", |
| 194 | + hook_fn=transpose, |
| 195 | + ) |
| 196 | + loader.port_weight( |
| 197 | + keras_variable=image_encoder.resblocks[i].mlp_dense_2.bias, |
| 198 | + hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias", |
| 199 | + ) |
| 200 | + |
| 201 | + loader.port_weight( |
| 202 | + keras_variable=image_encoder.encoder_layer_norm.gamma, |
| 203 | + hf_weight_key="vision_tower.vision_model.post_layernorm.weight", |
| 204 | + ) |
| 205 | + loader.port_weight( |
| 206 | + keras_variable=image_encoder.encoder_layer_norm.beta, |
| 207 | + hf_weight_key="vision_tower.vision_model.post_layernorm.bias", |
| 208 | + ) |
| 209 | + |
| 210 | + loader.port_weight( |
| 211 | + keras_variable=vision_encoder.get_layer( |
| 212 | + "vision_output_encoder" |
| 213 | + ).vision_soft_embedding_norm.scale, |
| 214 | + hf_weight_key="multi_modal_projector.mm_soft_emb_norm.weight", |
| 215 | + ) |
| 216 | + |
| 217 | + loader.port_weight( |
| 218 | + keras_variable=vision_encoder.get_layer( |
| 219 | + "vision_output_encoder" |
| 220 | + ).vision_input_projection.kernel, |
| 221 | + hf_weight_key="multi_modal_projector.mm_input_projection_weight", |
| 222 | + ) |
| 223 | + |
| 224 | + for i in range(backbone.num_layers): |
| 225 | + decoder_layer = backbone.get_layer(f"decoder_block_{i}") |
| 226 | + |
| 227 | + loader.port_weight( |
| 228 | + keras_variable=decoder_layer.pre_attention_norm.scale, |
| 229 | + hf_weight_key=f"{prefix}.layers.{i}.input_layernorm.weight", |
| 230 | + ) |
| 231 | + loader.port_weight( |
| 232 | + keras_variable=decoder_layer.post_attention_norm.scale, |
| 233 | + hf_weight_key=f"{prefix}.layers.{i}.post_attention_layernorm.weight", |
| 234 | + ) |
| 235 | + loader.port_weight( |
| 236 | + keras_variable=decoder_layer.pre_ffw_norm.scale, |
| 237 | + hf_weight_key=f"{prefix}.layers.{i}.pre_feedforward_layernorm.weight", |
| 238 | + ) |
| 239 | + loader.port_weight( |
| 240 | + keras_variable=decoder_layer.post_ffw_norm.scale, |
| 241 | + hf_weight_key=f"{prefix}.layers.{i}.post_feedforward_layernorm.weight", |
| 242 | + ) |
| 243 | + |
| 244 | + # Attention layers |
| 245 | + |
| 246 | + ## Query |
| 247 | + loader.port_weight( |
| 248 | + keras_variable=decoder_layer.attention.query_dense.kernel, |
| 249 | + hf_weight_key=f"{prefix}.layers.{i}.self_attn.q_proj.weight", |
| 250 | + hook_fn=lambda hf_tensor, keras_shape: np.transpose( |
| 251 | + np.reshape( |
| 252 | + hf_tensor, |
| 253 | + (keras_shape[0], keras_shape[2], keras_shape[1]), |
| 254 | + ), |
| 255 | + axes=(0, 2, 1), |
| 256 | + ), |
| 257 | + ) |
| 258 | + loader.port_weight( |
| 259 | + keras_variable=decoder_layer.attention.query_norm.scale, |
| 260 | + hf_weight_key=f"{prefix}.layers.{i}.self_attn.q_norm.weight", |
| 261 | + ) |
| 262 | + ## Key |
| 263 | + loader.port_weight( |
| 264 | + keras_variable=decoder_layer.attention.key_dense.kernel, |
| 265 | + hf_weight_key=f"{prefix}.layers.{i}.self_attn.k_proj.weight", |
| 266 | + hook_fn=lambda hf_tensor, keras_shape: np.transpose( |
| 267 | + np.reshape( |
| 268 | + hf_tensor, |
| 269 | + (keras_shape[0], keras_shape[2], keras_shape[1]), |
| 270 | + ), |
| 271 | + axes=(0, 2, 1), |
| 272 | + ), |
| 273 | + ) |
| 274 | + loader.port_weight( |
| 275 | + keras_variable=decoder_layer.attention.key_norm.scale, |
| 276 | + hf_weight_key=f"{prefix}.layers.{i}.self_attn.k_norm.weight", |
| 277 | + ) |
| 278 | + ## Value |
| 279 | + loader.port_weight( |
| 280 | + keras_variable=decoder_layer.attention.value_dense.kernel, |
| 281 | + hf_weight_key=f"{prefix}.layers.{i}.self_attn.v_proj.weight", |
| 282 | + hook_fn=lambda hf_tensor, keras_shape: np.transpose( |
| 283 | + np.reshape( |
| 284 | + hf_tensor, |
| 285 | + (keras_shape[0], keras_shape[2], keras_shape[1]), |
| 286 | + ), |
| 287 | + axes=(0, 2, 1), |
| 288 | + ), |
| 289 | + ) |
| 290 | + ## Output |
| 291 | + loader.port_weight( |
| 292 | + keras_variable=decoder_layer.attention.output_dense.kernel, |
| 293 | + hf_weight_key=f"{prefix}.layers.{i}.self_attn.o_proj.weight", |
| 294 | + # rearrange_patterns="c (a b) -> a b c", |
| 295 | + # rearrange_dims={"a": backbone.num_query_heads}, |
| 296 | + hook_fn=lambda hf_tensor, keras_shape: np.transpose( |
| 297 | + np.reshape( |
| 298 | + hf_tensor, |
| 299 | + (keras_shape[2], keras_shape[0], keras_shape[1]), |
| 300 | + ), |
| 301 | + axes=(1, 2, 0), |
| 302 | + ), |
| 303 | + ) |
| 304 | + |
| 305 | + # MLP layers |
| 306 | + loader.port_weight( |
| 307 | + keras_variable=decoder_layer.gating_ffw.kernel, |
| 308 | + hf_weight_key=f"{prefix}.layers.{i}.mlp.gate_proj.weight", |
| 309 | + # rearrange_patterns="b a -> a b", |
| 310 | + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), |
| 311 | + ) |
| 312 | + loader.port_weight( |
| 313 | + keras_variable=decoder_layer.gating_ffw_2.kernel, |
| 314 | + hf_weight_key=f"{prefix}.layers.{i}.mlp.up_proj.weight", |
| 315 | + # rearrange_patterns="b a -> a b", |
| 316 | + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), |
| 317 | + ) |
| 318 | + loader.port_weight( |
| 319 | + keras_variable=decoder_layer.ffw_linear.kernel, |
| 320 | + hf_weight_key=f"{prefix}.layers.{i}.mlp.down_proj.weight", |
| 321 | + # rearrange_patterns="b a -> a b", |
| 322 | + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), |
| 323 | + ) |
| 324 | + |
| 325 | + # Final normalization layer |
| 326 | + loader.port_weight( |
| 327 | + keras_variable=backbone.get_layer("final_normalization").scale, |
| 328 | + hf_weight_key=f"{prefix}.norm.weight", |
| 329 | + ) |
| 330 | + |
| 331 | + return backbone |
| 332 | + |
| 333 | + |
| 334 | +def convert_tokenizer(cls, preset, **kwargs): |
| 335 | + return cls(get_file(preset, "tokenizer.model"), **kwargs) |
0 commit comments