Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@
"-W ignore::beartype.roar.BeartypeDecorHintPep585DeprecationWarning",
]
doctest_optionflags="NORMALIZE_WHITESPACE ELLIPSIS FLOAT_CMP"
markers=[
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
]
filterwarnings=[
"ignore:pkg_resources is deprecated as an API:DeprecationWarning",
# Ignore numpy.distutils deprecation warning caused by pandas
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch

from transformer_lens.ActivationCache import ActivationCache
from transformer_lens.model_bridge import TransformerBridge


class TestActivationCacheCompatibility:
Expand All @@ -14,16 +13,15 @@ class TestActivationCacheCompatibility:
def cleanup_after_class(self):
"""Clean up memory after each test class."""
yield
# Clear GPU memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
for _ in range(3):
gc.collect()

@pytest.fixture(scope="class")
def bridge_model(self):
"""Create a TransformerBridge model for testing."""
return TransformerBridge.boot_transformers("gpt2", device="cpu")
def bridge_model(self, gpt2_bridge):
"""Use session-scoped gpt2 bridge."""
return gpt2_bridge

@pytest.fixture(scope="class")
def sample_cache(self, bridge_model):
Expand Down
23 changes: 5 additions & 18 deletions tests/acceptance/model_bridge/compatibility/test_backward_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,23 @@
import pytest
import torch

from transformer_lens import HookedTransformer
from transformer_lens.model_bridge import TransformerBridge


class TestBackwardHookCompatibility:
"""Test backward hook compatibility between TransformerBridge and HookedTransformer."""

@pytest.mark.skip(
reason="hook_mlp_out has known gradient differences due to architectural bridging (0.875 diff, but forward pass matches perfectly)"
)
def test_backward_hook_gradients_match_hooked_transformer(self):
def test_backward_hook_gradients_match_hooked_transformer(
self, gpt2_hooked_unprocessed, gpt2_bridge_compat_no_processing
):
"""Test that backward hook gradients match between TransformerBridge and HookedTransformer.

This test ensures that backward hooks see identical gradient values in both
TransformerBridge and HookedTransformer when using no_processing mode.
"""
hooked_model = HookedTransformer.from_pretrained_no_processing("gpt2", device_map="cpu")
bridge_model: TransformerBridge = TransformerBridge.boot_transformers(
"gpt2", device="cpu"
) # type: ignore
bridge_model.enable_compatibility_mode(no_processing=True)
hooked_model = gpt2_hooked_unprocessed
bridge_model = gpt2_bridge_compat_no_processing

test_input = torch.tensor([[1, 2, 3]])

Expand All @@ -51,16 +47,7 @@ def sum_bridge_grads(grad, hook=None):
out = bridge_model(test_input)
out.sum().backward()

print(f"HookedTransformer gradient sum: {hooked_grad_sum.item():.6f}")
print(f"TransformerBridge gradient sum: {bridge_grad_sum.item():.6f}")
print(f"Difference: {abs(hooked_grad_sum - bridge_grad_sum).item():.6f}")
assert torch.allclose(hooked_grad_sum, bridge_grad_sum, atol=1e-2, rtol=1e-2), (
f"Gradient sums should be identical but differ by "
f"{abs(hooked_grad_sum - bridge_grad_sum).item():.6f}"
)


if __name__ == "__main__":
test = TestBackwardHookCompatibility()
test.test_backward_hook_gradients_match_hooked_transformer()
print("✅ Backward hook compatibility test passed!")
226 changes: 0 additions & 226 deletions tests/acceptance/model_bridge/compatibility/test_bridge_hooks.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from transformer_lens.benchmarks import benchmark_forward_hooks, benchmark_hook_registry
from transformer_lens.model_bridge import TransformerBridge

pytestmark = pytest.mark.skip(reason="Temporarily skipping hook completeness tests pending fixes")
pytestmark = pytest.mark.slow

# Diverse architectures for hook completeness testing
MODELS_TO_TEST = [
Expand Down Expand Up @@ -71,7 +71,9 @@ def test_all_hooks_fire(self, model_name):
test_text = "The quick brown fox"

# Run benchmark - this will fail if hooks don't fire
result = benchmark_forward_hooks(bridge, test_text, reference_model=ht, tolerance=1e-3)
# tolerance=1e-2: some architectures (e.g., pythia) accumulate small floating-point
# differences across layers that exceed 1e-3 but are not meaningful divergences.
result = benchmark_forward_hooks(bridge, test_text, reference_model=ht, tolerance=1e-2)

# Must pass - all hooks must fire
assert result.passed, (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@

import torch

from transformer_lens import HookedTransformer
from transformer_lens.model_bridge import TransformerBridge


def test_TransformerBridge_compatibility_mode_calls_hooks_once():
def test_TransformerBridge_compatibility_mode_calls_hooks_once(
gpt2_hooked_unprocessed, gpt2_bridge_compat_no_processing
):
"""Regression test: hooks fire exactly once even with aliased HookPoint names."""
hooked_model = HookedTransformer.from_pretrained_no_processing("gpt2", device_map="cpu")
bridge_model: TransformerBridge = TransformerBridge.boot_transformers("gpt2", device="cpu") # type: ignore
bridge_model.enable_compatibility_mode(no_processing=True)
hooked_model = gpt2_hooked_unprocessed
bridge_model = gpt2_bridge_compat_no_processing

test_input = torch.tensor([[1, 2, 3]])

Expand Down Expand Up @@ -47,10 +45,9 @@ def count_bridge_calls(acts, hook):
)


def test_hook_mlp_out_aliasing():
def test_hook_mlp_out_aliasing(gpt2_bridge_compat_no_processing):
"""Test that hook_mlp_out is properly aliased to mlp.hook_out in compatibility mode."""
bridge_model: TransformerBridge = TransformerBridge.boot_transformers("gpt2", device="cpu") # type: ignore
bridge_model.enable_compatibility_mode(no_processing=True)
bridge_model = gpt2_bridge_compat_no_processing

block0 = bridge_model.blocks[0]

Expand All @@ -61,10 +58,9 @@ def test_hook_mlp_out_aliasing():
), "hook_mlp_out should be aliased to mlp.hook_out (same object)"


def test_stateful_hook_pattern():
def test_stateful_hook_pattern(gpt2_bridge_compat_no_processing):
"""Test stateful closure pattern (circuit-tracer's cache-then-pop) with aliased hooks."""
bridge_model: TransformerBridge = TransformerBridge.boot_transformers("gpt2", device="cpu") # type: ignore
bridge_model.enable_compatibility_mode(no_processing=True)
bridge_model = gpt2_bridge_compat_no_processing

test_input = torch.tensor([[1, 2, 3]])
block = bridge_model.blocks[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ def model_name(self, request):
return request.param

@pytest.fixture(scope="class")
def bridge_model(self, model_name):
"""Create a TransformerBridge model for testing."""
def bridge_model(self, model_name, gpt2_bridge):
"""Use session-scoped fixture for gpt2, load fresh for other models."""
if model_name == "gpt2":
return gpt2_bridge
try:
return TransformerBridge.boot_transformers(model_name, device="cpu")
except Exception as e:
Expand Down
Loading
Loading