Skip to content

Commit 44532be

Browse files
Investigate refactoring opportunities for batch management in Plugin and Compiler - review
1 parent 6210112 commit 44532be

File tree

5 files changed

+66
-5
lines changed

5 files changed

+66
-5
lines changed

src/plugins/intel_npu/src/backend/src/zero_infer_request.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@ std::optional<size_t> determine_dynamic_batch_size(const IODescriptor& desc,
8484
}
8585

8686
auto batchFromModel = ioShape[intel_npu::utils::BATCH_AXIS];
87-
if (!batchFromModel.is_dynamic()) {
87+
auto batchModelFromIR =
88+
desc.shapeFromIRModel.has_value() && desc.shapeFromIRModel.value()[intel_npu::utils::BATCH_AXIS].is_dynamic();
89+
if (!batchFromModel.is_dynamic() && !batchModelFromIR) {
8890
return std::nullopt;
8991
}
9092

src/plugins/intel_npu/src/compiler_adapter/include/graph.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class Graph : public IGraph {
6464

6565
protected:
6666
bool release_blob(const Config& config);
67+
std::optional<size_t> determine_batch_size();
6768

6869
std::shared_ptr<ZeGraphExtWrappers> _zeGraphExt;
6970

src/plugins/intel_npu/src/compiler_adapter/src/graph.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,10 @@ void Graph::initialize(const Config& config) {
232232
// releasing it here to avoid unnecessary memory usage.
233233
_blobIsReleased = release_blob(config);
234234

235+
if (!_batchSize.has_value()) {
236+
_batchSize = determine_batch_size();
237+
}
238+
235239
if (_zeroInitStruct->getCommandQueueDdiTable().version() < ZE_MAKE_VERSION(1, 1) &&
236240
config.get<RUN_INFERENCES_SEQUENTIALLY>()) {
237241
auto numberOfCommandLists = _batchSize.has_value() ? *_batchSize : 1;
@@ -288,6 +292,58 @@ uint32_t Graph::get_last_submitted_id() const {
288292
return _lastSubmittedId;
289293
}
290294

295+
std::optional<size_t> Graph::determine_batch_size() {
296+
if (!_metadata.outputs.at(0).shapeFromIRModel.has_value()) {
297+
_logger.debug("Batching on the plugin is not used, batching is handled by the compiler");
298+
return std::nullopt;
299+
}
300+
301+
const ov::PartialShape& firstShape = *_metadata.outputs.at(0).shapeFromIRModel;
302+
if (firstShape.is_dynamic() || firstShape.rank().get_length() == 0) {
303+
return std::nullopt;
304+
}
305+
306+
const size_t candidateBatchSize = firstShape[utils::BATCH_AXIS].get_max_length();
307+
if (candidateBatchSize == 0 || candidateBatchSize == utils::DEFAULT_BATCH_SIZE) {
308+
_logger.debug("Batching on the plugin is not used, batching is handled by the compiler");
309+
return std::nullopt;
310+
}
311+
312+
auto checkDescriptorsUseCandidateBatchSize = [candidateBatchSize](const std::vector<IODescriptor>& descriptors) {
313+
for (const IODescriptor& descriptor : descriptors) {
314+
OPENVINO_ASSERT(descriptor.shapeFromIRModel.has_value(),
315+
"Missing value for the \"shapeFromIRModel\" attribute, I/O descriptor");
316+
317+
const ov::PartialShape& shapeFromCompiler = descriptor.shapeFromCompiler;
318+
const ov::PartialShape& shapeFromIRModel = *descriptor.shapeFromIRModel;
319+
320+
if (shapeFromCompiler.is_dynamic() || shapeFromCompiler.rank().get_length() == 0 ||
321+
*shapeFromCompiler.begin() != utils::DEFAULT_BATCH_SIZE) {
322+
return false;
323+
}
324+
325+
if (!descriptor.isStateInput && !descriptor.isStateOutput && !descriptor.isShapeTensor) {
326+
if (shapeFromIRModel.is_dynamic() || shapeFromIRModel.rank().get_length() == 0 ||
327+
*shapeFromIRModel.begin() != candidateBatchSize) {
328+
return false;
329+
}
330+
}
331+
}
332+
333+
return true;
334+
};
335+
336+
if (!checkDescriptorsUseCandidateBatchSize(_metadata.inputs) ||
337+
!checkDescriptorsUseCandidateBatchSize(_metadata.outputs)) {
338+
_logger.debug("Batching on the plugin is not used, batching is handled by the compiler");
339+
return std::nullopt;
340+
}
341+
342+
_logger.debug("Batching is handled by the plugin");
343+
344+
return candidateBatchSize;
345+
}
346+
291347
const std::optional<std::size_t> Graph::get_batch_size() const {
292348
return _batchSize;
293349
}

src/plugins/intel_npu/src/plugin/src/plugin.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -708,7 +708,7 @@ std::shared_ptr<ov::ICompiledModel> Plugin::compile_model(const std::shared_ptr<
708708
}
709709

710710
std::optional<int64_t> batch = std::nullopt;
711-
if (originalBatch.has_value()) {
711+
if (originalBatch.has_value() && successfullyDebatched) {
712712
batch = originalBatch.value().is_static() ? originalBatch.value().get_length() : -1;
713713
if (batch > 0) {
714714
// Initial batch setup for static cases

src/plugins/intel_npu/src/plugin/src/transformations.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,13 +235,15 @@ std::tuple<std::shared_ptr<ov::Model>, bool> handlePluginBatching(
235235
logger.info("The model has been debatched successfully");
236236
successfullyDebatched = true;
237237
}
238+
if (batchModeIsAvailable) {
239+
// If we have successfully debatched the model on the PLUGIN side, we should
240+
// avoid repeating the same in the compiler by resetting the batch mode
241+
updateBatchMode(ov::intel_npu::BatchMode::COMPILER);
242+
}
238243
} catch (const std::exception& ex) {
239244
logger.info("Couldn't validate and reshape the model. Batching will be handled by compiler. Error: %s",
240245
ex.what());
241246
}
242-
if (batchModeIsAvailable) {
243-
updateBatchMode(ov::intel_npu::BatchMode::COMPILER);
244-
}
245247
return {reshapedModel, successfullyDebatched};
246248
}
247249

0 commit comments

Comments
 (0)