Skip to content

Commit 0ef6829

Browse files
authored
Peft Lora implementation (quic#85)
AutoPeftModelForCausalLM for loading LoRA models Better export code that can be utilized for other auto classes Hashing model cache location to avoid exporting again Hashing model compile location to avoid compiling again --------- Signed-off-by: Ilango Rajagopal <[email protected]>
1 parent 67922d7 commit 0ef6829

File tree

13 files changed

+1112
-0
lines changed

13 files changed

+1112
-0
lines changed

QEfficient/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from QEfficient.compile.compile_helper import compile
1010
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
1111
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
12+
from QEfficient.peft import QEffAutoPeftModelForCausalLM
1213
from QEfficient.transformers.transform import transform
1314

1415
# Users can use QEfficient.export for exporting models to ONNX
@@ -22,5 +23,6 @@
2223
"cloud_ai_100_exec_kv",
2324
"QEffAutoModel",
2425
"QEFFAutoModelForCausalLM",
26+
"QEffAutoPeftModelForCausalLM",
2527
"QEFFCommonLoader",
2628
]

QEfficient/peft/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
from QEfficient.peft.auto import QEffAutoPeftModelForCausalLM
9+
from QEfficient.peft.peft_model import QEffPeftModelForCausalLM
10+
11+
__all__ = [
12+
"QEffAutoPeftModelForCausalLM",
13+
"QEffPeftModelForCausalLM",
14+
]

QEfficient/peft/auto.py

Lines changed: 559 additions & 0 deletions
Large diffs are not rendered by default.

QEfficient/peft/onnx_transforms.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# ----------------------------------------------------------------------------
7+
8+
from typing import Tuple
9+
10+
import onnx
11+
12+
from QEfficient.base.onnx_transforms import OnnxTransform
13+
14+
15+
class AdapterWeightsToInputsTransform(OnnxTransform):
16+
@classmethod
17+
def apply(cls, model: onnx.ModelProto, *, adapter_name: str, **kwargs) -> Tuple[onnx.ModelProto, bool]:
18+
transformed = False
19+
removed_initializers = []
20+
21+
# Find nodes with lora weights as inputs
22+
weight_suffix = f".{adapter_name}.weight"
23+
lora_weight_nodes = {
24+
inp: node for node in model.graph.node for inp in node.input if inp.endswith(weight_suffix)
25+
}
26+
27+
for i, weight in enumerate(model.graph.initializer):
28+
if weight.name.endswith(weight_suffix):
29+
transformed = True
30+
31+
# Create input/output for lora weights
32+
new_weight_name = weight.name[: -len(weight_suffix)] + ".weight"
33+
type_proto = onnx.helper.make_tensor_type_proto(weight.data_type, shape=list(weight.dims))
34+
inp = onnx.ValueInfoProto(name=new_weight_name, type=type_proto)
35+
out = onnx.ValueInfoProto(name=new_weight_name + "_RetainedState", type=type_proto)
36+
model.graph.input.append(inp)
37+
model.graph.output.append(out)
38+
39+
# Create a node that connects input -> output
40+
node = onnx.helper.make_node("Identity", [inp.name], [out.name], new_weight_name + "_identity")
41+
model.graph.node.append(node)
42+
43+
# Rename weight input
44+
lora_weight_node = lora_weight_nodes[weight.name]
45+
for j, inp in enumerate(lora_weight_node.input):
46+
if inp == weight.name:
47+
lora_weight_node.input[j] = new_weight_name
48+
49+
# Remove weight initializers
50+
removed_initializers.append(i)
51+
52+
if transformed:
53+
for i in sorted(removed_initializers, reverse=True):
54+
model.graph.initializer.pop(i)
55+
56+
return model, transformed

QEfficient/peft/peft_model.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# ----------------------------------------------------------------------------
7+
8+
from peft import PeftModelForCausalLM, PeftType
9+
10+
11+
class QEffPeftModelForCausalLM(PeftModelForCausalLM):
12+
def forward(
13+
self,
14+
input_ids=None,
15+
attention_mask=None,
16+
position_ids=None,
17+
past_key_values=None,
18+
inputs_embeds=None,
19+
labels=None,
20+
output_attentions=None,
21+
output_hidden_states=None,
22+
return_dict=None,
23+
task_ids=None,
24+
**kwargs,
25+
):
26+
peft_config = self.active_peft_config
27+
if not peft_config.is_prompt_learning:
28+
if self.base_model.config.model_type == "mpt":
29+
if inputs_embeds is not None:
30+
raise AssertionError("forward in MPTForCausalLM does not support inputs_embeds")
31+
return self.base_model(
32+
input_ids=input_ids,
33+
attention_mask=attention_mask,
34+
position_ids=position_ids,
35+
past_key_values=past_key_values,
36+
labels=labels,
37+
output_attentions=output_attentions,
38+
output_hidden_states=output_hidden_states,
39+
return_dict=return_dict,
40+
**kwargs,
41+
)
42+
43+
if peft_config.peft_type == PeftType.POLY:
44+
kwargs["task_ids"] = task_ids
45+
46+
with self._enable_peft_forward_hooks(**kwargs):
47+
kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
48+
return self.base_model(
49+
input_ids=input_ids,
50+
attention_mask=attention_mask,
51+
position_ids=position_ids,
52+
past_key_values=past_key_values,
53+
inputs_embeds=inputs_embeds,
54+
labels=labels,
55+
output_attentions=output_attentions,
56+
output_hidden_states=output_hidden_states,
57+
return_dict=return_dict,
58+
**kwargs,
59+
)
60+
61+
raise NotImplementedError("Prompt learning methods are not supported from QEfficient")
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# ----------------------------------------------------------------------------
7+
8+
from peft import PeftModelForCausalLM
9+
10+
from QEfficient.base.pytorch_transforms import ModuleMappingTransform
11+
from QEfficient.peft.peft_model import QEffPeftModelForCausalLM
12+
13+
14+
class PeftModelInputsTransform(ModuleMappingTransform):
15+
_module_mapping = {PeftModelForCausalLM: QEffPeftModelForCausalLM}

QEfficient/utils/cache.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# ----------------------------------------------------------------------------
7+
8+
import json
9+
import os
10+
from pathlib import Path
11+
12+
QEFF_HOME: Path = None
13+
if "QEFF_HOME" in os.environ:
14+
QEFF_HOME = Path(os.environ["QEFF_HOME"])
15+
elif "XDG_CACHE_HOME" in os.environ:
16+
QEFF_HOME = Path(os.environ["XDG_CACHE_HOME"]) / "qeff_models"
17+
else:
18+
QEFF_HOME = Path("~/.cache/qeff_models").expanduser()
19+
20+
21+
def json_serializable(obj):
22+
if isinstance(obj, set):
23+
return sorted(obj)
24+
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
25+
26+
27+
def to_hashable(obj) -> bytes:
28+
"""
29+
Converts obj to bytes such that same object will result in same hash
30+
"""
31+
return json.dumps(
32+
obj,
33+
skipkeys=False,
34+
ensure_ascii=True,
35+
check_circular=True,
36+
allow_nan=False,
37+
indent=None,
38+
separators=(",", ":"),
39+
default=json_serializable,
40+
sort_keys=True,
41+
).encode()

QEfficient/utils/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ def get_models_dir():
4343

4444
QEFF_MODELS_DIR = get_models_dir()
4545

46+
ONNX_EXPORT_EXAMPLE_BATCH_SIZE = 1
47+
ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32
48+
4649

4750
class Constants:
4851
# Export Constants.

docs/source/hl_api.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@
99
:undoc-members:
1010
:exclude-members: QEffAutoModel,QEFFTransformersBase, run_ort, run_pytorch, get_tokenizer, run_cloud_ai_100, execute
1111
```
12+
13+
## `QEffAutoPeftModelForCausalLM`
14+
```{eval-rst}
15+
.. autoclass:: QEfficient.peft.auto.QEffAutoPeftModelForCausalLM
16+
:members:
17+
```
18+
1219
## `export`
1320
```{eval-rst}
1421
.. automodule:: QEfficient.exporter.export_hf_to_cloud_ai_100

examples/peft_models.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
from transformers import AutoTokenizer, TextStreamer
9+
10+
from QEfficient import QEffAutoPeftModelForCausalLM
11+
12+
base_model_name = "mistralai/Mistral-7B-v0.1"
13+
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
14+
streamer = TextStreamer(tokenizer)
15+
16+
m = QEffAutoPeftModelForCausalLM.from_pretrained("predibase/magicoder", "magicoder")
17+
m.export()
18+
m.compile(prefill_seq_len=32, ctx_len=1024)
19+
20+
# Magicoder adapter
21+
m.set_adapter("magicoder")
22+
inputs = tokenizer("def fibonacci", return_tensors="pt")
23+
m.generate(**inputs, streamer=streamer, max_new_tokens=1024)
24+
25+
# TLDR, summary generator
26+
m.load_adapter("predibase/tldr_headline_gen", "tldr_headline_gen")
27+
m.set_adapter("tldr_headline_gen")
28+
inputs = tokenizer(
29+
"""Summarize this passage in one sentence or less: Jeffrey Berns, CEO of Blockchains LLC, wants the Nevada government to allow companies like \
30+
his to form local governments on land they own, granting them power over everything from \
31+
schools to law enforcement. Berns envisions a city based on digital currencies and \
32+
blockchain storage. His company is proposing to build a 15,000 home town 12 miles east of \
33+
Reno. Nevada Lawmakers have responded with intrigue and skepticism. The proposed \
34+
legislation has yet to be formally filed or discussed in public hearings.
35+
36+
Summary: """,
37+
return_tensors="pt",
38+
)
39+
m.generate(**inputs, streamer=streamer, max_new_tokens=1024)
40+
41+
# Math problems
42+
m.load_adapter("predibase/gsm8k", "gsm8k")
43+
m.set_adapter("gsm8k")
44+
inputs = tokenizer(
45+
"James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. \
46+
How many total meters does he run a week?",
47+
return_tensors="pt",
48+
)
49+
m.generate(**inputs, streamer=streamer, max_new_tokens=1024)
50+
51+
# News explanation
52+
m.load_adapter("predibase/agnews_explained", "agnews_explained")
53+
m.set_adapter("agnews_explained")
54+
inputs = tokenizer(
55+
"""Below is a news article. Please classify it under one of the following \
56+
classes (World, Business, Sports, Sci/Tech) and provide a reasonable coherent explanation for \
57+
why the article is classified as such. Please format your response as a JSON payload.
58+
59+
### Article: US poverty rate climbs, along with number lacking health coverage (AFP) AFP - The \
60+
number of Americans living in poverty or without health insurance grew last year, a government \
61+
survey showed, adding potential dynamite in the battle for the White House.
62+
63+
### JSON Response
64+
65+
""",
66+
return_tensors="pt",
67+
)
68+
m.generate(**inputs, streamer=streamer, max_new_tokens=1024)

0 commit comments

Comments
 (0)