@@ -87,8 +87,8 @@ def prepare_pytorch_inputs(self):
8787
8888 if self .full_batch_size :
8989 inputs ["input_ids" ] = input_ids
90- inputs ["position_ids" ] = torch . arange ( input_len ). view ( 1 , input_len )
91- inputs ["batch_index" ] = torch .arange (1 ).view (- 1 , 1 )
90+ inputs ["position_ids" ] = position_ids
91+ inputs ["batch_index" ] = torch .arange (self . full_batch_size ).view (- 1 , 1 )
9292
9393 past_key_values = []
9494 sliding_padding_shape = self .padding_shape [:2 ] + [self .config .sliding_window ] + [self .padding_shape [- 1 ]]
@@ -117,18 +117,15 @@ def update_pytorch_inputs(self, inputs, pt_outputs):
117117 """
118118 updated_inputs = {}
119119 if self .full_batch_size :
120- batch_index = torch .arange (1 ).view (- 1 , 1 )
121-
122120 input_ids = pt_outputs .logits .detach ().argmax (2 )
123121 updated_inputs ["input_ids" ] = torch .full ((self .full_batch_size , 1 ), self .tokenizer .pad_token_id )
124- updated_inputs ["input_ids" ][batch_index .view (- 1 )] = input_ids
122+ updated_inputs ["input_ids" ][inputs [ " batch_index" ] .view (- 1 )] = input_ids
125123
126124 position_ids = inputs ["position_ids" ].max (1 , keepdim = True ).values + 1
127125 updated_inputs ["position_ids" ] = torch .full ((self .full_batch_size , 1 ), 0 )
128- updated_inputs ["position_ids" ][batch_index .view (- 1 )] = position_ids
129-
130- updated_inputs ["batch_index" ] = torch .arange (self .full_batch_size ).view (- 1 , 1 )
126+ updated_inputs ["position_ids" ][inputs ["batch_index" ].view (- 1 )] = position_ids
131127
128+ updated_inputs ["batch_index" ] = inputs ["batch_index" ]
132129 else :
133130 updated_inputs ["input_ids" ] = pt_outputs ["logits" ].argmax (- 1 ).reshape (- 1 , 1 )
134131 updated_inputs ["position_ids" ] = inputs ["position_ids" ].max (1 , keepdim = True ).values + 1
@@ -172,9 +169,15 @@ def prepare_ort_inputs(self):
172169 inputs ["past_key." + str (i )] = np .zeros ((cache_shape ), dtype = np .float32 )
173170 inputs ["past_value." + str (i )] = np .zeros ((cache_shape ), dtype = np .float32 )
174171 else :
172+ sliding_padding_shape = self .padding_shape [:2 ] + [self .config .sliding_window ] + [self .padding_shape [- 1 ]]
175173 for i in range (self .n_layer ):
176- inputs ["past_key." + str (i )] = np .zeros ((self .padding_shape ), dtype = np .float32 )
177- inputs ["past_value." + str (i )] = np .zeros ((self .padding_shape ), dtype = np .float32 )
174+ pad_shape = (
175+ sliding_padding_shape if self .config .layer_types [i ] == "sliding_attention" else self .padding_shape
176+ )
177+ inputs ["past_key." + str (i )] = np .zeros ((pad_shape ), dtype = np .float32 )
178+ inputs ["past_value." + str (i )] = np .zeros ((pad_shape ), dtype = np .float32 )
179+ if self .full_batch_size :
180+ inputs ["batch_index" ] = np .arange (self .full_batch_size ).reshape (- 1 , 1 )
178181 return inputs
179182
180183 def update_ort_inputs (self , inputs , ort_outputs ):
@@ -195,7 +198,8 @@ def update_ort_inputs(self, inputs, ort_outputs):
195198 for i in range (self .n_layer ):
196199 updated_inputs ["past_key." + str (i )] = ort_outputs ["past_key_values" ][i * 2 ]
197200 updated_inputs ["past_value." + str (i )] = ort_outputs ["past_key_values" ][i * 2 + 1 ]
198-
201+ if self .full_batch_size :
202+ updated_inputs ["batch_index" ] = inputs ["batch_index" ]
199203 return updated_inputs
200204
201205 def update_ort_outputs (self , ort_outputs ):
0 commit comments