From f6c95d7106e8e6b270cee1f76f4794d7c4957fd6 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 6 Oct 2025 01:19:27 +0000 Subject: [PATCH 1/2] Unify way to identify tower module in lora cases --- python/sglang/srt/lora/lora_manager.py | 27 ++++++++++++++++++++---- python/sglang/srt/models/dots_vlm.py | 5 ++++- python/sglang/srt/models/gemma3_mm.py | 8 +++++-- python/sglang/srt/models/gemma3n_mm.py | 7 +++++-- python/sglang/srt/models/kimi_vl.py | 5 ++++- python/sglang/srt/models/llava.py | 17 +++++++++++---- python/sglang/srt/models/llavavid.py | 5 ++++- python/sglang/srt/models/mistral.py | 6 ++++++ python/sglang/srt/models/mllama4.py | 8 +++++-- python/sglang/srt/models/qwen2_audio.py | 8 +++++-- python/sglang/srt/models/utils.py | 19 +++++++++++++++++ python/sglang/srt/models/vila.py | 5 ++++- test/srt/models/test_llama4_towers.py | 22 +++++++++++++++++++ test/srt/models/test_mistral_models.py | 28 +++++++++++++++++++++++++ 14 files changed, 150 insertions(+), 20 deletions(-) create mode 100644 test/srt/models/test_llama4_towers.py create mode 100644 test/srt/models/test_mistral_models.py diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 2b90a8741bf..ed98394b00c 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -68,6 +68,20 @@ def __init__( self.tp_size: int = tp_size self.tp_rank: int = tp_rank + tower_names_fn = getattr(self.base_model, "get_mm_tower_names", None) + self.tower_module_prefixes: Set[str] = set() + if callable(tower_names_fn): + raw_prefixes = tower_names_fn() + if not isinstance(raw_prefixes, (list, tuple, set)): + raise TypeError( + "get_mm_tower_names() must return an iterable of strings." + ) + self.tower_module_prefixes = { + str(prefix).strip() + for prefix in raw_prefixes + if isinstance(prefix, str) and prefix.strip() + } + # LoRA backend for running sgemm kernels logger.info(f"Using {lora_backend} as backend of LoRA kernels.") backend_type = get_backend_from_name(lora_backend) @@ -418,9 +432,14 @@ def set_lora_module(self, module_name, module): replace_submodule(self.base_model, module_name, lora_module) return lora_module - def should_skip_lora_for_vision_model(self, module_name): - # TODO: support different vision models - return module_name.find("vision_model.model") != -1 + def should_skip_lora_for_tower(self, module_name: str) -> bool: + if any( + module_name == prefix or module_name.startswith(prefix + ".") + for prefix in self.tower_module_prefixes + ): + return True + # Maintain backward compatibility for historical vision model naming. + return "vision_model.model" in module_name def init_lora_modules(self): # Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module. @@ -440,7 +459,7 @@ def init_lora_modules(self): continue # Skip vision model - if self.should_skip_lora_for_vision_model(module_name): + if self.should_skip_lora_for_tower(module_name): continue # The module should be converted if it is included in target_names diff --git a/python/sglang/srt/models/dots_vlm.py b/python/sglang/srt/models/dots_vlm.py index 95475058f5e..d9b751fff16 100644 --- a/python/sglang/srt/models/dots_vlm.py +++ b/python/sglang/srt/models/dots_vlm.py @@ -33,13 +33,16 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM +from sglang.srt.models.utils import TowerAwareMixin from .dots_vlm_vit import DotsVisionTransformer -class DotsVLMForCausalLM(nn.Module): +class DotsVLMForCausalLM(TowerAwareMixin, nn.Module): """DotsVLM model for sglang inference""" + tower_names = ("vision_tower",) + def __init__( self, config: DotsVLMConfig, quant_config: Optional[QuantizationConfig] = None ) -> None: diff --git a/python/sglang/srt/models/gemma3_mm.py b/python/sglang/srt/models/gemma3_mm.py index 8060fdee94b..1eefa7dc3e3 100644 --- a/python/sglang/srt/models/gemma3_mm.py +++ b/python/sglang/srt/models/gemma3_mm.py @@ -42,6 +42,7 @@ ) from sglang.srt.models.gemma3_causal import Gemma3ForCausalLM from sglang.srt.models.siglip import SiglipVisionModel +from sglang.srt.models.utils import TowerAwareMixin from sglang.srt.utils import add_prefix from sglang.srt.utils.hf_transformers_utils import get_processor @@ -106,8 +107,10 @@ def forward(self, vision_outputs: torch.Tensor) -> torch.Tensor: return projected_vision_outputs.type_as(vision_outputs) -class Gemma3ForConditionalGeneration(PreTrainedModel): +class Gemma3ForConditionalGeneration(TowerAwareMixin, PreTrainedModel): config_class = Gemma3Config + + tower_names = ("vision_tower",) """Gemma3 multimodal model for conditional generation.""" # BitandBytes specific attributes @@ -165,10 +168,11 @@ def __init__( self.config = config self.quant_config = quant_config + vision_tower_name = self.get_tower_name() self.vision_tower = SiglipVisionModel( config=config.vision_config, quant_config=quant_config, - prefix=add_prefix("vision_tower", prefix), + prefix=add_prefix(vision_tower_name, prefix), ) self.multi_modal_projector = Gemma3MultiModalProjector(config) diff --git a/python/sglang/srt/models/gemma3n_mm.py b/python/sglang/srt/models/gemma3n_mm.py index 3c52635dd9e..6e57e0062ec 100644 --- a/python/sglang/srt/models/gemma3n_mm.py +++ b/python/sglang/srt/models/gemma3n_mm.py @@ -36,6 +36,7 @@ ) from sglang.srt.models.gemma3n_audio import Gemma3nAudioEncoder from sglang.srt.models.gemma3n_causal import Gemma3nRMSNorm, Gemma3nTextModel +from sglang.srt.models.utils import TowerAwareMixin from sglang.srt.utils import add_prefix from sglang.srt.utils.hf_transformers_utils import get_processor @@ -142,7 +143,7 @@ def forward( return self.embedding_post_projection_norm(emb_norm_proj) -class Gemma3nForConditionalGeneration(PreTrainedModel): +class Gemma3nForConditionalGeneration(TowerAwareMixin, PreTrainedModel): config_class = Gemma3nConfig """Gemma3n multimodal model for conditional generation.""" @@ -189,6 +190,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel): embedding_modules = {} embedding_padding_modules = [] supports_lora = True + tower_names = ("vision_tower", "audio_tower") def __init__( self, @@ -221,10 +223,11 @@ def __init__( prefix=add_prefix("embed_audio", prefix), ) + audio_tower_name = self.get_tower_name(1) self.audio_tower = Gemma3nAudioEncoder( config.audio_config, quant_config=quant_config, - prefix=add_prefix("audio_tower", prefix), + prefix=add_prefix(audio_tower_name, prefix), ) self.vocab_size = config.text_config.vocab_size diff --git a/python/sglang/srt/models/kimi_vl.py b/python/sglang/srt/models/kimi_vl.py index 68ed47b2ef0..1fd1994bd19 100644 --- a/python/sglang/srt/models/kimi_vl.py +++ b/python/sglang/srt/models/kimi_vl.py @@ -79,6 +79,7 @@ ) from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM from sglang.srt.models.kimi_vl_moonvit import MoonVitPretrainedModel +from sglang.srt.models.utils import TowerAwareMixin from sglang.srt.utils import add_prefix logger = logging.getLogger(__name__) @@ -118,7 +119,9 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: return hidden_states -class KimiVLForConditionalGeneration(nn.Module): +class KimiVLForConditionalGeneration(TowerAwareMixin, nn.Module): + tower_names = ("vision_tower",) + def __init__( self, config: KimiVLConfig, diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 2fbbe559081..7d7cd15c85c 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -46,6 +46,7 @@ from sglang.srt.models.llama import LlamaForCausalLM from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM +from sglang.srt.models.utils import TowerAwareMixin from sglang.srt.multimodal.mm_utils import ( get_anyres_image_grid_shape, unpad_image, @@ -54,7 +55,9 @@ from sglang.srt.utils import add_prefix, flatten_nested_list, logger -class LlavaBaseForCausalLM(nn.Module): +class LlavaBaseForCausalLM(TowerAwareMixin, nn.Module): + tower_names = ("vision_tower",) + def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs): image_sizes = flatten_nested_list( [item.image_sizes for item in image_inputs.mm_items] @@ -475,16 +478,21 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): raise ValueError(f"Unexpected select feature: {self.select_feature}") # load mm_projector + vision_tower_name = self.get_tower_name() projector_weights = { "model.mm_projector.0": "multi_modal_projector.linear_1", "model.mm_projector.2": "multi_modal_projector.linear_2", - "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned). "model.image_newline": "language_model.model.image_newline", } + projector_weights[f"model.{vision_tower_name}.{vision_tower_name}"] = vision_tower_name params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: - if "projector" in name or "vision_tower" in name or "image_newline" in name: + if ( + "projector" in name + or vision_tower_name in name + or "image_newline" in name + ): for weight_name, param_name in projector_weights.items(): if weight_name in name: name = name.replace(weight_name, param_name) @@ -733,10 +741,11 @@ def __init__( quant_config=quant_config, prefix=add_prefix("language_model", prefix), ) + vision_tower_name = self.get_tower_name() self.vision_tower = vision_model_cls( self.vision_config, quant_config=quant_config, - prefix=add_prefix("vision_tower", prefix), + prefix=add_prefix(vision_tower_name, prefix), ) if "unpad" in getattr(self.config, "mm_patch_merge_type", ""): diff --git a/python/sglang/srt/models/llavavid.py b/python/sglang/srt/models/llavavid.py index e5d6aa72ba9..850bd71c040 100644 --- a/python/sglang/srt/models/llavavid.py +++ b/python/sglang/srt/models/llavavid.py @@ -26,10 +26,13 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.llama import LlamaForCausalLM +from sglang.srt.models.utils import TowerAwareMixin from sglang.srt.utils import add_prefix -class LlavaVidForCausalLM(nn.Module): +class LlavaVidForCausalLM(TowerAwareMixin, nn.Module): + tower_names = ("vision_tower",) + def __init__( self, config: LlavaConfig, diff --git a/python/sglang/srt/models/mistral.py b/python/sglang/srt/models/mistral.py index 632e857c280..2a71fdde660 100644 --- a/python/sglang/srt/models/mistral.py +++ b/python/sglang/srt/models/mistral.py @@ -89,5 +89,11 @@ def __hasattr__(self, name): def __call__(self, *args, **kwargs): return self.inner(*args, **kwargs) + def get_mm_tower_names(self): + get_names = getattr(self.inner, "get_mm_tower_names", None) + if callable(get_names): + return get_names() + return [] + EntryClass = [MistralForCausalLM, Mistral3ForConditionalGeneration] diff --git a/python/sglang/srt/models/mllama4.py b/python/sglang/srt/models/mllama4.py index 72077d96a27..f244015f44c 100644 --- a/python/sglang/srt/models/mllama4.py +++ b/python/sglang/srt/models/mllama4.py @@ -33,6 +33,7 @@ global_server_args_dict, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.models.utils import TowerAwareMixin from sglang.srt.utils import is_cpu _is_cpu = is_cpu() @@ -416,12 +417,14 @@ def forward( return hidden_state -class Llama4ForConditionalGeneration(nn.Module): +class Llama4ForConditionalGeneration(TowerAwareMixin, nn.Module): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], } + tower_names = ("vision_model",) + def __init__( self, config: Llama4Config, @@ -457,10 +460,11 @@ def __init__( vision_quant_config = None else: vision_quant_config = quant_config + vision_tower_name = self.get_tower_name() self.vision_model = Llama4VisionModel( config.vision_config, quant_config=vision_quant_config, - prefix=add_prefix("vision_model", prefix), + prefix=add_prefix(vision_tower_name, prefix), ) self.multi_modal_projector = Llama4MultiModalProjector(config) diff --git a/python/sglang/srt/models/qwen2_audio.py b/python/sglang/srt/models/qwen2_audio.py index 8609758a958..693b27248bc 100644 --- a/python/sglang/srt/models/qwen2_audio.py +++ b/python/sglang/srt/models/qwen2_audio.py @@ -59,13 +59,14 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2 import Qwen2ForCausalLM +from sglang.srt.models.utils import TowerAwareMixin from sglang.srt.utils import add_prefix from sglang.srt.utils.hf_transformers_utils import get_processor logger = logging.getLogger(__name__) -class Qwen2AudioForConditionalGeneration(nn.Module): +class Qwen2AudioForConditionalGeneration(TowerAwareMixin, nn.Module): # BitandBytes specific attributes default_bitsandbytes_target_modules = [ ".gate_proj.", @@ -85,6 +86,8 @@ class Qwen2AudioForConditionalGeneration(nn.Module): "up_proj": ("gate_up_proj", 1), } + tower_names = ("audio_tower",) + def __init__( self, config: Qwen2AudioConfig, @@ -151,6 +154,7 @@ def forward( return hidden_states def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + audio_tower_name = self.get_tower_name() stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -173,7 +177,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name or "audio_tower" in name: + if weight_name not in name or audio_tower_name in name: continue name_tmp = name.replace(weight_name, param_name) diff --git a/python/sglang/srt/models/utils.py b/python/sglang/srt/models/utils.py index f4c2a0e3eee..b80d2faffd3 100644 --- a/python/sglang/srt/models/utils.py +++ b/python/sglang/srt/models/utils.py @@ -12,6 +12,8 @@ # limitations under the License. # ============================================================================== +from typing import List, Sequence + import torch from sglang.srt.layers.radix_attention import RadixAttention @@ -49,3 +51,20 @@ def create_fused_set_kv_buffer_arg( v_scale=layer.v_scale, cache_loc=forward_batch.out_cache_loc, ) + + +class TowerAwareMixin: + """Mixin for multimodal models to declare tower module prefixes.""" + + tower_names: Sequence[str] = () + + def get_mm_tower_names(self) -> List[str]: + return [name for name in self.tower_names if isinstance(name, str) and name] + + def get_tower_name(self, index: int = 0) -> str: + names = self.get_mm_tower_names() + if not names: + raise ValueError("No tower names defined for this model.") + if not (0 <= index < len(names)): + raise IndexError(f"Tower index {index} out of range for towers {names}.") + return names[index] diff --git a/python/sglang/srt/models/vila.py b/python/sglang/srt/models/vila.py index 2bb0b2d35d9..78619058fa2 100644 --- a/python/sglang/srt/models/vila.py +++ b/python/sglang/srt/models/vila.py @@ -24,6 +24,7 @@ ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.models.qwen2 import Qwen2ForCausalLM +from sglang.srt.models.utils import TowerAwareMixin logger = logging.getLogger(__name__) @@ -179,7 +180,7 @@ def forward(self, x: Tensor) -> Tensor: ##### END COPY modeling_vila.py ##### -class VILAForConditionalGeneration(nn.Module): +class VILAForConditionalGeneration(TowerAwareMixin, nn.Module): config: VILAConfig quant_config: Optional[QuantizationConfig] @@ -190,6 +191,8 @@ class VILAForConditionalGeneration(nn.Module): mm_projector: MultimodalProjector vision_tower: SiglipVisionModel + tower_names = ("vision_tower",) + def __init__( self, config: VILAConfig, diff --git a/test/srt/models/test_llama4_towers.py b/test/srt/models/test_llama4_towers.py new file mode 100644 index 00000000000..18167e0dffe --- /dev/null +++ b/test/srt/models/test_llama4_towers.py @@ -0,0 +1,22 @@ +import pytest + +from sglang.srt.models.mllama4 import Llama4ForConditionalGeneration + + +def _build_stub() -> Llama4ForConditionalGeneration: + # Bypass heavy initialization while keeping class-level tower metadata accessible. + return object.__new__(Llama4ForConditionalGeneration) + + +def test_llama4_tower_names_exposed(): + model = _build_stub() + + assert model.get_mm_tower_names() == ["vision_model"] + assert model.get_tower_name() == "vision_model" + + +def test_llama4_get_tower_name_out_of_range(): + model = _build_stub() + + with pytest.raises(IndexError): + model.get_tower_name(1) diff --git a/test/srt/models/test_mistral_models.py b/test/srt/models/test_mistral_models.py new file mode 100644 index 00000000000..b1159fd5ba0 --- /dev/null +++ b/test/srt/models/test_mistral_models.py @@ -0,0 +1,28 @@ +import torch + +from sglang.srt.models.mistral import Mistral3ForConditionalGeneration +from sglang.srt.models.utils import TowerAwareMixin + + +class _DummyInnerModel(torch.nn.Module, TowerAwareMixin): + tower_names = ("vision_tower",) + + def __init__(self): + super().__init__() + self._dummy_param = torch.nn.Parameter(torch.zeros(1)) + + +def _build_wrapper() -> Mistral3ForConditionalGeneration: + wrapper = object.__new__(Mistral3ForConditionalGeneration) + wrapper.inner = _DummyInnerModel() + return wrapper + + +def test_mistral3_get_mm_tower_names_delegates_to_inner(): + wrapper = _build_wrapper() + + # The wrapper should expose tower names via the inner model. + assert wrapper.get_mm_tower_names() == ["vision_tower"] + + # Ensure attribute access falls through to the inner nn.Module instance. + assert list(wrapper.parameters())[0] is wrapper.inner._dummy_param From e8fb374b12e5842172506c2322909db73f94019b Mon Sep 17 00:00:00 2001 From: root Date: Mon, 6 Oct 2025 04:08:50 +0000 Subject: [PATCH 2/2] fix lint and suggestion --- python/sglang/srt/lora/lora_manager.py | 2 -- python/sglang/srt/models/llava.py | 4 +++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index ed98394b00c..1b4c241072d 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -438,8 +438,6 @@ def should_skip_lora_for_tower(self, module_name: str) -> bool: for prefix in self.tower_module_prefixes ): return True - # Maintain backward compatibility for historical vision model naming. - return "vision_model.model" in module_name def init_lora_modules(self): # Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module. diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 7d7cd15c85c..d0e0106520d 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -485,7 +485,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Update the vision tower weights if we find them in the checkpoint (it may be finetuned). "model.image_newline": "language_model.model.image_newline", } - projector_weights[f"model.{vision_tower_name}.{vision_tower_name}"] = vision_tower_name + projector_weights[f"model.{vision_tower_name}.{vision_tower_name}"] = ( + vision_tower_name + ) params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if (