Skip to content

Commit e8005e9

Browse files
authored
[torch api] Support down conversion of opsets (#2503)
Starting from PyTorch 2.9, down conversion is turned on and supported. --------- Signed-off-by: Justin Chu <[email protected]>
1 parent ae4c668 commit e8005e9

File tree

3 files changed

+43
-1
lines changed

3 files changed

+43
-1
lines changed

onnxscript/_framework_apis/torch_2_8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3-
"""Stable APIs for PyTorch 2.7."""
3+
"""Stable APIs for PyTorch 2.8."""
44

55
from __future__ import annotations
66

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Stable APIs for PyTorch 2.9."""
4+
5+
from __future__ import annotations
6+
7+
__all__ = [
8+
"check_model",
9+
"convert_version",
10+
"get_torchlib_ops",
11+
"optimize",
12+
"save_model_with_external_data",
13+
]
14+
15+
from typing import TYPE_CHECKING
16+
17+
from onnxscript import version_converter
18+
from onnxscript._framework_apis.torch_2_8 import (
19+
check_model,
20+
get_torchlib_ops,
21+
optimize,
22+
save_model_with_external_data,
23+
)
24+
25+
if TYPE_CHECKING:
26+
import onnx_ir as ir
27+
28+
29+
def convert_version(model: ir.Model, target_version: int) -> ir.Model:
30+
"""Convert the model to the specified ONNX opset version.
31+
32+
Starting from PyTorch 2.9, down conversion is turned on and supported.
33+
"""
34+
version_converter.convert_version(model, target_version, fallback=True)
35+
return model

onnxscript/version_converter/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,13 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
107107
self.target_version,
108108
)
109109
return ir.passes.PassResult(model, False)
110+
else:
111+
logger.warning(
112+
"The model version conversion is not supported by the onnxscript version converter "
113+
"and fallback is enabled. The model will be converted using the onnx C API "
114+
"(target version: %d).",
115+
self.target_version,
116+
)
110117

111118
# If the onnxscript version converter does not support the conversion,
112119
# we can use the onnx C API to convert the model

0 commit comments

Comments
 (0)