|
41 | 41 | resize_image, |
42 | 42 | resize_image_letterbox, |
43 | 43 | resize_image_with_aspect, |
| 44 | + setup_python_preprocessing_pipeline, |
44 | 45 | ) |
45 | 46 |
|
46 | 47 |
|
@@ -143,6 +144,8 @@ def __init__( |
143 | 144 | ) |
144 | 145 | self.is_onnx_file = False |
145 | 146 | self.onnx_metadata = {} |
| 147 | + self.preprocessor = lambda arg: arg |
| 148 | + self.use_python_preprocessing = False |
146 | 149 |
|
147 | 150 | if isinstance(self.model_path, (str, Path)): |
148 | 151 | if Path(self.model_path).suffix == ".onnx" and weights_path: |
@@ -175,7 +178,52 @@ def __init__( |
175 | 178 | msg = "Model must be bytes or a file" |
176 | 179 | raise RuntimeError(msg) |
177 | 180 |
|
| 181 | + def reshape_dynamic_inputs(self) -> None: |
| 182 | + """For NPU devices, set static shape if the model has dynamic shapes""" |
| 183 | + for input in self.model.inputs: |
| 184 | + if input.partial_shape.is_dynamic: |
| 185 | + input_name = input.get_any_name() |
| 186 | + shape = get_input_shape(input) |
| 187 | + static_shape = [] |
| 188 | + |
| 189 | + # Detect likely layout for 4D shapes |
| 190 | + is_nchw = False |
| 191 | + if len(shape) == 4 and not isinstance(shape[1], tuple) and shape[1] != -1 and shape[1] <= 4: |
| 192 | + is_nchw = True |
| 193 | + |
| 194 | + for i, dim in enumerate(shape): |
| 195 | + if isinstance(dim, tuple): |
| 196 | + static_shape.append((dim[0] + dim[1]) // 2) |
| 197 | + elif dim == -1: |
| 198 | + if i == 0: |
| 199 | + static_shape.append(1) |
| 200 | + elif len(shape) == 4: |
| 201 | + if is_nchw: |
| 202 | + if i == 1: |
| 203 | + static_shape.append(3) |
| 204 | + else: |
| 205 | + static_shape.append(224) |
| 206 | + else: |
| 207 | + if i == 3: |
| 208 | + static_shape.append(3) |
| 209 | + else: |
| 210 | + static_shape.append(224) |
| 211 | + else: |
| 212 | + static_shape.append(1) |
| 213 | + else: |
| 214 | + static_shape.append(dim) |
| 215 | + |
| 216 | + log.info( |
| 217 | + f"NPU: Reshaping input '{input_name}' from dynamic {shape} to static {static_shape}", |
| 218 | + ) |
| 219 | + self.reshape_model({input_name: static_shape}) |
| 220 | + |
178 | 221 | def load_model(self) -> None: |
| 222 | + """Loads the model to the device specified in the constructor""" |
| 223 | + devices = parse_devices(self.device) |
| 224 | + if any("NPU" in dev.upper() for dev in devices) and self.model.is_dynamic(): |
| 225 | + self.reshape_dynamic_inputs() |
| 226 | + |
179 | 227 | self.compiled_model = self.core.compile_model( |
180 | 228 | self.model, |
181 | 229 | self.device, |
@@ -280,11 +328,17 @@ def copy_raw_result(self, request): |
280 | 328 | return {key: request.get_tensor(key).data.copy() for key in self.get_output_layers()} |
281 | 329 |
|
282 | 330 | def infer_sync(self, dict_data: dict[str, ndarray]) -> dict[str, ndarray]: |
| 331 | + if self.use_python_preprocessing: |
| 332 | + for key in dict_data: |
| 333 | + dict_data[key] = self.preprocessor(dict_data[key]) |
283 | 334 | self.infer_request = self.async_queue[self.async_queue.get_idle_request_id()] |
284 | 335 | self.infer_request.infer(dict_data) |
285 | 336 | return self.get_raw_result(self.infer_request) |
286 | 337 |
|
287 | 338 | def infer_async(self, dict_data, callback_data) -> None: |
| 339 | + if self.use_python_preprocessing: |
| 340 | + for key in dict_data: |
| 341 | + dict_data[key] = self.preprocessor(dict_data[key]) |
288 | 342 | self.async_queue.start_async(dict_data, callback_data) |
289 | 343 |
|
290 | 344 | def set_callback(self, callback_fn: Callable): |
@@ -347,8 +401,32 @@ def embed_preprocessing( |
347 | 401 | input_idx: int = 0, |
348 | 402 | ) -> None: |
349 | 403 | """ |
350 | | - Embeds OpenVINO PrePostProcessor module into the model. |
| 404 | + Embeds preprocessing into the model, or sets up Python preprocessing for NPU devices. |
351 | 405 | """ |
| 406 | + # Check if we should use Python preprocessing for NPU devices |
| 407 | + devices = parse_devices(self.device) |
| 408 | + if any("NPU" in dev.upper() for dev in devices): |
| 409 | + self.preprocessor = setup_python_preprocessing_pipeline( |
| 410 | + layout=layout, |
| 411 | + resize_mode=resize_mode, |
| 412 | + interpolation_mode=interpolation_mode, |
| 413 | + target_shape=target_shape, |
| 414 | + pad_value=pad_value, |
| 415 | + dtype=dtype, |
| 416 | + brg2rgb=brg2rgb, |
| 417 | + mean=mean, |
| 418 | + scale=scale, |
| 419 | + input_idx=input_idx, |
| 420 | + ) |
| 421 | + self.use_python_preprocessing = True |
| 422 | + input_name = self.model.inputs[input_idx].get_any_name() |
| 423 | + if layout == "NCHW": |
| 424 | + static_shape = [1, 3, target_shape[1], target_shape[0]] |
| 425 | + else: |
| 426 | + static_shape = [1, target_shape[1], target_shape[0], 3] |
| 427 | + self.reshape_model({input_name: static_shape}) |
| 428 | + return |
| 429 | + |
352 | 430 | ppp = PrePostProcessor(self.model) |
353 | 431 |
|
354 | 432 | # Change the input type to the 8-bit image |
|
0 commit comments