Skip to content

Refactor loading_from_pretrained local cache for Llama and Gemma models#1220

Closed
brendanlong wants to merge 3 commits intoTransformerLensOrg:devfrom
brendanlong:brendanlong/load-from-pretrained-cache
Closed

Refactor loading_from_pretrained local cache for Llama and Gemma models#1220
brendanlong wants to merge 3 commits intoTransformerLensOrg:devfrom
brendanlong:brendanlong/load-from-pretrained-cache

Conversation

@brendanlong
Copy link
Copy Markdown

@brendanlong brendanlong commented Mar 29, 2026

Description

loading_from_pretrained.convert_hf_model_config() has special cases hard-coding values for Llama and Gemma, to avoid making calls to HuggingFace since these models are gated.

This split the logic between two sections of the code, leading to weird edge cases like google/gemma-2b having architecture Gemma2ForCausalLM even though it's a Gemma 1 model, and requiring all Llama models to be hard-coded since hf_config was never loaded for them, despite a LlamaForCausalLM path existing (causing problems for adding Deepseek R1 distill support, see brendanlong#7 for the trivial commit on top of this PR).

This should also make the code much easier to follow and support going forward, and I moved it to a separate file to avoid dumping hundreds of lines of configs into the context whenever someone reads this.

I added regression tests in a commit before making the changes to ensure that we're not breaking existing configs.

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

claude and others added 2 commits March 28, 2026 16:29
…cture branches

These tests cover every branch in the giant if/else chain:
- All 13 hardcoded Llama model configs
- All 5 hardcoded Gemma 1/2 model configs
- All architecture-based configs (GPT2, GPTNeo, OPT, GPTJ, GPTNeoX, Bloom,
  Mistral, Mixtral, Santacoder, generic Llama, Qwen, Qwen2, Qwen3, Phi,
  Phi3, Bert, HuBERT, Wav2Vec2, T5, Apertus, GptOss)
- Edge cases (unsupported architecture, tokenizer_name, trust_remote_code,
  TinyStories n_ctx override)

This establishes a baseline before refactoring the Llama config loading
to use a HF config cache instead of hardcoded dicts.

Note: Fixed a missing key for the codellama path, adding explicit
`"rotary_adjacent_pairs": False` matching the other models and ensuring
the tests won't need to change when we fix this to be consistent.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Instead of 13 hardcoded if/elif blocks (280 lines) that duplicate config
values for each Llama model variant, we now:

1. Cache LlamaConfig objects in _LLAMA_HF_CONFIGS, keyed by model name
   prefix. These contain the same values as HF's actual configs, with
   max_position_embeddings capped to memory-safe values where needed.

2. Look up the cached config via _get_llama_hf_config() when a Llama
   model is detected by name. Falls back to AutoConfig.from_pretrained()
   for unknown Llama models (e.g. fine-tunes).

3. Enhanced the generic LlamaForCausalLM handler with:
   - rope_theta support (for Llama 3+ models using 500K base)
   - rope_scaling support (NTK-by-parts for Llama 3.1/3.2/3.3)

This means adding support for new Llama-based models (e.g. DeepSeek R1
distills) only requires adding a cache entry with the model's config
values, rather than duplicating the entire config dict.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@brendanlong brendanlong force-pushed the brendanlong/load-from-pretrained-cache branch from c5eacc7 to 9c1237a Compare March 29, 2026 00:55
Matches by longest prefix: e.g. "meta-llama/Llama-3.1-8B-Instruct" matches
the "meta-llama/Llama-3.1-8B" entry. Returns None if no cache entry matches.
"""
for prefix in sorted(_HF_CONFIG_CACHE, key=len, reverse=True):
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm tempted to remove the prefix matching and just list every model explicitly, but I didn't want to change the current functionality.

@brendanlong brendanlong force-pushed the brendanlong/load-from-pretrained-cache branch from 9c1237a to 8079bc0 Compare March 29, 2026 15:53
Move all Gemma 1, 2, and 3 models to _HF_CONFIG_CACHE using their
native HF config classes (GemmaConfig, Gemma2Config, Gemma3TextConfig,
Gemma3Config). Add generic architecture handlers for GemmaForCausalLM,
Gemma2ForCausalLM, Gemma3ForCausalLM, and Gemma3ForConditionalGeneration
that read from the HF config objects, eliminating ~460 lines of
hardcoded dicts and the name-based architecture detection.

Fixes google/gemma-2b incorrectly getting Gemma2ForCausalLM architecture
(now correctly gets GemmaForCausalLM from cache). Gemma 2 attn_types now
exactly match n_layers instead of having extra unused entries.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@brendanlong brendanlong force-pushed the brendanlong/load-from-pretrained-cache branch from 8079bc0 to 4f3832d Compare March 29, 2026 16:12
@brendanlong
Copy link
Copy Markdown
Author

Actually this causes problems with different versinos of LlamaConfig on transformers==0.42.0, which we seem to support on ancient Python versions.

That's unfortunate but I'll try to think about if there's a clean way to do this that doesn't break between transformers versions :\

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants