@@ -125,7 +125,6 @@ def generate(
125125
126126 max_length = len (input_ids ) + max_new_tokens
127127
128- params .input_ids = input_ids
129128 if self .config and "search" in self .config :
130129 search_config = self .config ["search" ]
131130 params .set_search_options (
@@ -159,10 +158,10 @@ def generate(
159158 params .try_graph_capture_with_max_batch_size (1 )
160159
161160 generator = og .Generator (self .model , params )
161+ generator .append_tokens (input_ids )
162162
163163 if streamer is None :
164164 prompt_start_time = time .perf_counter ()
165- generator .compute_logits ()
166165 generator .generate_next_token ()
167166 prompt_end_time = time .perf_counter ()
168167
@@ -173,7 +172,6 @@ def generate(
173172 token_gen_times = []
174173 while not generator .is_done ():
175174 token_gen_start_time = time .perf_counter ()
176- generator .compute_logits ()
177175 generator .generate_next_token ()
178176 token_gen_end_time = time .perf_counter ()
179177
@@ -194,7 +192,6 @@ def generate(
194192 stop_early = False
195193
196194 while not generator .is_done () and not stop_early :
197- generator .compute_logits ()
198195 generator .generate_next_token ()
199196
200197 new_token = generator .get_next_tokens ()[0 ]
@@ -253,6 +250,13 @@ def parser(add_help: bool = True) -> argparse.ArgumentParser:
253250 add_help = add_help ,
254251 )
255252
253+ parser .add_argument (
254+ "-ip" ,
255+ "--input_path" ,
256+ default = "" ,
257+ help = "the local huggingface model in your disk" ,
258+ )
259+
256260 parser .add_argument (
257261 "-d" ,
258262 "--device" ,
@@ -304,6 +308,7 @@ def run(
304308 self ,
305309 state : State ,
306310 input : str ,
311+ input_path : str = "" ,
307312 device : str = "igpu" ,
308313 dtype : str = "int4" ,
309314 int4_block_size : int = None ,
@@ -449,7 +454,7 @@ def run(
449454 try :
450455 model_builder .create_model (
451456 checkpoint , # model_name
452- "" , # input_path
457+ input_path , # input_path
453458 full_model_path , # output_path
454459 dtype , # precision
455460 execution_providers [device ], # execution_provider
0 commit comments