Skip to content

Commit ee4d204

Browse files
authored
[Fix] Fix calibrate bug when transformers>4.36 (#967)
* fix llama calibrate * fix layer idx
1 parent 3914eac commit ee4d204

File tree

1 file changed

+32
-6
lines changed

1 file changed

+32
-6
lines changed

lmdeploy/lite/quantization/calibration.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from typing import Union
44

55
import torch
6+
import transformers
7+
from mmengine import digit_version
68
from torch import nn
79
from transformers import PreTrainedTokenizer
810

@@ -162,12 +164,36 @@ def _forward(mod, *args, **kwargs):
162164

163165
if k_obs and v_obs:
164166
batch_kwargs[i]['use_cache'] = True
165-
out = self._ori_forwards[mod](*batch_args[i],
166-
**batch_kwargs[i])
167-
out = list(out)
168-
key, value = out.pop(-1)
169-
k_obs.observe(key)
170-
v_obs.observe(value)
167+
version = digit_version(transformers.__version__)
168+
use_new_cache = type(mod).__name__ == 'LlamaDecoderLayer'
169+
if version > digit_version('4.36.0') and use_new_cache:
170+
from transformers.cache_utils import DynamicCache
171+
batch_kwargs[i]['past_key_value'] = DynamicCache()
172+
173+
ori_idx = mod.self_attn.layer_idx
174+
mod.self_attn.layer_idx = 0
175+
176+
out = self._ori_forwards[mod](*batch_args[i],
177+
**batch_kwargs[i])
178+
mod.self_attn.layer_idx = ori_idx
179+
180+
out = list(out)
181+
cache = out.pop(-1)
182+
183+
key = cache.key_cache.pop(-1)
184+
value = cache.value_cache.pop(-1)
185+
186+
k_obs.observe(key)
187+
v_obs.observe(value)
188+
189+
else:
190+
out = self._ori_forwards[mod](*batch_args[i],
191+
**batch_kwargs[i])
192+
out = list(out)
193+
key, value = out.pop(-1)
194+
195+
k_obs.observe(key)
196+
v_obs.observe(value)
171197

172198
del key, value
173199
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)