diff --git a/src/speculators/convert/eagle/eagle3_converter.py b/src/speculators/convert/eagle/eagle3_converter.py index ffe3985..52d070b 100644 --- a/src/speculators/convert/eagle/eagle3_converter.py +++ b/src/speculators/convert/eagle/eagle3_converter.py @@ -145,7 +145,7 @@ def _create_transformer_config_from_eagle( vocab_size=eagle_config.get("target_vocab_size", 128000), hidden_size=eagle_config.get("hidden_size", 4096), intermediate_size=eagle_config.get("intermediate_size", 11008), - num_hidden_layers=1, + num_hidden_layers=eagle_config.get("num_hidden_layers", 1), num_attention_heads=eagle_config.get("num_attention_heads", 32), num_key_value_heads=eagle_config.get("num_key_value_heads", 8), hidden_act=eagle_config.get("hidden_act", "silu"), diff --git a/tests/unit/convert/test_eagle3_converter.py b/tests/unit/convert/test_eagle3_converter.py index 75080e1..23e3a07 100644 --- a/tests/unit/convert/test_eagle3_converter.py +++ b/tests/unit/convert/test_eagle3_converter.py @@ -93,6 +93,44 @@ def test_config_max_position_embeddings_logic( # rope_theta comes from Eagle3 config, not verifier assert llama_config.rope_theta == 10000.0 + @pytest.mark.sanity + @patch( + "speculators.convert.eagle.eagle3_converter.PretrainedConfig.get_config_dict" + ) + def test_config_num_hidden_layers_from_config( + self, mock_get_config, sample_eagle3_config + ): + """Test that num_hidden_layers is taken from eagle_config when present.""" + mock_get_config.return_value = ({}, None) + converter = Eagle3Converter() + + # Add num_hidden_layers to the sample config + sample_eagle3_config["num_hidden_layers"] = 3 + + llama_config = converter._create_transformer_config_from_eagle( + sample_eagle3_config, "meta-llama/Llama-3.1-8B" + ) + assert llama_config.num_hidden_layers == 3 + + @pytest.mark.sanity + @patch( + "speculators.convert.eagle.eagle3_converter.PretrainedConfig.get_config_dict" + ) + def test_config_num_hidden_layers_default( + self, mock_get_config, sample_eagle3_config + ): + """Test that num_hidden_layers defaults to 1 when not in config.""" + mock_get_config.return_value = ({}, None) + converter = Eagle3Converter() + + # Remove num_hidden_layers if present + sample_eagle3_config.pop("num_hidden_layers", None) + + llama_config = converter._create_transformer_config_from_eagle( + sample_eagle3_config, "meta-llama/Llama-3.1-8B" + ) + assert llama_config.num_hidden_layers == 1 + @pytest.mark.sanity @patch( "speculators.convert.eagle.eagle3_converter.PretrainedConfig.get_config_dict"