From ebe03a816613fa5b4786669eacbe386da3980663 Mon Sep 17 00:00:00 2001 From: Stefan Dobrev Date: Wed, 16 Nov 2022 17:52:33 +0200 Subject: [PATCH] Add support for explicit batch for all models --- c/tensorNet.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/c/tensorNet.cpp b/c/tensorNet.cpp index 8d6cc76b4..f3548b0da 100644 --- a/c/tensorNet.cpp +++ b/c/tensorNet.cpp @@ -621,7 +621,7 @@ bool tensorNet::ProfileModel(const std::string& deployFile, // name for caf //parser->destroy(); } #if NV_TENSORRT_MAJOR >= 5 - else if( mModelType == MODEL_ONNX ) + else if( inputDims[0].nbDims == 4 ) { #if NV_TENSORRT_MAJOR >= 7 network->destroy(); @@ -714,7 +714,7 @@ bool tensorNet::ProfileModel(const std::string& deployFile, // name for caf nvinfer1::Dims dims = network->getInput(i)->getDimensions(); #if NV_TENSORRT_MAJOR >= 7 - if( mModelType == MODEL_ONNX ) + if( dims.nbDims == 4 ) dims = shiftDims(dims); // change NCHW to CHW for EXPLICIT_BATCH #endif @@ -1100,15 +1100,15 @@ bool tensorNet::LoadNetwork( const char* prototxt_path_, const char* model_path_ } else if( model_fmt == MODEL_ENGINE ) { + mModelType = model_fmt; + mModelPath = model_path; + if( !LoadEngine(model_path.c_str(), input_blobs, output_blobs, NULL, device, stream) ) { LogError(LOG_TRT "failed to load %s\n", model_path.c_str()); return false; } - mModelType = model_fmt; - mModelPath = model_path; - LogSuccess(LOG_TRT "device %s, initialized %s\n", deviceTypeToStr(device), mModelPath.c_str()); return true; } @@ -1385,7 +1385,7 @@ bool tensorNet::LoadEngine( nvinfer1::ICudaEngine* engine, nvinfer1::Dims inputDims = validateDims(engine->getBindingDimensions(inputIndex)); #if NV_TENSORRT_MAJOR >= 7 - if( mModelType == MODEL_ONNX ) + if( inputDims.nbDims == 4 ) inputDims = shiftDims(inputDims); // change NCHW to CHW if EXPLICIT_BATCH set #endif #else @@ -1449,7 +1449,7 @@ bool tensorNet::LoadEngine( nvinfer1::ICudaEngine* engine, nvinfer1::Dims outputDims = validateDims(engine->getBindingDimensions(outputIndex)); #if NV_TENSORRT_MAJOR >= 7 - if( mModelType == MODEL_ONNX ) + if( outputDims.nbDims == 4 ) outputDims = shiftDims(outputDims); // change NCHW to CHW if EXPLICIT_BATCH set #endif #else