Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 75 additions & 9 deletions onnxscript/backend/onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,11 @@
return False


class Exporter:
class _Exporter:
"""Class used for recursive traversal of Proto structures."""

def __init__(
self, rename: bool, use_operators: bool = False, inline_const: bool = False
self, *, rename: bool, use_operators: bool, inline_const: bool, skip_initializers: bool
) -> None:
self.use_operators = use_operators
if rename:
Expand All @@ -266,6 +266,8 @@
# _name_remappings: used to undo the SSA-renaming in ONNX control-flow ops.
# We map the multiple SSA-variants back to the same Python variable name.
self._name_remappings: list[dict[str, str]] = []
self.skip_initializers = skip_initializers
self.skipped_initializers: dict[str, onnx.TensorProto] = {}

def _handle_attrname_conflict(self, renamer):
"""Add ref-attr-name-conflict handling logic to renaming function."""
Expand Down Expand Up @@ -338,6 +340,14 @@
code = []
if hasattr(graph, "initializer"):
for init in graph.initializer:
if self.skip_initializers:
init_py_name = self._translate_onnx_var(init.name)

Check warning on line 344 in onnxscript/backend/onnx_export.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export.py#L344

Added line #L344 was not covered by tests
if init_py_name in self.skipped_initializers:
raise RuntimeError(

Check warning on line 346 in onnxscript/backend/onnx_export.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export.py#L346

Added line #L346 was not covered by tests
f"Initializer {init.name!r} is already present in skipped_initializers."
)
self.skipped_initializers[init_py_name] = init
continue

Check warning on line 350 in onnxscript/backend/onnx_export.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export.py#L349-L350

Added lines #L349 - L350 were not covered by tests
node = make_node(
"Constant",
[],
Expand Down Expand Up @@ -684,15 +694,61 @@
def add(line: str) -> None:
result.append(line)

add("@script()")
add(f"def {function_name}{_translate_signature(graph.input, graph.output)}")
if self.skip_initializers:
indent_level = 2
indent = _SINGLE_INDENT

Check warning on line 699 in onnxscript/backend/onnx_export.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export.py#L698-L699

Added lines #L698 - L699 were not covered by tests
else:
indent_level = 1
indent = ""
add(f"{indent}@script()")
add(f"{indent}def {function_name}{_translate_signature(graph.input, graph.output)}")
indent = indent + _SINGLE_INDENT
doc = graph.doc_string
if doc:
add(f' """{doc}"""')
add(self._translate_graph_body(graph, opsets, indent=1))
add(f'{indent}"""{doc}"""')

Check warning on line 708 in onnxscript/backend/onnx_export.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export.py#L708

Added line #L708 was not covered by tests
add(self._translate_graph_body(graph, opsets, indent=indent_level))
return_values = ", ".join(self._translate_onnx_var(x) for x in graph.output)
add(f" return {return_values}")
return "\n".join(result)
add(f"{indent}return {return_values}")
script = "\n".join(result)
if self.skipped_initializers:
return self._substitute_initializers(script, function_name)

Check warning on line 714 in onnxscript/backend/onnx_export.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export.py#L714

Added line #L714 was not covered by tests
return script

def _substitute_initializers(self, script: str, script_function_name: str) -> str:
init_names = self.skipped_initializers.keys()

Check warning on line 718 in onnxscript/backend/onnx_export.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export.py#L718

Added line #L718 was not covered by tests
# Formal parameters representing initializers (single level indentation)
__ = _SINGLE_INDENT
initializers_as_params = "\n".join(f"{__}{x}," for x in init_names)

Check warning on line 721 in onnxscript/backend/onnx_export.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export.py#L720-L721

Added lines #L720 - L721 were not covered by tests

def generate_rand(name: str, value: TensorProto) -> str:
shape = ",".join(str(d) for d in value.dims)

Check warning on line 724 in onnxscript/backend/onnx_export.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export.py#L723-L724

Added lines #L723 - L724 were not covered by tests
if value.data_type != TensorProto.FLOAT:
raise NotImplementedError(

Check warning on line 726 in onnxscript/backend/onnx_export.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export.py#L726

Added line #L726 was not covered by tests
f"Unable to generate random initializer for data type {value.data_type}."
)
return f"{__}{name} = numpy.random.rand({shape}).astype(numpy.float32)"

Check warning on line 729 in onnxscript/backend/onnx_export.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export.py#L729

Added line #L729 was not covered by tests

random_initializer_values = "\n".join(

Check warning on line 731 in onnxscript/backend/onnx_export.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export.py#L731

Added line #L731 was not covered by tests
generate_rand(key, value) for key, value in self.skipped_initializers.items()
)
# Actual parameter values for initializers (double level indentation)
indented_initializers_as_params = "\n".join(f"{__}{__}{x}," for x in init_names)
return f"""

Check warning on line 736 in onnxscript/backend/onnx_export.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/backend/onnx_export.py#L735-L736

Added lines #L735 - L736 were not covered by tests
def make_model(
{initializers_as_params}
):
{script}
{__}model = {script_function_name}.to_model_proto()
{__}return model
def make_model_with_random_weights():
{random_initializer_values}
{__}model = make_model(
{indented_initializers_as_params}
{__})
{__}return model
"""

def _import_onnx_types(
self, proto: onnx.ModelProto | onnx.GraphProto | onnx.FunctionProto
Expand Down Expand Up @@ -778,9 +834,11 @@
def export2python(
model_onnx,
function_name: Optional[str] = None,
*,
rename: bool = False,
use_operators: bool = False,
inline_const: bool = False,
skip_initializers: bool = False,
):
"""Exports an ONNX model to the *python* syntax.
Expand All @@ -790,6 +848,9 @@
function_name: main function name
use_operators: use Python operators.
inline_const: replace ONNX constants inline if compact
skip_initializers: generated script will not include initializers.
Instead, a function that generates the model, given initializer values, is generated,
along with one that generates random values for the initializers.
Returns:
python code
Expand All @@ -815,5 +876,10 @@
if not isinstance(model_onnx, (ModelProto, FunctionProto)):
raise TypeError(f"The function expects a ModelProto not {type(model_onnx)!r}.")

exporter = Exporter(rename, use_operators, inline_const)
exporter = _Exporter(
rename=rename,
use_operators=use_operators,
inline_const=inline_const,
skip_initializers=skip_initializers,
)
return exporter.export(model_onnx, function_name)
29 changes: 29 additions & 0 deletions tools/onnx2external.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import argparse
import os

import onnx
import onnx.external_data_helper


def convert2external(input_file_name: str) -> None:
dir_name = os.path.dirname(input_file_name)
base_name, _suffix = os.path.splitext(os.path.basename(input_file_name))
model = onnx.load(input_file_name)
os.makedirs(os.path.join(dir_name, base_name), exist_ok=True)
onnx.external_data_helper.convert_model_to_external_data(
model, location="external_data.onnx", size_threshold=128
)
onnx.save(model, os.path.join(dir_name, base_name, "model.onnx"))


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Convert ONNX model file to external data format"
)
parser.add_argument("input", help="ONNX model file to convert")
args = parser.parse_args()

convert2external(args.input)
16 changes: 13 additions & 3 deletions tools/onnx2script.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@


def convert2script(
input_file_name: str, output_file_name: Optional[str], verbose: bool
input_file_name: str, output_file_name: Optional[str], verbose: bool, initializers: bool
) -> None:
model = onnx.load(input_file_name, load_external_data=False)
python_code = onnxscript.proto2python(
model, use_operators=not verbose, inline_const=not verbose
model,
use_operators=not verbose,
inline_const=not verbose,
skip_initializers=not initializers,
)

# If output file name is not provided, use the input file name with .py extension
Expand All @@ -55,6 +58,13 @@ def convert2script(
help="Verbose mode, suppresses use of overloaded operators and inline constants",
default=False,
)
parser.add_argument(
"-i",
"--initializers",
action="store_true",
help="Include initializers in the generated script",
default=False,
)

args = parser.parse_args()
convert2script(args.input, args.output, args.verbose)
convert2script(args.input, args.output, args.verbose, args.initializers)
Loading