Skip to content

Allow passing in custom past_key_values #497

@Nisimachluf

Description

@Nisimachluf

I'm trying to use KV caching with phi3-unsloth model from the HF hub (unsloth/Phi-3-mini-4k-instruct)
How ever it seems that the FastLanguageModel class doesn't suuprt KV caching.
Here is a toy exmaple of asking it a question, and folow it's reply with another question.

from unsloth import FastLanguageModel

max_seq_length = 4096  # Can be set arbitrarily, automatically supports RoPE scaling!
dtype = None  # Automatically detect if None. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = False  # Reduce memory usage using 4-bit quantization. Can be set to False.
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="/media/local/models/phi3_unsloth",  # Use "unsloth/mistral-7b" for 16-bit loading
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
    attn_implementation="flash_attention_2",  # loading the model with flash-attenstion support

)

prompt = """<|user|>
My name name is Jon. What is my name?<|end|>
<|assistant|>"""

model_inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to("cuda")
generated_output = model.generate(**model_inputs, max_new_tokens=500, return_dict_in_generate=True, temperature=0)
text_output = tokenizer.batch_decode(generated_output.sequences)[0]
print(text_output)

second_prompt = """
<|user|>
I'm 30 years old. How old am i?<|end|>
<|assistant|>"""

full_prompt = text_output + second_prompt
model_inputs = tokenizer(full_prompt, return_tensors="pt", add_special_tokens=False).to("cuda")
generated_output = model.generate(**model_inputs, max_new_tokens=500, return_dict_in_generate=True, past_key_values=generated_output.past_key_values)
text_output = tokenizer.batch_decode(generated_output.sequences)[0]
print(text_output)

The second call to model.generate() fails with

Traceback (most recent call last):
  File "phi3_unsloth_toy.py", line 31, in <module>
    generated_output = model.generate(**model_inputs, max_new_tokens=500, return_dict_in_generate=True, past_key_values=generated_output.past_key_values)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 1736, in generate
    result = self._sample(
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 2375, in _sample
    outputs = self(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/unsloth/models/mistral.py", line 205, in MistralForCausalLM_fast_forward
    outputs = LlamaModel_fast_forward_inference(
  File "/usr/local/lib/python3.10/dist-packages/unsloth/models/llama.py", line 748, in LlamaModel_fast_forward_inference
    hidden_states, present_key_value = LlamaAttention_fast_forward_inference(
  File "/usr/local/lib/python3.10/dist-packages/unsloth/models/llama.py", line 154, in LlamaAttention_fast_forward_inference
    Qn = Qn.view(bsz, 1, n_heads,    head_dim).transpose(1, 2)
RuntimeError: shape '[1, 1, 32, 96]' is invalid for input of size 61440

Works well if not using past_key_values.

Metadata

Metadata

Assignees

No one assigned

    Labels

    feature requestFeature request pending on roadmaphelp wantedHelp from the OSS community wanted!

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions