Refactor loading_from_pretrained local cache for Llama and Gemma models#1220
Closed
brendanlong wants to merge 3 commits intoTransformerLensOrg:devfrom
Closed
Refactor loading_from_pretrained local cache for Llama and Gemma models#1220brendanlong wants to merge 3 commits intoTransformerLensOrg:devfrom
brendanlong wants to merge 3 commits intoTransformerLensOrg:devfrom
Conversation
…cture branches These tests cover every branch in the giant if/else chain: - All 13 hardcoded Llama model configs - All 5 hardcoded Gemma 1/2 model configs - All architecture-based configs (GPT2, GPTNeo, OPT, GPTJ, GPTNeoX, Bloom, Mistral, Mixtral, Santacoder, generic Llama, Qwen, Qwen2, Qwen3, Phi, Phi3, Bert, HuBERT, Wav2Vec2, T5, Apertus, GptOss) - Edge cases (unsupported architecture, tokenizer_name, trust_remote_code, TinyStories n_ctx override) This establishes a baseline before refactoring the Llama config loading to use a HF config cache instead of hardcoded dicts. Note: Fixed a missing key for the codellama path, adding explicit `"rotary_adjacent_pairs": False` matching the other models and ensuring the tests won't need to change when we fix this to be consistent. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Instead of 13 hardcoded if/elif blocks (280 lines) that duplicate config values for each Llama model variant, we now: 1. Cache LlamaConfig objects in _LLAMA_HF_CONFIGS, keyed by model name prefix. These contain the same values as HF's actual configs, with max_position_embeddings capped to memory-safe values where needed. 2. Look up the cached config via _get_llama_hf_config() when a Llama model is detected by name. Falls back to AutoConfig.from_pretrained() for unknown Llama models (e.g. fine-tunes). 3. Enhanced the generic LlamaForCausalLM handler with: - rope_theta support (for Llama 3+ models using 500K base) - rope_scaling support (NTK-by-parts for Llama 3.1/3.2/3.3) This means adding support for new Llama-based models (e.g. DeepSeek R1 distills) only requires adding a cache entry with the model's config values, rather than duplicating the entire config dict. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
c5eacc7 to
9c1237a
Compare
brendanlong
commented
Mar 29, 2026
| Matches by longest prefix: e.g. "meta-llama/Llama-3.1-8B-Instruct" matches | ||
| the "meta-llama/Llama-3.1-8B" entry. Returns None if no cache entry matches. | ||
| """ | ||
| for prefix in sorted(_HF_CONFIG_CACHE, key=len, reverse=True): |
Author
There was a problem hiding this comment.
I'm tempted to remove the prefix matching and just list every model explicitly, but I didn't want to change the current functionality.
9c1237a to
8079bc0
Compare
brendanlong
commented
Mar 29, 2026
Move all Gemma 1, 2, and 3 models to _HF_CONFIG_CACHE using their native HF config classes (GemmaConfig, Gemma2Config, Gemma3TextConfig, Gemma3Config). Add generic architecture handlers for GemmaForCausalLM, Gemma2ForCausalLM, Gemma3ForCausalLM, and Gemma3ForConditionalGeneration that read from the HF config objects, eliminating ~460 lines of hardcoded dicts and the name-based architecture detection. Fixes google/gemma-2b incorrectly getting Gemma2ForCausalLM architecture (now correctly gets GemmaForCausalLM from cache). Gemma 2 attn_types now exactly match n_layers instead of having extra unused entries. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
8079bc0 to
4f3832d
Compare
Author
|
Actually this causes problems with different versinos of LlamaConfig on transformers==0.42.0, which we seem to support on ancient Python versions. That's unfortunate but I'll try to think about if there's a clean way to do this that doesn't break between transformers versions :\ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
loading_from_pretrained.convert_hf_model_config() has special cases hard-coding values for Llama and Gemma, to avoid making calls to HuggingFace since these models are gated.
This split the logic between two sections of the code, leading to weird edge cases like
google/gemma-2bhaving architecture Gemma2ForCausalLM even though it's a Gemma 1 model, and requiring all Llama models to be hard-coded since hf_config was never loaded for them, despite a LlamaForCausalLM path existing (causing problems for adding Deepseek R1 distill support, see brendanlong#7 for the trivial commit on top of this PR).This should also make the code much easier to follow and support going forward, and I moved it to a separate file to avoid dumping hundreds of lines of configs into the context whenever someone reads this.
I added regression tests in a commit before making the changes to ensure that we're not breaking existing configs.
Type of change
Please delete options that are not relevant.
Checklist: