- 
                Notifications
    You must be signed in to change notification settings 
- Fork 307
Gemma3 text keras hf checkpoint conversion #2433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
938b52b
              1f06acb
              24c9573
              71bb3af
              85f9498
              69a7137
              525da45
              06ed2ad
              ab1bde1
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,177 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import keras.ops as ops | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_gemma3_config(backbone): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Convert Keras Gemma3 config to Hugging Face config dictionary.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| token_embedding_layer = backbone.get_layer("token_embedding") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hf_config = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "architectures": ["Gemma3ForCausalLM"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "model_type": "gemma3_text", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "vocab_size": backbone.vocabulary_size, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "num_hidden_layers": backbone.num_layers, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "num_attention_heads": backbone.num_query_heads, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "num_key_value_heads": backbone.num_key_value_heads, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "hidden_size": backbone.hidden_dim, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "intermediate_size": backbone.intermediate_dim, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "head_dim": backbone.head_dim, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "max_position_embeddings": 32768, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "tie_word_embeddings": token_embedding_layer.tie_weights, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "rms_norm_eps": 1e-6, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "rope_theta": 10000.0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "attention_bias": False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "attention_dropout": 0.0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "hidden_activation": "gelu_pytorch_tanh", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return hf_config | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_gemma3_weights_map(backbone, include_lm_head=False): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Convert a Keras Gemma3 model to Hugging Face format. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| include_lm_head: If True, exports for CausalLM (with "model." prefix). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| If False, exports for backbone only (without prefix). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weights_dict = {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # For CausalLM export, use "model." prefix | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # For backbone export, use no prefix | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prefix = "model." if include_lm_head else "" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Token embeddings - use .weights[0] to get backend tensor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| token_embedding_layer = backbone.get_layer("token_embedding") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| token_embedding = token_embedding_layer.weights[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weights_dict[f"{prefix}embed_tokens.weight"] = token_embedding | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for i in range(backbone.num_layers): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block = backbone.get_layer(f"decoder_block_{i}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Attention query projection | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| q_kernel = block.attention.query_dense.weights[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| q_kernel = ops.transpose(q_kernel, axes=(1, 0, 2)) # permute(1, 0, 2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| q_kernel = ops.reshape(q_kernel, (backbone.hidden_dim, -1)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| q_kernel = ops.transpose(q_kernel) # .T | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weights_dict[f"{prefix}layers.{i}.self_attn.q_proj.weight"] = q_kernel | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Attention key projection | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| k_kernel = block.attention.key_dense.weights[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| k_kernel = ops.transpose(k_kernel, axes=(1, 0, 2)) # permute(1, 0, 2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| k_kernel = ops.reshape(k_kernel, (backbone.hidden_dim, -1)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| k_kernel = ops.transpose(k_kernel) # .T | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weights_dict[f"{prefix}layers.{i}.self_attn.k_proj.weight"] = k_kernel | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Attention value projection | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| v_kernel = block.attention.value_dense.weights[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| v_kernel = ops.transpose(v_kernel, axes=(1, 0, 2)) # permute(1, 0, 2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| v_kernel = ops.reshape(v_kernel, (backbone.hidden_dim, -1)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| v_kernel = ops.transpose(v_kernel) # .T | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weights_dict[f"{prefix}layers.{i}.self_attn.v_proj.weight"] = v_kernel | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Attention output projection | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| o_kernel = block.attention.output_dense.weights[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| o_kernel = ops.transpose(o_kernel, axes=(2, 0, 1)) # permute(2, 0, 1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| o_kernel = ops.reshape(o_kernel, (backbone.hidden_dim, -1)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weights_dict[f"{prefix}layers.{i}.self_attn.o_proj.weight"] = o_kernel | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Query and key normalization | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| q_norm = block.attention.query_norm.weights[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weights_dict[f"{prefix}layers.{i}.self_attn.q_norm.weight"] = q_norm | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| k_norm = block.attention.key_norm.weights[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weights_dict[f"{prefix}layers.{i}.self_attn.k_norm.weight"] = k_norm | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # MLP gate projection | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| gate_kernel = block.gating_ffw.weights[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| gate_kernel = ops.transpose(gate_kernel) # .T | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weights_dict[f"{prefix}layers.{i}.mlp.gate_proj.weight"] = gate_kernel | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # MLP up projection | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| up_kernel = block.gating_ffw_2.weights[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| up_kernel = ops.transpose(up_kernel) # .T | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weights_dict[f"{prefix}layers.{i}.mlp.up_proj.weight"] = up_kernel | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # MLP down projection | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| down_kernel = block.ffw_linear.weights[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| down_kernel = ops.transpose(down_kernel) # .T | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weights_dict[f"{prefix}layers.{i}.mlp.down_proj.weight"] = down_kernel | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Pre-attention normalization | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_layer_norm = block.pre_attention_norm.weights[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weights_dict[f"{prefix}layers.{i}.input_layernorm.weight"] = ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_layer_norm | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Post-attention normalization | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if hasattr(block, "post_attention_norm"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| post_attn_norm = block.post_attention_norm.weights[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Fallback to pre_ffw_norm if post_attention_norm doesn't exist | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| post_attn_norm = block.pre_ffw_norm.weights[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weights_dict[f"{prefix}layers.{i}.post_attention_layernorm.weight"] = ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| post_attn_norm | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Pre-feedforward normalization | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pre_feedforward_layernorm = block.pre_ffw_norm.weights[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weights_dict[f"{prefix}layers.{i}.pre_feedforward_layernorm.weight"] = ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pre_feedforward_layernorm | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Post-feedforward normalization | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if hasattr(block, "post_ffw_norm"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| post_feedforward_layernorm = block.post_ffw_norm.weights[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Fallback to pre_ffw_norm if post_ffw_norm doesn't exist | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| post_feedforward_layernorm = block.pre_ffw_norm.weights[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weights_dict[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"{prefix}layers.{i}.post_feedforward_layernorm.weight" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] = post_feedforward_layernorm | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 
      Comment on lines
    
      +105
     to 
      +128
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The fallback logic for  The weights should only be exported if the corresponding layers exist. Please remove the  
        Suggested change
       
 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Final normalization | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| final_norm = backbone.get_layer("final_normalization").weights[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weights_dict[f"{prefix}norm.weight"] = final_norm | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if include_lm_head and not token_embedding_layer.tie_weights: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weights_dict["lm_head.weight"] = ops.transpose( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| token_embedding_layer.reverse_embeddings | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return weights_dict | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_gemma3_tokenizer_config(tokenizer): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tokenizer_config = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "tokenizer_class": "GemmaTokenizer", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "clean_up_tokenization_spaces": False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "bos_token": "<bos>", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "eos_token": "<eos>", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "pad_token": "<pad>", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "unk_token": "<unk>", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "add_bos_token": True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "add_eos_token": False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "model_max_length": 32768, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Add added_tokens_decoder | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| added_tokens_decoder = {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| special_tokens = [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "<pad>", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "<bos>", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "<eos>", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "<unk>", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "<start_of_image>", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "<end_of_image>", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "<img>", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for token in special_tokens: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| token_id = tokenizer.token_to_id(token) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if token_id is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| added_tokens_decoder[str(token_id)] = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "content": token, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "special": True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "single_word": False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "lstrip": False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "rstrip": False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "normalized": False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tokenizer_config["added_tokens_decoder"] = added_tokens_decoder | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return tokenizer_config | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,161 @@ | ||
| import os | ||
|  | ||
| import numpy as np | ||
| from transformers import AutoModel | ||
| from transformers import AutoModelForCausalLM | ||
| from transformers import AutoTokenizer | ||
|  | ||
| from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone | ||
| from keras_hub.src.models.gemma3.gemma3_causal_lm import Gemma3CausalLM | ||
| from keras_hub.src.models.gemma3.gemma3_causal_lm_preprocessor import ( | ||
| Gemma3CausalLMPreprocessor, | ||
| ) | ||
| from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer | ||
| from keras_hub.src.tests.test_case import TestCase | ||
|  | ||
|  | ||
| class TestGemma3Export(TestCase): | ||
| def test_export_to_hf(self): | ||
| proto = os.path.join(self.get_test_data_dir(), "gemma3_test_vocab.spm") | ||
| tokenizer = Gemma3Tokenizer(proto=proto) | ||
|  | ||
| # Create a small backbone (text-only, no vision encoder) | ||
| backbone = Gemma3Backbone( | ||
| vocabulary_size=tokenizer.vocabulary_size(), | ||
| image_size=896, # Default value even for text-only | ||
| num_layers=2, | ||
| num_query_heads=4, | ||
| num_key_value_heads=1, | ||
| hidden_dim=512, | ||
| intermediate_dim=1028, | ||
| head_dim=128, | ||
| query_head_dim_normalize=True, | ||
| use_query_key_norm=True, | ||
| use_post_ffw_norm=True, # Real Gemma3 models have these | ||
| use_post_attention_norm=True, # Real Gemma3 models have these | ||
| attention_logit_soft_cap=None, | ||
| final_logit_soft_cap=None, | ||
| use_sliding_window_attention=False, | ||
| sliding_window_size=4096, | ||
| vision_encoder=None, # Text-only model for testing | ||
| layer_norm_epsilon=1e-6, | ||
| dropout=0, | ||
| ) | ||
|  | ||
| # Create preprocessor | ||
| preprocessor = Gemma3CausalLMPreprocessor(tokenizer=tokenizer) | ||
|  | ||
| # Create the causal LM model | ||
| keras_model = Gemma3CausalLM( | ||
| backbone=backbone, preprocessor=preprocessor | ||
| ) | ||
|  | ||
| # Set all weights to random values | ||
| rng = np.random.default_rng(42) | ||
| weights = keras_model.get_weights() | ||
| for i in range(len(weights)): | ||
| weights[i] = rng.random(weights[i].shape).astype(weights[i].dtype) | ||
| keras_model.set_weights(weights) | ||
|  | ||
| # Export to Hugging Face format using the new methods | ||
| export_path_backbone = os.path.join( | ||
| self.get_temp_dir(), "export_backbone" | ||
| ) | ||
| backbone.export_to_transformers(export_path_backbone) | ||
|  | ||
| export_path_tokenizer = os.path.join( | ||
| self.get_temp_dir(), "export_tokenizer" | ||
| ) | ||
| preprocessor.tokenizer.export_to_transformers(export_path_tokenizer) | ||
|  | ||
| export_path_task = os.path.join(self.get_temp_dir(), "export_task") | ||
| keras_model.export_to_transformers(export_path_task) | ||
|  | ||
| # Load Hugging Face models and tokenizer | ||
| # Note: We only test the slow tokenizer because the test vocab file | ||
| # may not be compatible with fast tokenizer conversion | ||
| hf_backbone = AutoModel.from_pretrained(export_path_backbone) | ||
| hf_tokenizer_slow = AutoTokenizer.from_pretrained( | ||
| export_path_tokenizer, use_fast=False | ||
| ) | ||
| hf_full_model = AutoModelForCausalLM.from_pretrained(export_path_task) | ||
|  | ||
| # Verify configuration | ||
| hf_config = hf_backbone.config | ||
| self.assertEqual( | ||
| hf_config.vocab_size, | ||
| backbone.vocabulary_size, | ||
| "Vocabulary sizes do not match", | ||
| ) | ||
| self.assertEqual( | ||
| hf_config.num_hidden_layers, | ||
| backbone.num_layers, | ||
| "Number of layers do not match", | ||
| ) | ||
| self.assertEqual( | ||
| hf_config.num_attention_heads, | ||
| backbone.num_query_heads, | ||
| "Number of query heads do not match", | ||
| ) | ||
| self.assertEqual( | ||
| hf_config.num_key_value_heads, | ||
| backbone.num_key_value_heads, | ||
| "Number of key value heads do not match", | ||
| ) | ||
| self.assertEqual( | ||
| hf_config.hidden_size, | ||
| backbone.hidden_dim, | ||
| "Hidden dimensions do not match", | ||
| ) | ||
| self.assertEqual( | ||
| hf_config.intermediate_size, | ||
| backbone.intermediate_dim, | ||
| "Intermediate sizes do not match", | ||
| ) | ||
| self.assertEqual( | ||
| hf_config.head_dim, | ||
| backbone.head_dim, | ||
| "Head dimensions do not match", | ||
| ) | ||
| self.assertEqual( | ||
| hf_config.max_position_embeddings, | ||
| 32768, | ||
| "Max position embeddings do not match", | ||
| ) | ||
| self.assertEqual( | ||
| hf_config.tie_word_embeddings, | ||
| backbone.token_embedding.tie_weights, | ||
| "Tie word embeddings do not match", | ||
| ) | ||
|  | ||
| # Verify tokenizer compatibility (using slow tokenizer) | ||
| self.assertEqual( | ||
| hf_tokenizer_slow.vocab_size, | ||
| tokenizer.vocabulary_size(), | ||
| "Tokenizer vocabulary sizes do not match", | ||
| ) | ||
|  | ||
| # Compare generated outputs using full model | ||
| prompt = "the quick" | ||
|  | ||
| # Generate with Keras model | ||
| keras_output = keras_model.generate(prompt, max_length=20) | ||
|  | ||
| # Generate with HuggingFace model using slow tokenizer | ||
| input_ids_slow = hf_tokenizer_slow.encode(prompt, return_tensors="pt") | ||
| output_ids_slow = hf_full_model.generate( | ||
| input_ids_slow, max_length=20, do_sample=False | ||
| ) | ||
| hf_slow_output = hf_tokenizer_slow.decode( | ||
| output_ids_slow[0], skip_special_tokens=True | ||
| ) | ||
|  | ||
| # Debug print to see the actual outputs | ||
| print(f"Keras output: '{keras_output}'") | ||
| print(f"HF slow output: '{hf_slow_output}'") | ||
|  | ||
| self.assertEqual( | ||
| keras_output, | ||
| hf_slow_output, | ||
| "Generated outputs do not match", | ||
| ) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for converting the query, key, and value projection kernels is identical across these blocks. This repetition can be refactored into a private helper function to improve code clarity and maintainability, adhering to the DRY (Don't Repeat Yourself) principle.
For example, you could define a helper like
_convert_qkv_kernel(kernel, hidden_dim)and call it for each of theq_proj,k_proj, andv_projweights.