Skip to content

Commit 37b11fc

Browse files
[API] Create stable APIs for PyTorch 2.6 (#1896)
- optimize is turned on. It will be controlled by an option in PyTorch - Remove the `_TORCH_ONNX_SAVE_EXTERNAL_DATA_WITH_IR` flag Co-authored-by: Ti-Tai Wang <[email protected]>
1 parent 1426e9f commit 37b11fc

File tree

2 files changed

+52
-46
lines changed

2 files changed

+52
-46
lines changed

onnxscript/_framework_apis/torch_2_5.py

Lines changed: 26 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,10 @@
1717
import pathlib
1818
from typing import Callable
1919

20-
import onnx
21-
2220
from onnxscript import ir, optimizer
2321
from onnxscript.function_libs.torch_lib import registration
2422
from onnxscript.ir import _external_data
2523

26-
# Internal flag. Will go away.
27-
_TORCH_ONNX_SAVE_EXTERNAL_DATA_WITH_IR = (
28-
os.getenv("TORCH_ONNX_OFFLOAD_EXTERNAL_DATA_WITH_IR") != "0"
29-
)
30-
3124

3225
@dataclasses.dataclass(frozen=True)
3326
class _OnnxFunctionMeta:
@@ -83,45 +76,32 @@ def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike
8376
"""Save the model with external data. The model is unchanged after saving."""
8477

8578
# TODO(#1835): Decide if we want to externalize large attributes as well
86-
if _TORCH_ONNX_SAVE_EXTERNAL_DATA_WITH_IR:
87-
initializer_values = tuple(model.graph.initializers.values())
88-
tensors = [v.const_value for v in initializer_values]
89-
for tensor in tensors:
90-
if tensor is None:
91-
raise ValueError(
92-
"The model contains uninitialized initializer values. "
93-
"Please make sure all initializer values are initialized."
94-
)
95-
destination_path = pathlib.Path(model_path)
96-
base_dir = destination_path.parent
97-
data_path = f"{destination_path.name}.data"
98-
99-
external_tensors = _external_data.convert_tensors_to_external(
100-
tensors, # type: ignore[arg-type]
101-
base_dir,
102-
data_path,
103-
)
104-
105-
# Replace the initializer values with external tensors and save the model
106-
for initializer, external_tensor in zip(initializer_values, external_tensors):
107-
initializer.const_value = external_tensor
108-
ir.save(model, model_path)
109-
110-
# Restore the original initializer values so the model is unchanged
111-
for initializer, tensor in zip(initializer_values, tensors):
112-
initializer.const_value = tensor
113-
114-
else:
115-
destination_path = pathlib.Path(model_path)
116-
# Create the directory if it does not exist
117-
data_path = f"{destination_path.name}.data"
118-
proto = ir.serde.serialize_model(model)
119-
onnx.save_model(
120-
proto,
121-
model_path,
122-
save_as_external_data=True,
123-
location=data_path,
124-
)
79+
initializer_values = tuple(model.graph.initializers.values())
80+
tensors = [v.const_value for v in initializer_values]
81+
for tensor in tensors:
82+
if tensor is None:
83+
raise ValueError(
84+
"The model contains uninitialized initializer values. "
85+
"Please make sure all initializer values are initialized."
86+
)
87+
destination_path = pathlib.Path(model_path)
88+
base_dir = destination_path.parent
89+
data_path = f"{destination_path.name}.data"
90+
91+
external_tensors = _external_data.convert_tensors_to_external(
92+
tensors, # type: ignore[arg-type]
93+
base_dir,
94+
data_path,
95+
)
96+
97+
# Replace the initializer values with external tensors and save the model
98+
for initializer, external_tensor in zip(initializer_values, external_tensors):
99+
initializer.const_value = external_tensor
100+
ir.save(model, model_path)
101+
102+
# Restore the original initializer values so the model is unchanged
103+
for initializer, tensor in zip(initializer_values, tensors):
104+
initializer.const_value = tensor
125105

126106

127107
def get_torchlib_ops() -> list[_OnnxFunctionMeta]:
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Stable APIs for PyTorch 2.6."""
4+
5+
from __future__ import annotations
6+
7+
__all__ = [
8+
"check_model",
9+
"convert_version",
10+
"get_torchlib_ops",
11+
"optimize",
12+
"save_model_with_external_data",
13+
]
14+
from onnxscript import ir, optimizer
15+
from onnxscript._framework_apis.torch_2_5 import (
16+
check_model,
17+
convert_version,
18+
get_torchlib_ops,
19+
save_model_with_external_data,
20+
)
21+
22+
23+
def optimize(model: ir.Model) -> ir.Model:
24+
"""Optimize the model."""
25+
optimizer.optimize_ir(model)
26+
return model

0 commit comments

Comments
 (0)