Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Gemma-2/3 specialized attention bridge.

Gemma-2/3 use Rotary Position Embeddings (RoPE) which requires special handling
to fire hook_rot_q and hook_rot_k with the correct post-rotation Q/K values.

This is achieved by wrapping HuggingFace's eager_attention_forward function
to intercept the query and key tensors after rotary embeddings have been applied.
"""Position embeddings attention bridge with full hook support.

Reimplements attention for models using RoPE (Llama, Gemma, Qwen, OLMo, etc.)
so that all hook points fire at the correct computation stage:
- hook_q/hook_k/hook_v: after projection
- hook_rot_q/hook_rot_k: after RoPE rotation
- hook_attn_scores: PRE-softmax (matching HookedTransformer convention)
- hook_pattern: POST-softmax
"""
from __future__ import annotations

Expand Down Expand Up @@ -79,12 +80,21 @@ def hooked_eager_attention_forward(
key = bridge.hook_rot_k(key)

# Call the original function
assert _ORIGINAL_EAGER_ATTENTION_FORWARD is not None
return _ORIGINAL_EAGER_ATTENTION_FORWARD(
module, query, key, value, attention_mask, **kwargs
)

# Replace the module-level function
# Replace the module-level function for both Gemma 2 and Gemma 3
gemma2_module.eager_attention_forward = hooked_eager_attention_forward # type: ignore[assignment]

try:
import transformers.models.gemma3.modeling_gemma3 as gemma3_module

gemma3_module.eager_attention_forward = hooked_eager_attention_forward # type: ignore[assignment]
except ImportError:
pass # Gemma 3 not available in this transformers version

_EAGER_ATTENTION_WRAPPED = True


Expand Down Expand Up @@ -165,6 +175,171 @@ def _apply_position_embedding_hooks(self, position_embeddings):
return (hooked_cos, hooked_sin)
return position_embeddings

def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Reimplemented forward pass with hooks at correct computation stages.

Instead of delegating to the HF attention module (which returns post-softmax
weights), this reimplements attention step-by-step so that:
- hook_attn_scores fires on PRE-softmax scores (matching HookedTransformer)
- hook_pattern fires on POST-softmax weights
- hook_rot_q/hook_rot_k fire after RoPE application

Handles RoPE, GQA, Q/K norms, sliding window, and softcapping.
"""
if self.original_component is None:
raise RuntimeError(
f"Original component not set for {self.name}. "
"Call set_original_component() first."
)

# Type as Any — the HF attention module's interface (q_proj, k_proj, etc.)
# varies by architecture and isn't captured by nn.Module's type signature.
hf_attn: Any = self.original_component

# Extract hidden_states and kwargs
if "hidden_states" in kwargs:
hidden_states = kwargs.pop("hidden_states")
elif len(args) > 0 and isinstance(args[0], torch.Tensor):
hidden_states = args[0]
args = args[1:]
else:
raise ValueError("Could not find hidden_states in args or kwargs")

position_embeddings = kwargs.pop("position_embeddings", None)
attention_mask = kwargs.pop("attention_mask", None)

# Apply input hook
hidden_states = self.hook_in(hidden_states)

# Match dtype of HF module
target_dtype = None
try:
target_dtype = next(hf_attn.parameters()).dtype
except StopIteration:
pass
if target_dtype is not None and hidden_states.is_floating_point():
hidden_states = hidden_states.to(dtype=target_dtype)

# --- Q/K/V Projection + Optional Q/K Norms ---
# Some models (OLMo 2) apply Q/K norms BEFORE multi-head reshape on [batch, seq, hidden].
# Others (Gemma 3) apply AFTER reshape on [batch, heads, seq, head_dim].
# We match the HF model's order by checking if the original forward does
# proj → norm → reshape (OLMo 2) or proj → reshape → norm (Gemma 3).
# Detection: try applying norm to the flat projected tensor. If it fails
# (shape mismatch), the model uses post-reshape norms.
input_shape = hidden_states.shape[:-1]
head_dim = hf_attn.head_dim
hidden_shape = (*input_shape, -1, head_dim)

query_states = hf_attn.q_proj(hidden_states)
key_states = hf_attn.k_proj(hidden_states)
value_states = hf_attn.v_proj(hidden_states)

has_q_norm = hasattr(hf_attn, "q_norm") and hf_attn.q_norm is not None
has_k_norm = hasattr(hf_attn, "k_norm") and hf_attn.k_norm is not None
applied_pre_reshape_norm = False

if has_q_norm:
try:
# Try pre-reshape norm (OLMo 2 style: norm on flat [batch, seq, hidden])
query_states = hf_attn.q_norm(query_states)
if has_k_norm:
key_states = hf_attn.k_norm(key_states)
applied_pre_reshape_norm = True
except RuntimeError:
# Shape mismatch — this model uses post-reshape norms
pass

query_states = query_states.view(hidden_shape).transpose(1, 2)
key_states = key_states.view(hidden_shape).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)

if has_q_norm and not applied_pre_reshape_norm:
# Post-reshape norm (Gemma 3 style: norm on [batch, heads, seq, head_dim])
query_states = hf_attn.q_norm(query_states)
if has_k_norm and not applied_pre_reshape_norm:
key_states = hf_attn.k_norm(key_states)

# --- RoPE ---
if position_embeddings is not None:
position_embeddings = self._apply_position_embedding_hooks(position_embeddings)
cos, sin = position_embeddings
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb

# Some models use partial rotary (e.g., GPT-OSS) where cos/sin cover only
# a portion of head_dim. Split Q/K, rotate the partial dims, recombine.
rotary_dim = cos.shape[-1]
if rotary_dim < head_dim:
q_rot, q_pass = query_states[..., :rotary_dim], query_states[..., rotary_dim:]
k_rot, k_pass = key_states[..., :rotary_dim], key_states[..., rotary_dim:]
q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
query_states = torch.cat([q_rot, q_pass], dim=-1)
key_states = torch.cat([k_rot, k_pass], dim=-1)
else:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

# Fire hook_rot_q/hook_rot_k (post-rotation)
if hasattr(self, "hook_rot_q"):
query_states = self.hook_rot_q(query_states)
if hasattr(self, "hook_rot_k"):
key_states = self.hook_rot_k(key_states)

# --- GQA: Expand K/V ---
num_key_value_groups = getattr(hf_attn, "num_key_value_groups", 1)
if num_key_value_groups > 1:
from transformers.models.llama.modeling_llama import repeat_kv

key_states_expanded = repeat_kv(key_states, num_key_value_groups)
value_states_expanded = repeat_kv(value_states, num_key_value_groups)
else:
key_states_expanded = key_states
value_states_expanded = value_states

# --- Attention Scores ---
scaling = getattr(hf_attn, "scaling", head_dim**-0.5)
attn_scores = torch.matmul(query_states, key_states_expanded.transpose(-2, -1)) * scaling

# --- Softcapping (Gemma 2) ---
softcap = getattr(hf_attn, "attn_logit_softcapping", None)
if softcap is not None:
attn_scores = attn_scores / softcap
attn_scores = torch.tanh(attn_scores)
attn_scores = attn_scores * softcap

# --- Causal / Sliding Window Mask ---
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states_expanded.shape[-2]]
attn_scores = attn_scores + causal_mask

# --- hook_attn_scores: PRE-softmax (matching HookedTransformer) ---
attn_scores = self.hook_attn_scores(attn_scores)

# --- Softmax (in float32 for numerical stability) ---
attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32).to(
query_states.dtype
)

# --- Dropout ---
dropout_rate = getattr(hf_attn, "attention_dropout", 0.0)
if self.training and dropout_rate > 0.0:
attn_weights = torch.nn.functional.dropout(attn_weights, p=dropout_rate, training=True)

# --- hook_pattern: POST-softmax ---
attn_weights = self.hook_pattern(attn_weights)

# --- Attention Output ---
attn_output = torch.matmul(attn_weights, value_states_expanded)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(*input_shape, -1)

# --- Output Projection ---
attn_output = hf_attn.o_proj(attn_output)

# --- Output Hook ---
attn_output = self.hook_out(attn_output)

return attn_output, attn_weights

def get_random_inputs(
self,
batch_size: int = 2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,14 @@ def __init__(
"""Initialize the SigLIP vision encoder bridge.

Args:
name: The name of this component (e.g., "vision_tower")
name: The name of this component (e.g., "model.vision_tower")
config: Optional configuration object
submodules: Dictionary of submodules to register
"""
# All submodule names are resolved relative to the parent's
# original_component (a SiglipVisionModel) by setup_submodules().
# SiglipVisionModel wraps SiglipVisionTransformer as .vision_model,
# so all paths go through vision_model.*.
default_submodules = {
"embeddings": GeneralizedComponent(name="vision_model.embeddings"),
"encoder_layers": SiglipVisionEncoderLayerBridge(name="vision_model.encoder.layers"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def __init__(self, cfg: Any) -> None:
self.cfg.gated_mlp = True
self.cfg.uses_rms_norm = True
self.cfg.normalization_type = "RMS"
# Gemma models use (1.0 + weight) in RMSNorm instead of just weight.
# Without this, fold_ln sets identity to 1.0 instead of 0.0, causing 2x scaling.
self.cfg.rmsnorm_uses_offset = True
self.cfg.positional_embedding_type = "rotary"
self.cfg.attn_implementation = "eager"

Expand Down Expand Up @@ -184,34 +187,15 @@ def __init__(self, cfg: Any) -> None:
def setup_hook_compatibility(self, bridge: Any) -> None:
"""Setup hook compatibility for Gemma3 multimodal models.

Applies embedding scaling like text-only Gemma 3.
Like text-only Gemma 3, the multimodal model uses
Gemma3TextScaledWordEmbedding which scales embeddings by sqrt(d_model)
internally in its forward() method. No additional hook conversion is
needed — adding one would double-scale the embeddings.

Args:
bridge: The TransformerBridge instance
"""
from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import (
BaseTensorConversion,
)

class EmbeddingScaleConversion(BaseTensorConversion):
"""Scale embeddings by sqrt(d_model) for Gemma models."""

def __init__(self, scale: float):
super().__init__()
self.scale = scale

def handle_conversion(self, input_value: Any, *full_context: Any) -> Any:
"""Scale the embedding output."""
return input_value * self.scale

def revert(self, input_value: Any, *full_context: Any) -> Any:
"""Unscale the embedding output (for user modifications)."""
return input_value / self.scale

# Apply scaling to embed.hook_out
if hasattr(bridge, "embed") and hasattr(bridge.embed, "hook_out"):
scale_factor = self.cfg.d_model**0.5
bridge.embed.hook_out.hook_conversion = EmbeddingScaleConversion(scale_factor)
pass

def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
"""Set up rotary embedding references for Gemma-3 multimodal component testing.
Expand Down
17 changes: 16 additions & 1 deletion transformer_lens/utilities/bridge_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,22 @@ def collect_components_of_block_bridge(
# Retrieve the remote component list from the adapter (we need a ModuleList to iterate over)
if component.name is None:
raise ValueError("Block bridge component must have a name")
remote_module_list = model.adapter.get_remote_component(model.original_model, component.name)

# If the component already has its original_component set (from boot),
# use it directly. This handles nested list components (e.g., vision encoder
# layers) whose names are relative to their parent, not the model root.
if component.original_component is not None:
remote_module_list = component.original_component
else:
try:
remote_module_list = model.adapter.get_remote_component(
model.original_model, component.name
)
except AttributeError:
# Submodule name is relative to a parent component that isn't the model root
# (e.g., vision encoder layers inside a multimodal model). Skip gracefully
# since these components are already set up during boot.
return components

# Make sure the remote component is a ModuleList
if isinstance(remote_module_list, nn.ModuleList):
Expand Down
Loading