Skip to content

Commit 33f5b19

Browse files
authored
fix lora name and rearange wqkv for internlm2 (#2912)
* fix lora name and rearange lora_b for wqkv * update for internvl * fix torchvision mismatch torch
1 parent 0ffac7f commit 33f5b19

File tree

5 files changed

+43
-2
lines changed

5 files changed

+43
-2
lines changed

lmdeploy/pytorch/models/internlm2.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,32 @@ def prepare_inputs_for_generation(
397397
inputs_embeds=inputs_embeds,
398398
)
399399

400+
def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]],
401+
adapter_id: int):
402+
"""load lora weights."""
403+
404+
from lmdeploy.pytorch.adapter.adapter import load_lora_weights
405+
406+
num_heads = self.config.num_attention_heads
407+
num_key_value_heads = self.config.num_key_value_heads
408+
hidden_size = self.config.hidden_size
409+
head_dim = hidden_size // num_heads
410+
group_size = num_heads // num_key_value_heads
411+
412+
def _rearange_wqkv(weights):
413+
for name, loaded_weight in weights:
414+
if 'wqkv.lora_B' in name:
415+
loaded_weight = loaded_weight.unflatten(
416+
0, (-1, 2 + group_size, head_dim))
417+
q = loaded_weight[:, :-2].flatten(0, 2)
418+
k = loaded_weight[:, -2].flatten(0, 1)
419+
v = loaded_weight[:, -1].flatten(0, 1)
420+
loaded_weight = torch.cat([q, k, v], dim=0)
421+
yield name, loaded_weight
422+
423+
weights_iter = _rearange_wqkv(weights)
424+
load_lora_weights(self, weights_iter, adapter_id)
425+
400426
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
401427
"""load weights."""
402428
# modify from vllm

lmdeploy/pytorch/models/internvl.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,17 @@ def prepare_inputs_for_generation(
516516
inputs_embeds=inputs_embeds,
517517
)
518518

519+
def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]],
520+
adapter_id: int):
521+
"""load lora weights."""
522+
523+
if hasattr(self.language_model, 'load_lora_weights'):
524+
return self.language_model.load_lora_weights(weights, adapter_id)
525+
else:
526+
from lmdeploy.pytorch.adapter.adapter import load_lora_weights
527+
528+
return load_lora_weights(weights, adapter_id)
529+
519530
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
520531
"""load weights."""
521532

lmdeploy/pytorch/models/patch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,10 @@ def add_adapters(model: torch.nn.Module,
251251
ranks, scalings = get_ranks_and_scalings(target_name,
252252
adapter_cfgs,
253253
device=device)
254+
# split in case target_name has '.' like 'attention.wo'
255+
# which cannot be used as name of a module
256+
# and it's not aligned with key in model.packed_modules_mapping
257+
target_name = target_name.split('.')[-1]
254258
found_mods, pack_idx = find_all_target(model, target_name)
255259
sum_rank = ranks.sum().item()
256260

requirements/runtime_ascend.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,6 @@ shortuuid
1818
tiktoken
1919
torch<=2.4.0,>=2.3.1
2020
torch-npu==2.3.1
21-
torchvision<=0.19.0,>=0.15.0
21+
torchvision<=0.19.0,>=0.18.1
2222
transformers
2323
uvicorn

requirements/runtime_cuda.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ sentencepiece
1616
shortuuid
1717
tiktoken
1818
torch<=2.5.1,>=2.0.0
19-
torchvision<=0.19.0,>=0.15.0
19+
torchvision<=0.20.1,>=0.15.0
2020
transformers
2121
triton==3.0.0; sys_platform == "linux"
2222
uvicorn

0 commit comments

Comments
 (0)