Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions python/sglang/srt/lora/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,10 +418,6 @@ def set_lora_module(self, module_name, module):
replace_submodule(self.base_model, module_name, lora_module)
return lora_module

def should_skip_lora_for_vision_model(self, module_name):
# TODO: support different vision models
return module_name.find("vision_model.model") != -1

def init_lora_modules(self):
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
Expand All @@ -439,10 +435,6 @@ def init_lora_modules(self):
) and not self.base_model.should_apply_lora(module_name):
continue

# Skip vision model
if self.should_skip_lora_for_vision_model(module_name):
continue

# The module should be converted if it is included in target_names
if module_name.split(".")[-1] in self.target_modules:
layer_id = get_layer_id(module_name)
Expand Down
16 changes: 16 additions & 0 deletions python/sglang/srt/models/gemma3_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py

import logging
import re
from functools import lru_cache
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict

Expand Down Expand Up @@ -154,6 +155,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
embedding_modules = {}
embedding_padding_modules = []
supports_lora = True
# Pattern to match language model layers only (skip vision_tower and multi_modal_projector)
lora_pattern = re.compile(
r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
)

def __init__(
self,
Expand All @@ -165,6 +170,13 @@ def __init__(
self.config = config
self.quant_config = quant_config

# For LoRA compatibility: expose text_config attributes at top level
# This allows LoRA code to work without special multimodal handling
if not hasattr(config, "num_hidden_layers"):
config.num_hidden_layers = config.text_config.num_hidden_layers
if not hasattr(config, "hidden_size"):
config.hidden_size = config.text_config.hidden_size
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only Gemma3 needs config attribute exposure:

  • Gemma3ForConditionalGeneration uses a config structure where language model attributes (num_hidden_layers, hidden_size) are exclusively in config.text_config, with no top-level copies
  • Other VLMs like Phi4ForConditionalGeneration and Llama4ForConditionalGeneration either:
    • Already have these attributes at the top level, OR
    • Have different config inheritance that makes them accessible
  • LoRA's LoRAMemoryPool.__init__ directly accesses base_hf_config.num_hidden_layers (line 59 in mem_pool.py), which fails for Gemma3's nested-only structure


self.vision_tower = SiglipVisionModel(
config=config.vision_config,
quant_config=quant_config,
Expand Down Expand Up @@ -380,6 +392,10 @@ def forward(

return hs

def should_apply_lora(self, module_name: str) -> bool:
"""Skip vision tower and multi_modal_projector for LoRA."""
return bool(self.lora_pattern.match(module_name))

def tie_weights(self):
return self.language_model.tie_weights()

Expand Down
10 changes: 10 additions & 0 deletions python/sglang/srt/models/mllama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import math
import os
import re
from collections.abc import Iterable
from typing import List, Optional, Set, Tuple

Expand Down Expand Up @@ -422,6 +423,11 @@ class Llama4ForConditionalGeneration(nn.Module):
"gate_up_proj": ["gate_proj", "up_proj"],
}

# Pattern to match language model layers only (skip vision_model and multi_modal_projector)
lora_pattern = re.compile(
r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
)

def __init__(
self,
config: Llama4Config,
Expand Down Expand Up @@ -544,6 +550,10 @@ def get_image_feature(

return projected_vision_flat

def should_apply_lora(self, module_name: str) -> bool:
"""Skip vision model and multi_modal_projector for LoRA."""
return bool(self.lora_pattern.match(module_name))

def forward(
self,
input_ids: torch.Tensor,
Expand Down
Loading