Skip to content

Commit 8c32e8e

Browse files
ConnorLi96root
authored andcommitted
Fix LoRA support for multimodal models (VLMs) by implementing a consistent pattern for skipping vision components (sgl-project#11261)
1 parent 7900896 commit 8c32e8e

File tree

3 files changed

+26
-8
lines changed

3 files changed

+26
-8
lines changed

python/sglang/srt/lora/lora_manager.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -418,10 +418,6 @@ def set_lora_module(self, module_name, module):
418418
replace_submodule(self.base_model, module_name, lora_module)
419419
return lora_module
420420

421-
def should_skip_lora_for_vision_model(self, module_name):
422-
# TODO: support different vision models
423-
return module_name.find("vision_model.model") != -1
424-
425421
def init_lora_modules(self):
426422
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
427423
self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
@@ -439,10 +435,6 @@ def init_lora_modules(self):
439435
) and not self.base_model.should_apply_lora(module_name):
440436
continue
441437

442-
# Skip vision model
443-
if self.should_skip_lora_for_vision_model(module_name):
444-
continue
445-
446438
# The module should be converted if it is included in target_names
447439
if module_name.split(".")[-1] in self.target_modules:
448440
layer_id = get_layer_id(module_name)

python/sglang/srt/models/gemma3_mm.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py
1717

1818
import logging
19+
import re
1920
from functools import lru_cache
2021
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
2122

@@ -154,6 +155,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
154155
embedding_modules = {}
155156
embedding_padding_modules = []
156157
supports_lora = True
158+
# Pattern to match language model layers only (skip vision_tower and multi_modal_projector)
159+
lora_pattern = re.compile(
160+
r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
161+
)
157162

158163
def __init__(
159164
self,
@@ -165,6 +170,13 @@ def __init__(
165170
self.config = config
166171
self.quant_config = quant_config
167172

173+
# For LoRA compatibility: expose text_config attributes at top level
174+
# This allows LoRA code to work without special multimodal handling
175+
if not hasattr(config, "num_hidden_layers"):
176+
config.num_hidden_layers = config.text_config.num_hidden_layers
177+
if not hasattr(config, "hidden_size"):
178+
config.hidden_size = config.text_config.hidden_size
179+
168180
self.vision_tower = SiglipVisionModel(
169181
config=config.vision_config,
170182
quant_config=quant_config,
@@ -380,6 +392,10 @@ def forward(
380392

381393
return hs
382394

395+
def should_apply_lora(self, module_name: str) -> bool:
396+
"""Skip vision tower and multi_modal_projector for LoRA."""
397+
return bool(self.lora_pattern.match(module_name))
398+
383399
def tie_weights(self):
384400
return self.language_model.tie_weights()
385401

python/sglang/srt/models/mllama4.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
import math
44
import os
5+
import re
56
from collections.abc import Iterable
67
from typing import List, Optional, Set, Tuple
78

@@ -422,6 +423,11 @@ class Llama4ForConditionalGeneration(nn.Module):
422423
"gate_up_proj": ["gate_proj", "up_proj"],
423424
}
424425

426+
# Pattern to match language model layers only (skip vision_model and multi_modal_projector)
427+
lora_pattern = re.compile(
428+
r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
429+
)
430+
425431
def __init__(
426432
self,
427433
config: Llama4Config,
@@ -555,6 +561,10 @@ def get_image_feature(
555561

556562
return projected_vision_flat
557563

564+
def should_apply_lora(self, module_name: str) -> bool:
565+
"""Skip vision model and multi_modal_projector for LoRA."""
566+
return bool(self.lora_pattern.match(module_name))
567+
558568
def forward(
559569
self,
560570
input_ids: torch.Tensor,

0 commit comments

Comments
 (0)