Skip to content

Commit 282de22

Browse files
committed
API change: Add option to choose ouput names when exporting using ONNX
1 parent bd0cf9f commit 282de22

File tree

4 files changed

+518
-14
lines changed

4 files changed

+518
-14
lines changed

model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
from typing import Callable
15+
from typing import Callable, Optional, List
1616
from io import BytesIO
1717

1818
import torch.nn
@@ -65,11 +65,14 @@ def __init__(self,
6565
self._use_onnx_custom_quantizer_ops = use_onnx_custom_quantizer_ops
6666
self._onnx_opset_version = onnx_opset_version
6767

68-
def export(self, output_names=None) -> None:
68+
def export(self, output_names: Optional[List[str]] = None) -> None:
6969
"""
7070
Convert an exportable (fully-quantized) PyTorch model to a fakely-quant model
7171
(namely, weights that are in fake-quant format) and fake-quant layers for the activations.
7272
73+
Args:
74+
output_names (Optional[List[str]]): Optional list of output node names for export compatibility.
75+
7376
Returns:
7477
Fake-quant PyTorch model.
7578
"""
@@ -131,6 +134,8 @@ def export(self, output_names=None) -> None:
131134
output_names = ['output']
132135
dynamic_axes.update({'output': {0: 'batch_size'}})
133136
else:
137+
assert isinstance(output_names, list), \
138+
f"`output_names` must be a list, but got {type(output_names).__name__}"
134139
if isinstance(model_output, (list, tuple)):
135140
num_of_outputs = len(model_output)
136141
else:

model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(self,
4949
save_model_path,
5050
repr_dataset)
5151

52-
def export(self) -> None:
52+
def export(self, output_names=None) -> None:
5353
"""
5454
Convert an exportable (fully-quantized) PyTorch model to a fakely-quant model
5555
(namely, weights that are in fake-quant format) and fake-quant layers for the activations.

model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
from typing import Callable
15+
from typing import Callable, Optional, List
1616
from packaging import version
1717

1818
from model_compression_toolkit.verify_packages import FOUND_TORCH
@@ -47,7 +47,8 @@ def pytorch_export_model(model: torch.nn.Module,
4747
is_layer_exportable_fn: Callable = is_pytorch_layer_exportable,
4848
serialization_format: PytorchExportSerializationFormat = PytorchExportSerializationFormat.ONNX,
4949
quantization_format: QuantizationFormat = QuantizationFormat.MCTQ,
50-
onnx_opset_version=DEFAULT_ONNX_OPSET_VERSION) -> None:
50+
onnx_opset_version: int = DEFAULT_ONNX_OPSET_VERSION,
51+
output_names: Optional[List[str]] = None) -> None:
5152
"""
5253
Export a PyTorch quantized model to a torchscript or onnx model.
5354
The model will be saved to the path in save_model_path.
@@ -57,19 +58,25 @@ def pytorch_export_model(model: torch.nn.Module,
5758
(where the model will be saved to ONNX model).
5859
5960
Args:
60-
model: Model to export.
61-
save_model_path: Path to save the model.
62-
repr_dataset: Representative dataset for tracing the pytorch model (mandatory for exporting it).
63-
is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
64-
serialization_format: Format to export the model according to (by default
65-
PytorchExportSerializationFormat.ONNX).
66-
quantization_format: Format of how quantizers are exported (fakely-quant, int8, MCTQ quantizers).
67-
onnx_opset_version: ONNX opset version to use for exported ONNX model.
61+
model (Module): Model to export.
62+
save_model_path (str): Path to save the model.
63+
repr_dataset (Callable): Representative dataset for tracing the pytorch model (mandatory for exporting it).
64+
is_layer_exportable_fn (Callable): Callable to check whether a layer can be exported or not.
65+
serialization_format (PytorchExportSerializationFormat): Format to export the model according to (by default PytorchExportSerializationFormat.ONNX).
66+
quantization_format (QuantizationFormat): Format of how quantizers are exported (fakely-quant, int8, MCTQ quantizers).
67+
onnx_opset_version (int): ONNX opset version to use for exported ONNX model.
68+
output_names (Optional[List[str]]): Optional list of output node names for export compatibility. This argument is relevant only when using PytorchExportSerializationFormat.ONNX.
6869
6970
"""
7071
# Ensure 'metadata' is available directly on the model, if present in submodules
7172
find_and_assign_metadata_attr(model)
7273

74+
if output_names is not None and serialization_format != PytorchExportSerializationFormat.ONNX:
75+
Logger.warning(
76+
f'`output_names` is only applicable when exporting to ONNX. '
77+
f'Current serialization format is {serialization_format}, so `output_names` will be ignored.'
78+
) # pragma: no cover
79+
7380
if serialization_format == PytorchExportSerializationFormat.TORCHSCRIPT:
7481
if quantization_format in supported_serialization_quantization_export_dict[serialization_format]:
7582
exporter = FakelyQuantTorchScriptPyTorchExporter(model,
@@ -107,7 +114,7 @@ def pytorch_export_model(model: torch.nn.Module,
107114
f'Unsupported serialization {serialization_format} was used to export Pytorch model.'
108115
f' Please see API for supported formats.') # pragma: no cover
109116

110-
exporter.export()
117+
exporter.export(output_names=output_names)
111118

112119
else:
113120
def pytorch_export_model(*args, **kwargs):

0 commit comments

Comments
 (0)