@@ -232,6 +232,10 @@ void Graph::initialize(const Config& config) {
232232 // releasing it here to avoid unnecessary memory usage.
233233 _blobIsReleased = release_blob (config);
234234
235+ if (!_batchSize.has_value ()) {
236+ _batchSize = determine_batch_size ();
237+ }
238+
235239 if (_zeroInitStruct->getCommandQueueDdiTable ().version () < ZE_MAKE_VERSION (1 , 1 ) &&
236240 config.get <RUN_INFERENCES_SEQUENTIALLY>()) {
237241 auto numberOfCommandLists = _batchSize.has_value () ? *_batchSize : 1 ;
@@ -288,6 +292,58 @@ uint32_t Graph::get_last_submitted_id() const {
288292 return _lastSubmittedId;
289293}
290294
295+ std::optional<size_t > Graph::determine_batch_size () {
296+ if (!_metadata.outputs .at (0 ).shapeFromIRModel .has_value ()) {
297+ _logger.debug (" Batching on the plugin is not used, batching is handled by the compiler" );
298+ return std::nullopt ;
299+ }
300+
301+ const ov::PartialShape& firstShape = *_metadata.outputs .at (0 ).shapeFromIRModel ;
302+ if (firstShape.is_dynamic () || firstShape.rank ().get_length () == 0 ) {
303+ return std::nullopt ;
304+ }
305+
306+ const size_t candidateBatchSize = firstShape[utils::BATCH_AXIS].get_max_length ();
307+ if (candidateBatchSize == 0 || candidateBatchSize == utils::DEFAULT_BATCH_SIZE) {
308+ _logger.debug (" Batching on the plugin is not used, batching is handled by the compiler" );
309+ return std::nullopt ;
310+ }
311+
312+ auto checkDescriptorsUseCandidateBatchSize = [candidateBatchSize](const std::vector<IODescriptor>& descriptors) {
313+ for (const IODescriptor& descriptor : descriptors) {
314+ OPENVINO_ASSERT (descriptor.shapeFromIRModel .has_value (),
315+ " Missing value for the \" shapeFromIRModel\" attribute, I/O descriptor" );
316+
317+ const ov::PartialShape& shapeFromCompiler = descriptor.shapeFromCompiler ;
318+ const ov::PartialShape& shapeFromIRModel = *descriptor.shapeFromIRModel ;
319+
320+ if (shapeFromCompiler.is_dynamic () || shapeFromCompiler.rank ().get_length () == 0 ||
321+ *shapeFromCompiler.begin () != utils::DEFAULT_BATCH_SIZE) {
322+ return false ;
323+ }
324+
325+ if (!descriptor.isStateInput && !descriptor.isStateOutput && !descriptor.isShapeTensor ) {
326+ if (shapeFromIRModel.is_dynamic () || shapeFromIRModel.rank ().get_length () == 0 ||
327+ *shapeFromIRModel.begin () != candidateBatchSize) {
328+ return false ;
329+ }
330+ }
331+ }
332+
333+ return true ;
334+ };
335+
336+ if (!checkDescriptorsUseCandidateBatchSize (_metadata.inputs ) ||
337+ !checkDescriptorsUseCandidateBatchSize (_metadata.outputs )) {
338+ _logger.debug (" Batching on the plugin is not used, batching is handled by the compiler" );
339+ return std::nullopt ;
340+ }
341+
342+ _logger.debug (" Batching is handled by the plugin" );
343+
344+ return candidateBatchSize;
345+ }
346+
291347const std::optional<std::size_t > Graph::get_batch_size () const {
292348 return _batchSize;
293349}
0 commit comments