diff --git a/onnx_tensorrt/backend.py b/onnx_tensorrt/backend.py index 511dd42..3288cb4 100644 --- a/onnx_tensorrt/backend.py +++ b/onnx_tensorrt/backend.py @@ -34,13 +34,13 @@ def count_trailing_ones(vals): class TensorRTBackendRep(BackendRep): def __init__(self, model, device, - max_workspace_size=None, serialize_engine=False, verbose=False, **kwargs): + max_workspace_size=None, serialize_engine=False, verbose=False, explicit_batchsize=False,**kwargs): if not isinstance(device, Device): device = Device(device) self._set_device(device) self._logger = TRT_LOGGER self.builder = trt.Builder(self._logger) - self.network = self.builder.create_network(flags=0) + self.network = self.builder.create_network(flags=int(explicit_batchsize)) self.parser = trt.OnnxParser(self.network, self._logger) self.config = self.builder.create_builder_config() self.serialize_engine = serialize_engine