Skip to content

Commit 3d11a35

Browse files
Investigate refactoring opportunities for batch management in Plugin and Compiler - review
1 parent 60d5898 commit 3d11a35

File tree

18 files changed

+129
-117
lines changed

18 files changed

+129
-117
lines changed

src/plugins/intel_npu/src/backend/include/zero_infer_request.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ class ZeroInferRequest final : public SyncInferRequest {
2323
explicit ZeroInferRequest(const std::shared_ptr<ZeroInitStructsHolder>& initStructs,
2424
const std::shared_ptr<const ICompiledModel>& compiledModel,
2525
const Config& config);
26+
std::optional<size_t> determine_dynamic_batch_size(const IODescriptor& desc,
27+
const size_t index,
28+
const bool isInput,
29+
const std::shared_ptr<ov::ITensor>& tensor,
30+
const std::optional<size_t> batchSize);
2631

2732
ov::SoPtr<ov::ITensor> get_tensor(const ov::Output<const ov::Node>& port) const override;
2833
void set_tensor(const ov::Output<const ov::Node>& port, const ov::SoPtr<ov::ITensor>& tensor) override;

src/plugins/intel_npu/src/backend/src/zero_infer_request.cpp

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -73,16 +73,21 @@ void check_level_zero_attributes_match(const IODescriptor& ioDescriptor, const A
7373
}
7474
}
7575

76-
std::optional<size_t> determine_dynamic_batch_size(const IODescriptor& desc,
77-
const std::shared_ptr<ov::ITensor>& tensor,
78-
const std::optional<size_t> batchSize) {
76+
} // namespace
77+
78+
std::optional<size_t> ZeroInferRequest::determine_dynamic_batch_size(const IODescriptor& desc,
79+
const size_t index,
80+
const bool isInput,
81+
const std::shared_ptr<ov::ITensor>& tensor,
82+
const std::optional<size_t> batchSize) {
7983
if (tensor == nullptr && !batchSize.has_value()) {
8084
return std::nullopt;
8185
}
8286

83-
auto dynamicBatchFromIR = desc.shapeFromIRModel.has_value() && (*desc.shapeFromIRModel).size() &&
84-
(*desc.shapeFromIRModel)[intel_npu::utils::BATCH_AXIS].is_dynamic();
85-
if (!dynamicBatchFromIR) {
87+
auto batchFromModel = isInput ? _compiledModel->inputs()[index].get_partial_shape()[intel_npu::utils::BATCH_AXIS]
88+
: _compiledModel->outputs()[index].get_partial_shape()[intel_npu::utils::BATCH_AXIS];
89+
90+
if (!batchFromModel.is_dynamic()) {
8691
return std::nullopt;
8792
}
8893

@@ -97,8 +102,6 @@ std::optional<size_t> determine_dynamic_batch_size(const IODescriptor& desc,
97102
return tensor->get_shape()[intel_npu::utils::BATCH_AXIS];
98103
}
99104

100-
} // namespace
101-
102105
//------------------------------------------------------------------------------
103106
ZeroInferRequest::ZeroInferRequest(const std::shared_ptr<ZeroInitStructsHolder>& initStructs,
104107
const std::shared_ptr<const ICompiledModel>& compiledModel,
@@ -310,8 +313,11 @@ void ZeroInferRequest::set_tensor(const ov::Output<const ov::Node>& port, const
310313
return;
311314
}
312315

313-
auto batchSizeCandidate =
314-
determine_dynamic_batch_size(_metadata.inputs.at(foundPort.idx), tensor._ptr, std::nullopt);
316+
auto batchSizeCandidate = determine_dynamic_batch_size(_metadata.inputs.at(foundPort.idx),
317+
foundPort.idx,
318+
true,
319+
tensor._ptr,
320+
std::nullopt);
315321

316322
if (batchSizeCandidate.has_value()) {
317323
if (!_dynamicBatchValueChanged) {
@@ -351,8 +357,11 @@ void ZeroInferRequest::set_tensor(const ov::Output<const ov::Node>& port, const
351357
return;
352358
}
353359

354-
auto batchSizeCandidate =
355-
determine_dynamic_batch_size(_metadata.outputs.at(foundPort.idx), tensor._ptr, std::nullopt);
360+
auto batchSizeCandidate = determine_dynamic_batch_size(_metadata.outputs.at(foundPort.idx),
361+
foundPort.idx,
362+
false,
363+
tensor._ptr,
364+
std::nullopt);
356365

357366
if (batchSizeCandidate.has_value()) {
358367
if (!_dynamicBatchValueChanged) {
@@ -439,7 +448,8 @@ void ZeroInferRequest::set_tensors(const ov::Output<const ov::Node>& port,
439448

440449
_logger.debug("ZeroInferRequest::set_tensors: %zu", tensors.size());
441450

442-
auto batchSizeCandidate = determine_dynamic_batch_size(_metadata.inputs.at(foundPort.idx), nullptr, tensors.size());
451+
auto batchSizeCandidate =
452+
determine_dynamic_batch_size(_metadata.inputs.at(foundPort.idx), foundPort.idx, true, nullptr, tensors.size());
443453

444454
// Check if batch has been changed
445455
if (batchSizeCandidate.has_value()) {

src/plugins/intel_npu/src/common/include/intel_npu/common/icompiler_adapter.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ class ICompilerAdapter {
4242
* @return A wrapper over the corresponding L0 graph handles (multiple only if "initBlobs" has been provided). This
4343
* wrapper further details the compiled model and brings it in a state closer to execution.
4444
*/
45-
virtual std::shared_ptr<IGraph> parse(ov::Tensor mainBlob,
46-
const Config& config,
47-
std::optional<std::vector<ov::Tensor>> initBlobs = std::nullopt,
48-
const std::optional<std::shared_ptr<const ov::Model>>& model = std::nullopt,
49-
std::optional<int64_t> batchSize = std::nullopt) const = 0;
45+
virtual std::shared_ptr<IGraph> parse(
46+
ov::Tensor mainBlob,
47+
const Config& config,
48+
std::optional<std::vector<ov::Tensor>> initBlobs = std::nullopt,
49+
const std::optional<std::shared_ptr<const ov::Model>>& model = std::nullopt) const = 0;
5050

5151
virtual ov::SupportedOpsMap query(const std::shared_ptr<const ov::Model>& model, const Config& config) const = 0;
5252
virtual uint32_t get_version() const = 0;

src/plugins/intel_npu/src/common/include/intel_npu/common/igraph.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ class IGraph : public std::enable_shared_from_this<IGraph> {
3636

3737
virtual void set_argument_value(uint32_t argi, const void* argv) const = 0;
3838

39-
virtual void set_metadata(NetworkMetadata metadata) = 0;
40-
4139
virtual void initialize(const Config& config) = 0;
4240

4341
virtual ~IGraph() = default;

src/plugins/intel_npu/src/compiler_adapter/include/driver_compiler_adapter.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ class DriverCompilerAdapter final : public ICompilerAdapter {
2222

2323
std::shared_ptr<IGraph> compileWS(const std::shared_ptr<ov::Model>& model, const Config& config) const override;
2424

25-
std::shared_ptr<IGraph> parse(ov::Tensor mainBlob,
26-
const Config& config,
27-
std::optional<std::vector<ov::Tensor>> initBlobs = std::nullopt,
28-
const std::optional<std::shared_ptr<const ov::Model>>& model = std::nullopt,
29-
std::optional<int64_t> batchSize = std::nullopt) const override;
25+
std::shared_ptr<IGraph> parse(
26+
ov::Tensor mainBlob,
27+
const Config& config,
28+
std::optional<std::vector<ov::Tensor>> initBlobs = std::nullopt,
29+
const std::optional<std::shared_ptr<const ov::Model>>& model = std::nullopt) const override;
3030

3131
ov::SupportedOpsMap query(const std::shared_ptr<const ov::Model>& model, const Config& config) const override;
3232

src/plugins/intel_npu/src/compiler_adapter/include/graph.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@ class Graph : public IGraph {
3535

3636
void set_argument_value(uint32_t argi, const void* argv) const override;
3737

38-
void set_metadata(NetworkMetadata metadata) override;
39-
4038
void initialize(const Config& config) override;
4139

4240
const NetworkMetadata& get_metadata() const override;

src/plugins/intel_npu/src/compiler_adapter/include/plugin_compiler_adapter.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ class PluginCompilerAdapter final : public ICompilerAdapter {
2323

2424
std::shared_ptr<IGraph> compileWS(const std::shared_ptr<ov::Model>& model, const Config& config) const override;
2525

26-
std::shared_ptr<IGraph> parse(ov::Tensor mainBlob,
27-
const Config& config,
28-
std::optional<std::vector<ov::Tensor>> initBlobs = std::nullopt,
29-
const std::optional<std::shared_ptr<const ov::Model>>& model = std::nullopt,
30-
std::optional<int64_t> batchSize = std::nullopt) const override;
26+
std::shared_ptr<IGraph> parse(
27+
ov::Tensor mainBlob,
28+
const Config& config,
29+
std::optional<std::vector<ov::Tensor>> initBlobs = std::nullopt,
30+
const std::optional<std::shared_ptr<const ov::Model>>& model = std::nullopt) const override;
3131

3232
ov::SupportedOpsMap query(const std::shared_ptr<const ov::Model>& model, const Config& config) const override;
3333

src/plugins/intel_npu/src/compiler_adapter/include/ze_graph_ext_wrappers.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ class ZeGraphExtWrappers {
3939

4040
GraphDescriptor getGraphDescriptor(void* data, size_t size) const;
4141

42-
NetworkMetadata getNetworkMeta(GraphDescriptor& graphDescriptor,
43-
std::optional<int64_t> batchSize = std::nullopt) const;
42+
NetworkMetadata getNetworkMeta(GraphDescriptor& graphDescriptor) const;
4443

4544
void destroyGraph(GraphDescriptor& graphDescriptor);
4645

@@ -62,8 +61,7 @@ class ZeGraphExtWrappers {
6261
void getMetadata(ze_graph_handle_t graphHandle,
6362
uint32_t index,
6463
std::vector<IODescriptor>& inputs,
65-
std::vector<IODescriptor>& outputs,
66-
std::optional<int64_t> batchSize) const;
64+
std::vector<IODescriptor>& outputs) const;
6765

6866
void initializeGraphThroughCommandList(ze_graph_handle_t graphHandle, uint32_t commandQueueGroupOrdinal) const;
6967

src/plugins/intel_npu/src/compiler_adapter/src/driver_compiler_adapter.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -231,19 +231,19 @@ std::shared_ptr<IGraph> DriverCompilerAdapter::compileWS(const std::shared_ptr<o
231231
config);
232232
}
233233

234-
std::shared_ptr<IGraph> DriverCompilerAdapter::parse(ov::Tensor mainBlob,
235-
const Config& config,
236-
std::optional<std::vector<ov::Tensor>> initBlobs,
237-
const std::optional<std::shared_ptr<const ov::Model>>& model,
238-
std::optional<int64_t> batchSize) const {
234+
std::shared_ptr<IGraph> DriverCompilerAdapter::parse(
235+
ov::Tensor mainBlob,
236+
const Config& config,
237+
std::optional<std::vector<ov::Tensor>> initBlobs,
238+
const std::optional<std::shared_ptr<const ov::Model>>& model) const {
239239
OV_ITT_TASK_CHAIN(PARSE_BLOB, itt::domains::NPUPlugin, "DriverCompilerAdapter", "parse");
240240

241241
_logger.debug("parse start");
242242
auto mainGraphDesc = _zeGraphExt->getGraphDescriptor(mainBlob.data(), mainBlob.get_byte_size());
243243
_logger.debug("parse end");
244244

245245
OV_ITT_TASK_NEXT(PARSE_BLOB, "getNetworkMeta");
246-
auto networkMeta = _zeGraphExt->getNetworkMeta(mainGraphDesc, batchSize);
246+
auto networkMeta = _zeGraphExt->getNetworkMeta(mainGraphDesc);
247247

248248
// exporting the blob when we get it from cache or ov::hint::compiled_blob property
249249
// shall be available

src/plugins/intel_npu/src/compiler_adapter/src/graph.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,6 @@ Graph::Graph(const std::shared_ptr<ZeGraphExtWrappers>& zeGraphExt,
4242
}
4343
}
4444

45-
void Graph::set_metadata(NetworkMetadata metadata) {
46-
_metadata = metadata;
47-
}
48-
4945
const NetworkMetadata& Graph::get_metadata() const {
5046
return _metadata;
5147
}

0 commit comments

Comments
 (0)