Skip to content
177 changes: 177 additions & 0 deletions keras_hub/src/utils/transformers/export/gemma3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import keras.ops as ops


def get_gemma3_config(backbone):
"""Convert Keras Gemma3 config to Hugging Face config dictionary."""
token_embedding_layer = backbone.get_layer("token_embedding")
hf_config = {
"architectures": ["Gemma3ForCausalLM"],
"model_type": "gemma3_text",
"vocab_size": backbone.vocabulary_size,
"num_hidden_layers": backbone.num_layers,
"num_attention_heads": backbone.num_query_heads,
"num_key_value_heads": backbone.num_key_value_heads,
"hidden_size": backbone.hidden_dim,
"intermediate_size": backbone.intermediate_dim,
"head_dim": backbone.head_dim,
"max_position_embeddings": 32768,
"tie_word_embeddings": token_embedding_layer.tie_weights,
"rms_norm_eps": 1e-6,
"rope_theta": 10000.0,
"attention_bias": False,
"attention_dropout": 0.0,
"hidden_activation": "gelu_pytorch_tanh",
}
return hf_config


def get_gemma3_weights_map(backbone, include_lm_head=False):
"""Convert a Keras Gemma3 model to Hugging Face format.

include_lm_head: If True, exports for CausalLM (with "model." prefix).
If False, exports for backbone only (without prefix).
"""

weights_dict = {}

# For CausalLM export, use "model." prefix
# For backbone export, use no prefix
prefix = "model." if include_lm_head else ""

# Token embeddings - use .weights[0] to get backend tensor
token_embedding_layer = backbone.get_layer("token_embedding")
token_embedding = token_embedding_layer.weights[0]
weights_dict[f"{prefix}embed_tokens.weight"] = token_embedding

for i in range(backbone.num_layers):
block = backbone.get_layer(f"decoder_block_{i}")

# Attention query projection
q_kernel = block.attention.query_dense.weights[0]
q_kernel = ops.transpose(q_kernel, axes=(1, 0, 2)) # permute(1, 0, 2)
q_kernel = ops.reshape(q_kernel, (backbone.hidden_dim, -1))
q_kernel = ops.transpose(q_kernel) # .T
weights_dict[f"{prefix}layers.{i}.self_attn.q_proj.weight"] = q_kernel

# Attention key projection
k_kernel = block.attention.key_dense.weights[0]
k_kernel = ops.transpose(k_kernel, axes=(1, 0, 2)) # permute(1, 0, 2)
k_kernel = ops.reshape(k_kernel, (backbone.hidden_dim, -1))
k_kernel = ops.transpose(k_kernel) # .T
weights_dict[f"{prefix}layers.{i}.self_attn.k_proj.weight"] = k_kernel

# Attention value projection
v_kernel = block.attention.value_dense.weights[0]
v_kernel = ops.transpose(v_kernel, axes=(1, 0, 2)) # permute(1, 0, 2)
v_kernel = ops.reshape(v_kernel, (backbone.hidden_dim, -1))
v_kernel = ops.transpose(v_kernel) # .T
weights_dict[f"{prefix}layers.{i}.self_attn.v_proj.weight"] = v_kernel
Comment on lines +50 to +68
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for converting the query, key, and value projection kernels is identical across these blocks. This repetition can be refactored into a private helper function to improve code clarity and maintainability, adhering to the DRY (Don't Repeat Yourself) principle.

For example, you could define a helper like _convert_qkv_kernel(kernel, hidden_dim) and call it for each of the q_proj, k_proj, and v_proj weights.


# Attention output projection
o_kernel = block.attention.output_dense.weights[0]
o_kernel = ops.transpose(o_kernel, axes=(2, 0, 1)) # permute(2, 0, 1)
o_kernel = ops.reshape(o_kernel, (backbone.hidden_dim, -1))
weights_dict[f"{prefix}layers.{i}.self_attn.o_proj.weight"] = o_kernel

# Query and key normalization
q_norm = block.attention.query_norm.weights[0]
weights_dict[f"{prefix}layers.{i}.self_attn.q_norm.weight"] = q_norm

k_norm = block.attention.key_norm.weights[0]
weights_dict[f"{prefix}layers.{i}.self_attn.k_norm.weight"] = k_norm

# MLP gate projection
gate_kernel = block.gating_ffw.weights[0]
gate_kernel = ops.transpose(gate_kernel) # .T
weights_dict[f"{prefix}layers.{i}.mlp.gate_proj.weight"] = gate_kernel

# MLP up projection
up_kernel = block.gating_ffw_2.weights[0]
up_kernel = ops.transpose(up_kernel) # .T
weights_dict[f"{prefix}layers.{i}.mlp.up_proj.weight"] = up_kernel

# MLP down projection
down_kernel = block.ffw_linear.weights[0]
down_kernel = ops.transpose(down_kernel) # .T
weights_dict[f"{prefix}layers.{i}.mlp.down_proj.weight"] = down_kernel

# Pre-attention normalization
input_layer_norm = block.pre_attention_norm.weights[0]
weights_dict[f"{prefix}layers.{i}.input_layernorm.weight"] = (
input_layer_norm
)

# Post-attention normalization
if hasattr(block, "post_attention_norm"):
post_attn_norm = block.post_attention_norm.weights[0]
else:
# Fallback to pre_ffw_norm if post_attention_norm doesn't exist
post_attn_norm = block.pre_ffw_norm.weights[0]
weights_dict[f"{prefix}layers.{i}.post_attention_layernorm.weight"] = (
post_attn_norm
)

# Pre-feedforward normalization
pre_feedforward_layernorm = block.pre_ffw_norm.weights[0]
weights_dict[f"{prefix}layers.{i}.pre_feedforward_layernorm.weight"] = (
pre_feedforward_layernorm
)

# Post-feedforward normalization
if hasattr(block, "post_ffw_norm"):
post_feedforward_layernorm = block.post_ffw_norm.weights[0]
else:
# Fallback to pre_ffw_norm if post_ffw_norm doesn't exist
post_feedforward_layernorm = block.pre_ffw_norm.weights[0]
weights_dict[
f"{prefix}layers.{i}.post_feedforward_layernorm.weight"
] = post_feedforward_layernorm
Comment on lines +105 to +128
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The fallback logic for post_attention_norm and post_ffw_norm appears to be incorrect. If these layers do not exist on the block (likely because the model was configured with use_post_attention_norm=False or use_post_ffw_norm=False), the Hugging Face model would not expect weights for the corresponding layernorms. Assigning weights from pre_ffw_norm in these cases could lead to a functionally incorrect model.

The weights should only be exported if the corresponding layers exist. Please remove the else blocks for both post_attention_layernorm and post_feedforward_layernorm.

Suggested change
if hasattr(block, "post_attention_norm"):
post_attn_norm = block.post_attention_norm.weights[0]
else:
# Fallback to pre_ffw_norm if post_attention_norm doesn't exist
post_attn_norm = block.pre_ffw_norm.weights[0]
weights_dict[f"{prefix}layers.{i}.post_attention_layernorm.weight"] = (
post_attn_norm
)
# Pre-feedforward normalization
pre_feedforward_layernorm = block.pre_ffw_norm.weights[0]
weights_dict[f"{prefix}layers.{i}.pre_feedforward_layernorm.weight"] = (
pre_feedforward_layernorm
)
# Post-feedforward normalization
if hasattr(block, "post_ffw_norm"):
post_feedforward_layernorm = block.post_ffw_norm.weights[0]
else:
# Fallback to pre_ffw_norm if post_ffw_norm doesn't exist
post_feedforward_layernorm = block.pre_ffw_norm.weights[0]
weights_dict[
f"{prefix}layers.{i}.post_feedforward_layernorm.weight"
] = post_feedforward_layernorm
if hasattr(block, "post_attention_norm"):
post_attn_norm = block.post_attention_norm.weights[0]
weights_dict[f"{prefix}layers.{i}.post_attention_layernorm.weight"] = (
post_attn_norm
)
# Pre-feedforward normalization
pre_feedforward_layernorm = block.pre_ffw_norm.weights[0]
weights_dict[f"{prefix}layers.{i}.pre_feedforward_layernorm.weight"] = (
pre_feedforward_layernorm
)
# Post-feedforward normalization
if hasattr(block, "post_ffw_norm"):
post_feedforward_layernorm = block.post_ffw_norm.weights[0]
weights_dict[
f"{prefix}layers.{i}.post_feedforward_layernorm.weight"
] = post_feedforward_layernorm


# Final normalization
final_norm = backbone.get_layer("final_normalization").weights[0]
weights_dict[f"{prefix}norm.weight"] = final_norm

if include_lm_head and not token_embedding_layer.tie_weights:
weights_dict["lm_head.weight"] = ops.transpose(
token_embedding_layer.reverse_embeddings
)

return weights_dict


def get_gemma3_tokenizer_config(tokenizer):
tokenizer_config = {
"tokenizer_class": "GemmaTokenizer",
"clean_up_tokenization_spaces": False,
"bos_token": "<bos>",
"eos_token": "<eos>",
"pad_token": "<pad>",
"unk_token": "<unk>",
"add_bos_token": True,
"add_eos_token": False,
"model_max_length": 32768,
}
# Add added_tokens_decoder
added_tokens_decoder = {}
special_tokens = [
"<pad>",
"<bos>",
"<eos>",
"<unk>",
"<start_of_image>",
"<end_of_image>",
"<img>",
]
for token in special_tokens:
token_id = tokenizer.token_to_id(token)
if token_id is not None:
added_tokens_decoder[str(token_id)] = {
"content": token,
"special": True,
"single_word": False,
"lstrip": False,
"rstrip": False,
"normalized": False,
}
tokenizer_config["added_tokens_decoder"] = added_tokens_decoder
return tokenizer_config
161 changes: 161 additions & 0 deletions keras_hub/src/utils/transformers/export/gemma3_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import os

import numpy as np
from transformers import AutoModel
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer

from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone
from keras_hub.src.models.gemma3.gemma3_causal_lm import Gemma3CausalLM
from keras_hub.src.models.gemma3.gemma3_causal_lm_preprocessor import (
Gemma3CausalLMPreprocessor,
)
from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer
from keras_hub.src.tests.test_case import TestCase


class TestGemma3Export(TestCase):
def test_export_to_hf(self):
proto = os.path.join(self.get_test_data_dir(), "gemma3_test_vocab.spm")
tokenizer = Gemma3Tokenizer(proto=proto)

# Create a small backbone (text-only, no vision encoder)
backbone = Gemma3Backbone(
vocabulary_size=tokenizer.vocabulary_size(),
image_size=896, # Default value even for text-only
num_layers=2,
num_query_heads=4,
num_key_value_heads=1,
hidden_dim=512,
intermediate_dim=1028,
head_dim=128,
query_head_dim_normalize=True,
use_query_key_norm=True,
use_post_ffw_norm=True, # Real Gemma3 models have these
use_post_attention_norm=True, # Real Gemma3 models have these
attention_logit_soft_cap=None,
final_logit_soft_cap=None,
use_sliding_window_attention=False,
sliding_window_size=4096,
vision_encoder=None, # Text-only model for testing
layer_norm_epsilon=1e-6,
dropout=0,
)

# Create preprocessor
preprocessor = Gemma3CausalLMPreprocessor(tokenizer=tokenizer)

# Create the causal LM model
keras_model = Gemma3CausalLM(
backbone=backbone, preprocessor=preprocessor
)

# Set all weights to random values
rng = np.random.default_rng(42)
weights = keras_model.get_weights()
for i in range(len(weights)):
weights[i] = rng.random(weights[i].shape).astype(weights[i].dtype)
keras_model.set_weights(weights)

# Export to Hugging Face format using the new methods
export_path_backbone = os.path.join(
self.get_temp_dir(), "export_backbone"
)
backbone.export_to_transformers(export_path_backbone)

export_path_tokenizer = os.path.join(
self.get_temp_dir(), "export_tokenizer"
)
preprocessor.tokenizer.export_to_transformers(export_path_tokenizer)

export_path_task = os.path.join(self.get_temp_dir(), "export_task")
keras_model.export_to_transformers(export_path_task)

# Load Hugging Face models and tokenizer
# Note: We only test the slow tokenizer because the test vocab file
# may not be compatible with fast tokenizer conversion
hf_backbone = AutoModel.from_pretrained(export_path_backbone)
hf_tokenizer_slow = AutoTokenizer.from_pretrained(
export_path_tokenizer, use_fast=False
)
hf_full_model = AutoModelForCausalLM.from_pretrained(export_path_task)

# Verify configuration
hf_config = hf_backbone.config
self.assertEqual(
hf_config.vocab_size,
backbone.vocabulary_size,
"Vocabulary sizes do not match",
)
self.assertEqual(
hf_config.num_hidden_layers,
backbone.num_layers,
"Number of layers do not match",
)
self.assertEqual(
hf_config.num_attention_heads,
backbone.num_query_heads,
"Number of query heads do not match",
)
self.assertEqual(
hf_config.num_key_value_heads,
backbone.num_key_value_heads,
"Number of key value heads do not match",
)
self.assertEqual(
hf_config.hidden_size,
backbone.hidden_dim,
"Hidden dimensions do not match",
)
self.assertEqual(
hf_config.intermediate_size,
backbone.intermediate_dim,
"Intermediate sizes do not match",
)
self.assertEqual(
hf_config.head_dim,
backbone.head_dim,
"Head dimensions do not match",
)
self.assertEqual(
hf_config.max_position_embeddings,
32768,
"Max position embeddings do not match",
)
self.assertEqual(
hf_config.tie_word_embeddings,
backbone.token_embedding.tie_weights,
"Tie word embeddings do not match",
)

# Verify tokenizer compatibility (using slow tokenizer)
self.assertEqual(
hf_tokenizer_slow.vocab_size,
tokenizer.vocabulary_size(),
"Tokenizer vocabulary sizes do not match",
)

# Compare generated outputs using full model
prompt = "the quick"

# Generate with Keras model
keras_output = keras_model.generate(prompt, max_length=20)

# Generate with HuggingFace model using slow tokenizer
input_ids_slow = hf_tokenizer_slow.encode(prompt, return_tensors="pt")
output_ids_slow = hf_full_model.generate(
input_ids_slow, max_length=20, do_sample=False
)
hf_slow_output = hf_tokenizer_slow.decode(
output_ids_slow[0], skip_special_tokens=True
)

# Debug print to see the actual outputs
print(f"Keras output: '{keras_output}'")
print(f"HF slow output: '{hf_slow_output}'")

self.assertEqual(
keras_output,
hf_slow_output,
"Generated outputs do not match",
)
10 changes: 10 additions & 0 deletions keras_hub/src/utils/transformers/export/hf_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,29 @@
get_gemma_tokenizer_config,
)
from keras_hub.src.utils.transformers.export.gemma import get_gemma_weights_map
from keras_hub.src.utils.transformers.export.gemma3 import get_gemma3_config
from keras_hub.src.utils.transformers.export.gemma3 import (
get_gemma3_tokenizer_config,
)
from keras_hub.src.utils.transformers.export.gemma3 import (
get_gemma3_weights_map,
)

MODEL_CONFIGS = {
"GemmaBackbone": get_gemma_config,
"Gemma3Backbone": get_gemma3_config,
# Add for future models, e.g., "MistralBackbone": get_mistral_config
}

MODEL_EXPORTERS = {
"GemmaBackbone": get_gemma_weights_map,
"Gemma3Backbone": get_gemma3_weights_map,
# Add for future models, e.g., "MistralBackbone": get_mistral_weights_map
}

MODEL_TOKENIZER_CONFIGS = {
"GemmaTokenizer": get_gemma_tokenizer_config,
"Gemma3Tokenizer": get_gemma3_tokenizer_config,
# Add for future models, e.g., "MistralTokenizer":
# get_mistral_tokenizer_config
}
Expand Down
Loading
Loading