Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 6 additions & 14 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def __init__(
"OlmoForCausalLM",
"OlmoeForCausalLM",
"Olmo2ForCausalLM",
"Qwen3ForCausalLM",
"PhiForCausalLM",
]
self.set_tokenizer(
AutoTokenizer.from_pretrained(
Expand Down Expand Up @@ -1407,7 +1409,10 @@ def from_pretrained(
)
center_writing_weights = False
# OLMo 2 uses post-norm (norm after attention/MLP, not before), which is
# incompatible with weight processing that assumes pre-norm structure.
# incompatible with fold_ln and center_writing_weights (these assume pre-norm).
# center_unembed and fold_value_biases are architecture-independent and remain valid:
# - center_unembed: softmax is always translation-invariant
# - fold_value_biases: attention patterns always sum to 1
if cfg.original_architecture == "Olmo2ForCausalLM":
if fold_ln:
logging.warning(
Expand All @@ -1421,19 +1426,6 @@ def from_pretrained(
"architecture. Setting center_writing_weights=False."
)
center_writing_weights = False
if center_unembed:
logging.warning(
"center_unembed=True is incompatible with OLMo 2's post-norm "
"architecture (uses RMSNorm which does not center). "
"Setting center_unembed=False."
)
center_unembed = False
if fold_value_biases:
logging.warning(
"fold_value_biases=True is incompatible with OLMo 2's post-norm "
"architecture. Setting fold_value_biases=False."
)
fold_value_biases = False
if center_unembed and cfg.output_logits_soft_cap > 0.0:
logging.warning(
"You tried to specify center_unembed=True for a model using logit softcap, but this can't be done! Softcapping is not invariant upon adding a constant "
Expand Down
5 changes: 5 additions & 0 deletions transformer_lens/components/abstract_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,11 @@ def _find_correction_dim(num_rotations: float) -> float:
freq = 1.0 / inv_freq
else:
freq = base ** (dim / (rotary_dim / 2))
# Apply linear RoPE scaling for global attention layers if configured
# (e.g., Gemma 3 4B uses factor=8.0 for global layers, but not local ones)
scaling_factor = getattr(self.cfg, "rotary_scaling_factor", 1.0)
if scaling_factor != 1.0 and self.attn_type != "local":
freq = freq * scaling_factor
if self.cfg.rotary_adjacent_pairs:
freq = einops.repeat(freq, "d -> (d 2)")
else:
Expand Down
3 changes: 3 additions & 0 deletions transformer_lens/config/HookedTransformerConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,9 @@ class HookedTransformerConfig(TransformerLensConfig):
rotary_base_local: Optional[
int
] = None # For models with different RoPE bases per attention type (e.g., Gemma 3)
rotary_scaling_factor: float = (
1.0 # Linear RoPE scaling factor for global attention (e.g., 8.0 for Gemma 3 4B)
)
trust_remote_code: bool = False
rotary_adjacent_pairs: bool = False
load_in_4bit: bool = False
Expand Down
44 changes: 34 additions & 10 deletions transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any) -> dict[str, Any]:
"positional_embedding_type": "rotary",
"rotary_adjacent_pairs": False,
"normalization_type": "LN",
"default_prepend_bos": False,
}
rotary_pct = get_rotary_pct_from_config(hf_config)
cfg_dict["rotary_dim"] = round(rotary_pct * cfg_dict["d_head"])
Expand Down Expand Up @@ -836,6 +837,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any) -> dict[str, Any]:
"rotary_base": _get_rope_theta(hf_config),
"use_attn_scale": True,
"parallel_attn_mlp": True,
"default_prepend_bos": False,
}
partial_rotary_factor = hf_config.partial_rotary_factor
cfg_dict["rotary_dim"] = round(partial_rotary_factor * cfg_dict["d_head"])
Expand Down Expand Up @@ -916,7 +918,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any) -> dict[str, Any]:
"n_ctx": 8192,
"eps": 1e-06,
"d_vocab": 256000,
"act_fn": "gelu_new",
"act_fn": "gelu",
"initializer_range": 0.02,
"normalization_type": "RMS",
"rotary_base": 10000,
Expand All @@ -938,7 +940,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any) -> dict[str, Any]:
"n_ctx": 8192,
"eps": 1e-06,
"d_vocab": 256000,
"act_fn": "gelu_new",
"act_fn": "gelu",
"initializer_range": 0.02,
"normalization_type": "RMS",
"rotary_base": 10000.0,
Expand Down Expand Up @@ -1149,6 +1151,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any) -> dict[str, Any]:
"normalization_type": "RMS",
"rotary_base": 1000000, # Global attention layers
"rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
"rotary_scaling_factor": 8.0, # Linear RoPE scaling for global layers
"positional_embedding_type": "rotary",
"use_attn_scale": True,
"n_key_value_heads": 4,
Expand Down Expand Up @@ -1211,6 +1214,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any) -> dict[str, Any]:
"normalization_type": "RMS",
"rotary_base": 1000000, # Global attention layers
"rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
"rotary_scaling_factor": 8.0, # Linear RoPE scaling for global layers
"positional_embedding_type": "rotary",
"use_attn_scale": True,
"n_key_value_heads": 8,
Expand Down Expand Up @@ -1292,6 +1296,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any) -> dict[str, Any]:
"normalization_type": "RMS",
"rotary_base": 1000000, # Global attention layers
"rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
"rotary_scaling_factor": 8.0, # Linear RoPE scaling for global layers
"positional_embedding_type": "rotary",
"use_attn_scale": True,
"n_key_value_heads": 16,
Expand Down Expand Up @@ -1377,7 +1382,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any) -> dict[str, Any]:
"n_ctx": 8192,
"eps": 1e-06,
"d_vocab": 256000,
"act_fn": "gelu_new",
"act_fn": "gelu",
"initializer_range": 0.02,
"normalization_type": "RMS",
"rotary_base": 10000,
Expand All @@ -1399,7 +1404,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any) -> dict[str, Any]:
"n_ctx": 8192,
"eps": 1e-06,
"d_vocab": 256000,
"act_fn": "gelu_new",
"act_fn": "gelu",
"initializer_range": 0.02,
"normalization_type": "RMS",
"rotary_base": 10000.0,
Expand Down Expand Up @@ -2051,12 +2056,31 @@ def get_pretrained_state_dict(
**kwargs,
)
else:
hf_model = AutoModelForCausalLM.from_pretrained(
official_model_name,
torch_dtype=dtype,
token=huggingface_token if len(huggingface_token) > 0 else None,
**kwargs,
)
# Some older model configs (e.g., microsoft/phi-1) lack pad_token_id,
# which newer transformers versions require during model initialization.
try:
hf_model = AutoModelForCausalLM.from_pretrained(
official_model_name,
torch_dtype=dtype,
token=huggingface_token if len(huggingface_token) > 0 else None,
**kwargs,
)
except AttributeError as e:
if "pad_token_id" in str(e):
hf_config = AutoConfig.from_pretrained(
official_model_name,
token=huggingface_token if len(huggingface_token) > 0 else None,
)
hf_config.pad_token_id = getattr(hf_config, "pad_token_id", None)
hf_model = AutoModelForCausalLM.from_pretrained(
official_model_name,
config=hf_config,
torch_dtype=dtype,
token=huggingface_token if len(huggingface_token) > 0 else None,
**kwargs,
)
else:
raise

# Load model weights, and fold in layer norm weights
if hf_model is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,11 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
attn_output = attn_output.reshape(*input_shape, -1)

# --- Output Projection ---
attn_output = hf_attn.o_proj(attn_output)
# Different architectures name this differently: o_proj (Llama, Gemma, Qwen),
# dense (Phi), out_proj (others)
o_proj = getattr(hf_attn, "o_proj", None) or getattr(hf_attn, "dense", None)
if o_proj is not None:
attn_output = o_proj(attn_output)

# --- Output Hook ---
attn_output = self.hook_out(attn_output)
Expand Down
52 changes: 30 additions & 22 deletions transformer_lens/tools/model_registry/data/supported_models.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
},
"total_architectures": 35,
"total_models": 5833,
"total_verified": 686,
"total_verified": 687,
"models": [
{
"architecture_id": "Qwen2ForCausalLM",
Expand All @@ -27,14 +27,15 @@
"architecture_id": "Qwen3ForCausalLM",
"model_id": "Qwen/Qwen3-0.6B",
"status": 1,
"verified_date": "2026-03-10",
"verified_date": "2026-03-27",
"metadata": null,
"note": "Full verification completed",
"phase1_score": 100.0,
"phase2_score": 100.0,
"phase3_score": 100.0,
"phase4_score": 91.9,
"phase7_score": null
"phase7_score": null,
"phase8_score": null
},
{
"architecture_id": "GPT2LMHeadModel",
Expand Down Expand Up @@ -118,14 +119,15 @@
"architecture_id": "Qwen3ForCausalLM",
"model_id": "Qwen/Qwen3-4B",
"status": 1,
"verified_date": "2026-03-10",
"verified_date": "2026-03-27",
"metadata": null,
"note": "Full verification completed",
"phase1_score": 100.0,
"phase2_score": 100.0,
"phase3_score": 100.0,
"phase4_score": 99.4,
"phase7_score": null
"phase7_score": null,
"phase8_score": null
},
{
"architecture_id": "Qwen3ForCausalLM",
Expand Down Expand Up @@ -430,14 +432,15 @@
"architecture_id": "PhiForCausalLM",
"model_id": "microsoft/phi-2",
"status": 1,
"verified_date": "2026-03-10",
"verified_date": "2026-03-27",
"metadata": null,
"note": "Full verification completed",
"note": "Full verification completed with issues: P2=92.9% (failed: backward_hooks)",
"phase1_score": 100.0,
"phase2_score": 100.0,
"phase2_score": 92.9,
"phase3_score": 100.0,
"phase4_score": 95.8,
"phase7_score": null
"phase7_score": null,
"phase8_score": null
},
{
"architecture_id": "Qwen2ForCausalLM",
Expand Down Expand Up @@ -508,14 +511,15 @@
"architecture_id": "GPTNeoXForCausalLM",
"model_id": "EleutherAI/pythia-160m",
"status": 1,
"verified_date": "2026-03-10",
"verified_date": "2026-03-27",
"metadata": null,
"note": "Full verification completed",
"phase1_score": 100.0,
"phase2_score": 100.0,
"phase3_score": 100.0,
"phase4_score": 92.6,
"phase7_score": null
"phase7_score": null,
"phase8_score": null
},
{
"architecture_id": "LlamaForCausalLM",
Expand Down Expand Up @@ -1340,14 +1344,15 @@
"architecture_id": "GPTNeoXForCausalLM",
"model_id": "EleutherAI/pythia-70m-deduped",
"status": 1,
"verified_date": "2026-03-10",
"verified_date": "2026-03-27",
"metadata": null,
"note": "Full verification completed",
"phase1_score": 100.0,
"phase2_score": 100.0,
"phase3_score": 100.0,
"phase4_score": 77.5,
"phase7_score": null
"phase7_score": null,
"phase8_score": null
},
{
"architecture_id": "GptOssForCausalLM",
Expand Down Expand Up @@ -1742,15 +1747,16 @@
{
"architecture_id": "Olmo2ForCausalLM",
"model_id": "allenai/OLMo-2-0425-1B",
"status": 3,
"verified_date": "2026-03-10",
"status": 1,
"verified_date": "2026-03-27",
"metadata": null,
"note": "Below threshold: P3=90.0% but required tests failed: logits_equivalence \u2014 Tensors differ: max_diff=28.810719, mean_rel=60.400272",
"note": "Full verification completed",
"phase1_score": 100.0,
"phase2_score": 100.0,
"phase3_score": 90.0,
"phase3_score": 100.0,
"phase4_score": 94.8,
"phase7_score": null
"phase7_score": null,
"phase8_score": null
},
{
"architecture_id": "MistralForCausalLM",
Expand Down Expand Up @@ -2653,14 +2659,15 @@
"architecture_id": "PhiForCausalLM",
"model_id": "microsoft/phi-1_5",
"status": 1,
"verified_date": "2026-03-10",
"verified_date": "2026-03-27",
"metadata": null,
"note": "Full verification completed",
"phase1_score": 100.0,
"phase2_score": 100.0,
"phase3_score": 100.0,
"phase4_score": 97.9,
"phase7_score": null
"phase7_score": null,
"phase8_score": null
},
{
"architecture_id": "Qwen2ForCausalLM",
Expand Down Expand Up @@ -5877,14 +5884,15 @@
"architecture_id": "Olmo2ForCausalLM",
"model_id": "allenai/OLMo-2-0425-1B-Instruct",
"status": 1,
"verified_date": "2026-03-10",
"verified_date": "2026-03-27",
"metadata": null,
"note": "Full verification completed",
"phase1_score": 100.0,
"phase2_score": 100.0,
"phase3_score": 100.0,
"phase4_score": 94.2,
"phase7_score": null
"phase7_score": null,
"phase8_score": null
},
{
"architecture_id": "Gemma3ForCausalLM",
Expand Down
Loading
Loading