Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions tests/unit/test_loading_from_pretrained_utilities.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from unittest import mock

import pytest

from transformer_lens import HookedTransformer
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.loading_from_pretrained import fill_missing_keys
from transformer_lens.loading_from_pretrained import (
fill_missing_keys,
get_pretrained_model_config,
)


def get_default_config():
Expand Down Expand Up @@ -97,9 +102,12 @@ class TestArchitectureConfigs:
"""Verify that convert_hf_model_config produces correct configs for new architectures."""

def test_apertus_config(self):
from transformer_lens.loading_from_pretrained import get_pretrained_model_config

cfg = get_pretrained_model_config("apertus-8b")
try:
cfg = get_pretrained_model_config("apertus-8b")
except ValueError as e:
if "does not recognize this architecture" in str(e):
pytest.skip(f"transformers version too old: {e}")
raise
assert cfg.original_architecture == "ApertusForCausalLM"
assert cfg.normalization_type == "RMS"
assert cfg.positional_embedding_type == "rotary"
Expand All @@ -112,9 +120,12 @@ def test_apertus_config(self):
assert cfg.n_heads > 0

def test_gpt_oss_config(self):
from transformer_lens.loading_from_pretrained import get_pretrained_model_config

cfg = get_pretrained_model_config("gpt-oss-20b")
try:
cfg = get_pretrained_model_config("gpt-oss-20b")
except ValueError as e:
if "does not recognize this architecture" in str(e):
pytest.skip(f"transformers version too old: {e}")
raise
assert cfg.original_architecture == "GptOssForCausalLM"
assert cfg.normalization_type == "RMS"
assert cfg.positional_embedding_type == "rotary"
Expand All @@ -126,8 +137,11 @@ def test_gpt_oss_config(self):
assert cfg.n_key_value_heads is not None

def test_apertus_instruct_config(self):
from transformer_lens.loading_from_pretrained import get_pretrained_model_config

cfg = get_pretrained_model_config("apertus-8b-instruct")
try:
cfg = get_pretrained_model_config("apertus-8b-instruct")
except ValueError as e:
if "does not recognize this architecture" in str(e):
pytest.skip(f"transformers version too old: {e}")
raise
assert cfg.original_architecture == "ApertusForCausalLM"
assert cfg.act_fn == "xielu"
Loading