Skip to content

Commit d9377b8

Browse files
authored
Support generation search for transformers examples (#2029)
Signed-off-by: Kaihui-intel <[email protected]>
1 parent 61f1e39 commit d9377b8

File tree

5 files changed

+912
-1
lines changed

5 files changed

+912
-1
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ repos:
7676
)$
7777
7878
- repo: https://github.com/PyCQA/docformatter
79-
rev: v1.7.5
79+
rev: 06907d0
8080
hooks:
8181
- id: docformatter
8282
args: [

examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/transformers/weight_only/text-generation/run_generation_gpu_woq.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import intel_extension_for_pytorch as ipex
99
from neural_compressor.transformers import AutoModelForCausalLM, AutoRoundConfig, RtnConfig, GPTQConfig
1010
from neural_compressor.transformers.quantization.utils import convert_dtype_str2torch
11+
from neural_compressor.transformers.generation import _greedy_search, _beam_search
1112
from transformers.utils import check_min_version
1213
import contextlib
1314

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2024 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
from .beam_search import _beam_search
19+
from .greedy_search import _greedy_search

0 commit comments

Comments
 (0)