diff --git a/transformer_lens/model_bridge/generalized_components/position_embeddings_attention.py b/transformer_lens/model_bridge/generalized_components/position_embeddings_attention.py index 42fa1c94f..ed0fc9131 100644 --- a/transformer_lens/model_bridge/generalized_components/position_embeddings_attention.py +++ b/transformer_lens/model_bridge/generalized_components/position_embeddings_attention.py @@ -1,10 +1,11 @@ -"""Gemma-2/3 specialized attention bridge. - -Gemma-2/3 use Rotary Position Embeddings (RoPE) which requires special handling -to fire hook_rot_q and hook_rot_k with the correct post-rotation Q/K values. - -This is achieved by wrapping HuggingFace's eager_attention_forward function -to intercept the query and key tensors after rotary embeddings have been applied. +"""Position embeddings attention bridge with full hook support. + +Reimplements attention for models using RoPE (Llama, Gemma, Qwen, OLMo, etc.) +so that all hook points fire at the correct computation stage: +- hook_q/hook_k/hook_v: after projection +- hook_rot_q/hook_rot_k: after RoPE rotation +- hook_attn_scores: PRE-softmax (matching HookedTransformer convention) +- hook_pattern: POST-softmax """ from __future__ import annotations @@ -79,12 +80,21 @@ def hooked_eager_attention_forward( key = bridge.hook_rot_k(key) # Call the original function + assert _ORIGINAL_EAGER_ATTENTION_FORWARD is not None return _ORIGINAL_EAGER_ATTENTION_FORWARD( module, query, key, value, attention_mask, **kwargs ) - # Replace the module-level function + # Replace the module-level function for both Gemma 2 and Gemma 3 gemma2_module.eager_attention_forward = hooked_eager_attention_forward # type: ignore[assignment] + + try: + import transformers.models.gemma3.modeling_gemma3 as gemma3_module + + gemma3_module.eager_attention_forward = hooked_eager_attention_forward # type: ignore[assignment] + except ImportError: + pass # Gemma 3 not available in this transformers version + _EAGER_ATTENTION_WRAPPED = True @@ -165,6 +175,171 @@ def _apply_position_embedding_hooks(self, position_embeddings): return (hooked_cos, hooked_sin) return position_embeddings + def forward(self, *args: Any, **kwargs: Any) -> Any: + """Reimplemented forward pass with hooks at correct computation stages. + + Instead of delegating to the HF attention module (which returns post-softmax + weights), this reimplements attention step-by-step so that: + - hook_attn_scores fires on PRE-softmax scores (matching HookedTransformer) + - hook_pattern fires on POST-softmax weights + - hook_rot_q/hook_rot_k fire after RoPE application + + Handles RoPE, GQA, Q/K norms, sliding window, and softcapping. + """ + if self.original_component is None: + raise RuntimeError( + f"Original component not set for {self.name}. " + "Call set_original_component() first." + ) + + # Type as Any — the HF attention module's interface (q_proj, k_proj, etc.) + # varies by architecture and isn't captured by nn.Module's type signature. + hf_attn: Any = self.original_component + + # Extract hidden_states and kwargs + if "hidden_states" in kwargs: + hidden_states = kwargs.pop("hidden_states") + elif len(args) > 0 and isinstance(args[0], torch.Tensor): + hidden_states = args[0] + args = args[1:] + else: + raise ValueError("Could not find hidden_states in args or kwargs") + + position_embeddings = kwargs.pop("position_embeddings", None) + attention_mask = kwargs.pop("attention_mask", None) + + # Apply input hook + hidden_states = self.hook_in(hidden_states) + + # Match dtype of HF module + target_dtype = None + try: + target_dtype = next(hf_attn.parameters()).dtype + except StopIteration: + pass + if target_dtype is not None and hidden_states.is_floating_point(): + hidden_states = hidden_states.to(dtype=target_dtype) + + # --- Q/K/V Projection + Optional Q/K Norms --- + # Some models (OLMo 2) apply Q/K norms BEFORE multi-head reshape on [batch, seq, hidden]. + # Others (Gemma 3) apply AFTER reshape on [batch, heads, seq, head_dim]. + # We match the HF model's order by checking if the original forward does + # proj → norm → reshape (OLMo 2) or proj → reshape → norm (Gemma 3). + # Detection: try applying norm to the flat projected tensor. If it fails + # (shape mismatch), the model uses post-reshape norms. + input_shape = hidden_states.shape[:-1] + head_dim = hf_attn.head_dim + hidden_shape = (*input_shape, -1, head_dim) + + query_states = hf_attn.q_proj(hidden_states) + key_states = hf_attn.k_proj(hidden_states) + value_states = hf_attn.v_proj(hidden_states) + + has_q_norm = hasattr(hf_attn, "q_norm") and hf_attn.q_norm is not None + has_k_norm = hasattr(hf_attn, "k_norm") and hf_attn.k_norm is not None + applied_pre_reshape_norm = False + + if has_q_norm: + try: + # Try pre-reshape norm (OLMo 2 style: norm on flat [batch, seq, hidden]) + query_states = hf_attn.q_norm(query_states) + if has_k_norm: + key_states = hf_attn.k_norm(key_states) + applied_pre_reshape_norm = True + except RuntimeError: + # Shape mismatch — this model uses post-reshape norms + pass + + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + if has_q_norm and not applied_pre_reshape_norm: + # Post-reshape norm (Gemma 3 style: norm on [batch, heads, seq, head_dim]) + query_states = hf_attn.q_norm(query_states) + if has_k_norm and not applied_pre_reshape_norm: + key_states = hf_attn.k_norm(key_states) + + # --- RoPE --- + if position_embeddings is not None: + position_embeddings = self._apply_position_embedding_hooks(position_embeddings) + cos, sin = position_embeddings + from transformers.models.llama.modeling_llama import apply_rotary_pos_emb + + # Some models use partial rotary (e.g., GPT-OSS) where cos/sin cover only + # a portion of head_dim. Split Q/K, rotate the partial dims, recombine. + rotary_dim = cos.shape[-1] + if rotary_dim < head_dim: + q_rot, q_pass = query_states[..., :rotary_dim], query_states[..., rotary_dim:] + k_rot, k_pass = key_states[..., :rotary_dim], key_states[..., rotary_dim:] + q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) + query_states = torch.cat([q_rot, q_pass], dim=-1) + key_states = torch.cat([k_rot, k_pass], dim=-1) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + # Fire hook_rot_q/hook_rot_k (post-rotation) + if hasattr(self, "hook_rot_q"): + query_states = self.hook_rot_q(query_states) + if hasattr(self, "hook_rot_k"): + key_states = self.hook_rot_k(key_states) + + # --- GQA: Expand K/V --- + num_key_value_groups = getattr(hf_attn, "num_key_value_groups", 1) + if num_key_value_groups > 1: + from transformers.models.llama.modeling_llama import repeat_kv + + key_states_expanded = repeat_kv(key_states, num_key_value_groups) + value_states_expanded = repeat_kv(value_states, num_key_value_groups) + else: + key_states_expanded = key_states + value_states_expanded = value_states + + # --- Attention Scores --- + scaling = getattr(hf_attn, "scaling", head_dim**-0.5) + attn_scores = torch.matmul(query_states, key_states_expanded.transpose(-2, -1)) * scaling + + # --- Softcapping (Gemma 2) --- + softcap = getattr(hf_attn, "attn_logit_softcapping", None) + if softcap is not None: + attn_scores = attn_scores / softcap + attn_scores = torch.tanh(attn_scores) + attn_scores = attn_scores * softcap + + # --- Causal / Sliding Window Mask --- + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states_expanded.shape[-2]] + attn_scores = attn_scores + causal_mask + + # --- hook_attn_scores: PRE-softmax (matching HookedTransformer) --- + attn_scores = self.hook_attn_scores(attn_scores) + + # --- Softmax (in float32 for numerical stability) --- + attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + + # --- Dropout --- + dropout_rate = getattr(hf_attn, "attention_dropout", 0.0) + if self.training and dropout_rate > 0.0: + attn_weights = torch.nn.functional.dropout(attn_weights, p=dropout_rate, training=True) + + # --- hook_pattern: POST-softmax --- + attn_weights = self.hook_pattern(attn_weights) + + # --- Attention Output --- + attn_output = torch.matmul(attn_weights, value_states_expanded) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(*input_shape, -1) + + # --- Output Projection --- + attn_output = hf_attn.o_proj(attn_output) + + # --- Output Hook --- + attn_output = self.hook_out(attn_output) + + return attn_output, attn_weights + def get_random_inputs( self, batch_size: int = 2, diff --git a/transformer_lens/model_bridge/generalized_components/siglip_vision_encoder.py b/transformer_lens/model_bridge/generalized_components/siglip_vision_encoder.py index c2a654218..569a09643 100644 --- a/transformer_lens/model_bridge/generalized_components/siglip_vision_encoder.py +++ b/transformer_lens/model_bridge/generalized_components/siglip_vision_encoder.py @@ -109,10 +109,14 @@ def __init__( """Initialize the SigLIP vision encoder bridge. Args: - name: The name of this component (e.g., "vision_tower") + name: The name of this component (e.g., "model.vision_tower") config: Optional configuration object submodules: Dictionary of submodules to register """ + # All submodule names are resolved relative to the parent's + # original_component (a SiglipVisionModel) by setup_submodules(). + # SiglipVisionModel wraps SiglipVisionTransformer as .vision_model, + # so all paths go through vision_model.*. default_submodules = { "embeddings": GeneralizedComponent(name="vision_model.embeddings"), "encoder_layers": SiglipVisionEncoderLayerBridge(name="vision_model.encoder.layers"), diff --git a/transformer_lens/model_bridge/supported_architectures/gemma3_multimodal.py b/transformer_lens/model_bridge/supported_architectures/gemma3_multimodal.py index bbd2e7df8..928186f4b 100644 --- a/transformer_lens/model_bridge/supported_architectures/gemma3_multimodal.py +++ b/transformer_lens/model_bridge/supported_architectures/gemma3_multimodal.py @@ -58,6 +58,9 @@ def __init__(self, cfg: Any) -> None: self.cfg.gated_mlp = True self.cfg.uses_rms_norm = True self.cfg.normalization_type = "RMS" + # Gemma models use (1.0 + weight) in RMSNorm instead of just weight. + # Without this, fold_ln sets identity to 1.0 instead of 0.0, causing 2x scaling. + self.cfg.rmsnorm_uses_offset = True self.cfg.positional_embedding_type = "rotary" self.cfg.attn_implementation = "eager" @@ -184,34 +187,15 @@ def __init__(self, cfg: Any) -> None: def setup_hook_compatibility(self, bridge: Any) -> None: """Setup hook compatibility for Gemma3 multimodal models. - Applies embedding scaling like text-only Gemma 3. + Like text-only Gemma 3, the multimodal model uses + Gemma3TextScaledWordEmbedding which scales embeddings by sqrt(d_model) + internally in its forward() method. No additional hook conversion is + needed — adding one would double-scale the embeddings. Args: bridge: The TransformerBridge instance """ - from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import ( - BaseTensorConversion, - ) - - class EmbeddingScaleConversion(BaseTensorConversion): - """Scale embeddings by sqrt(d_model) for Gemma models.""" - - def __init__(self, scale: float): - super().__init__() - self.scale = scale - - def handle_conversion(self, input_value: Any, *full_context: Any) -> Any: - """Scale the embedding output.""" - return input_value * self.scale - - def revert(self, input_value: Any, *full_context: Any) -> Any: - """Unscale the embedding output (for user modifications).""" - return input_value / self.scale - - # Apply scaling to embed.hook_out - if hasattr(bridge, "embed") and hasattr(bridge.embed, "hook_out"): - scale_factor = self.cfg.d_model**0.5 - bridge.embed.hook_out.hook_conversion = EmbeddingScaleConversion(scale_factor) + pass def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: """Set up rotary embedding references for Gemma-3 multimodal component testing. diff --git a/transformer_lens/utilities/bridge_components.py b/transformer_lens/utilities/bridge_components.py index c3d03c1c6..c1cd881ba 100644 --- a/transformer_lens/utilities/bridge_components.py +++ b/transformer_lens/utilities/bridge_components.py @@ -57,7 +57,22 @@ def collect_components_of_block_bridge( # Retrieve the remote component list from the adapter (we need a ModuleList to iterate over) if component.name is None: raise ValueError("Block bridge component must have a name") - remote_module_list = model.adapter.get_remote_component(model.original_model, component.name) + + # If the component already has its original_component set (from boot), + # use it directly. This handles nested list components (e.g., vision encoder + # layers) whose names are relative to their parent, not the model root. + if component.original_component is not None: + remote_module_list = component.original_component + else: + try: + remote_module_list = model.adapter.get_remote_component( + model.original_model, component.name + ) + except AttributeError: + # Submodule name is relative to a parent component that isn't the model root + # (e.g., vision encoder layers inside a multimodal model). Skip gracefully + # since these components are already set up during boot. + return components # Make sure the remote component is a ModuleList if isinstance(remote_module_list, nn.ModuleList):