Skip to content

Commit 2b497fb

Browse files
authored
Merge branch 'main' into rama/const-prop
2 parents 18862b9 + 12f9209 commit 2b497fb

File tree

9 files changed

+98
-99
lines changed

9 files changed

+98
-99
lines changed

onnxscript/_framework_apis/torch_2_5.py

Lines changed: 26 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,10 @@
1717
import pathlib
1818
from typing import Callable
1919

20-
import onnx
21-
2220
from onnxscript import ir, optimizer
2321
from onnxscript.function_libs.torch_lib import registration
2422
from onnxscript.ir import _external_data
2523

26-
# Internal flag. Will go away.
27-
_TORCH_ONNX_SAVE_EXTERNAL_DATA_WITH_IR = (
28-
os.getenv("TORCH_ONNX_OFFLOAD_EXTERNAL_DATA_WITH_IR") != "0"
29-
)
30-
3124

3225
@dataclasses.dataclass(frozen=True)
3326
class _OnnxFunctionMeta:
@@ -83,45 +76,32 @@ def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike
8376
"""Save the model with external data. The model is unchanged after saving."""
8477

8578
# TODO(#1835): Decide if we want to externalize large attributes as well
86-
if _TORCH_ONNX_SAVE_EXTERNAL_DATA_WITH_IR:
87-
initializer_values = tuple(model.graph.initializers.values())
88-
tensors = [v.const_value for v in initializer_values]
89-
for tensor in tensors:
90-
if tensor is None:
91-
raise ValueError(
92-
"The model contains uninitialized initializer values. "
93-
"Please make sure all initializer values are initialized."
94-
)
95-
destination_path = pathlib.Path(model_path)
96-
base_dir = destination_path.parent
97-
data_path = f"{destination_path.name}.data"
98-
99-
external_tensors = _external_data.convert_tensors_to_external(
100-
tensors, # type: ignore[arg-type]
101-
base_dir,
102-
data_path,
103-
)
104-
105-
# Replace the initializer values with external tensors and save the model
106-
for initializer, external_tensor in zip(initializer_values, external_tensors):
107-
initializer.const_value = external_tensor
108-
ir.save(model, model_path)
109-
110-
# Restore the original initializer values so the model is unchanged
111-
for initializer, tensor in zip(initializer_values, tensors):
112-
initializer.const_value = tensor
113-
114-
else:
115-
destination_path = pathlib.Path(model_path)
116-
# Create the directory if it does not exist
117-
data_path = f"{destination_path.name}.data"
118-
proto = ir.serde.serialize_model(model)
119-
onnx.save_model(
120-
proto,
121-
model_path,
122-
save_as_external_data=True,
123-
location=data_path,
124-
)
79+
initializer_values = tuple(model.graph.initializers.values())
80+
tensors = [v.const_value for v in initializer_values]
81+
for tensor in tensors:
82+
if tensor is None:
83+
raise ValueError(
84+
"The model contains uninitialized initializer values. "
85+
"Please make sure all initializer values are initialized."
86+
)
87+
destination_path = pathlib.Path(model_path)
88+
base_dir = destination_path.parent
89+
data_path = f"{destination_path.name}.data"
90+
91+
external_tensors = _external_data.convert_tensors_to_external(
92+
tensors, # type: ignore[arg-type]
93+
base_dir,
94+
data_path,
95+
)
96+
97+
# Replace the initializer values with external tensors and save the model
98+
for initializer, external_tensor in zip(initializer_values, external_tensors):
99+
initializer.const_value = external_tensor
100+
ir.save(model, model_path)
101+
102+
# Restore the original initializer values so the model is unchanged
103+
for initializer, tensor in zip(initializer_values, tensors):
104+
initializer.const_value = tensor
125105

126106

127107
def get_torchlib_ops() -> list[_OnnxFunctionMeta]:
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Stable APIs for PyTorch 2.6."""
4+
5+
from __future__ import annotations
6+
7+
__all__ = [
8+
"check_model",
9+
"convert_version",
10+
"get_torchlib_ops",
11+
"optimize",
12+
"save_model_with_external_data",
13+
]
14+
from onnxscript import ir, optimizer
15+
from onnxscript._framework_apis.torch_2_5 import (
16+
check_model,
17+
convert_version,
18+
get_torchlib_ops,
19+
save_model_with_external_data,
20+
)
21+
22+
23+
def optimize(model: ir.Model) -> ir.Model:
24+
"""Optimize the model."""
25+
optimizer.optimize_ir(model)
26+
return model

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 25 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
RealType,
4040
TFloat,
4141
TFloatHighPrecision,
42-
TFloatOrBFloat16,
4342
TInt,
4443
TReal,
4544
TRealOrUInt8,
@@ -2031,12 +2030,6 @@ def aten_convolution(
20312030
stride = (stride, stride)
20322031
strides = list(stride)
20332032

2034-
if bias is None:
2035-
weight_dim_0 = op.Shape(weight, start=0, end=1)
2036-
bias_shape = op.Expand(weight_dim_0, op.Constant(value_ints=[1]))
2037-
zero = op.CastLike(0.0, input)
2038-
bias = op.Expand(zero, bias_shape)
2039-
20402033
result = _aten_convolution_onnx(
20412034
input,
20422035
weight,
@@ -3564,14 +3557,14 @@ def aten_flipud(self: TensorType) -> TensorType:
35643557

35653558

35663559
@torch_op("aten::floor", traceable=True)
3567-
def aten_floor(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
3560+
def aten_floor(self: TFloat) -> TFloat:
35683561
"""floor(Tensor self) -> Tensor"""
35693562

35703563
return op.Floor(self)
35713564

35723565

35733566
@torch_op("math::floor", traceable=True)
3574-
def python_math_floor(self: TFloatOrBFloat16) -> TInt:
3567+
def python_math_floor(self: TFloat) -> TInt:
35753568
"""floor(Tensor self) -> Tensor"""
35763569
floor = op.Floor(self)
35773570
return op.Cast(floor, to=INT64.dtype)
@@ -4533,7 +4526,7 @@ def aten_isfinite(self: TFloatHighPrecision) -> BOOL:
45334526

45344527

45354528
@torch_op("aten::isinf")
4536-
def aten_isinf(self: TFloatOrBFloat16) -> BOOL:
4529+
def aten_isinf(self: TFloat) -> BOOL:
45374530
"""isinf(Tensor self) -> Tensor"""
45384531

45394532
# Added Cast inside the function so it can support all real dtypes naturally
@@ -4542,14 +4535,14 @@ def aten_isinf(self: TFloatOrBFloat16) -> BOOL:
45424535

45434536

45444537
@torch_op("aten::isnan")
4545-
def aten_isnan(self: TFloatOrBFloat16) -> BOOL:
4538+
def aten_isnan(self: TFloat) -> BOOL:
45464539
"""isnan(Tensor self) -> Tensor"""
45474540

45484541
return op.IsNaN(self)
45494542

45504543

45514544
@torch_op("aten::isneginf")
4552-
def aten_isneginf(self: TFloatOrBFloat16) -> BOOL:
4545+
def aten_isneginf(self: TFloat) -> BOOL:
45534546
"""isneginf(Tensor self) -> Tensor"""
45544547

45554548
# Added Cast inside the function so it can support all real dtypes naturally
@@ -4558,7 +4551,7 @@ def aten_isneginf(self: TFloatOrBFloat16) -> BOOL:
45584551

45594552

45604553
@torch_op("aten::isposinf")
4561-
def aten_isposinf(self: TFloatOrBFloat16) -> BOOL:
4554+
def aten_isposinf(self: TFloat) -> BOOL:
45624555
"""isposinf(Tensor self) -> Tensor"""
45634556

45644557
# Added Cast inside the function so it can support all real dtypes naturally
@@ -4778,42 +4771,42 @@ def aten_linspace(
47784771

47794772

47804773
@torch_op("aten::log", traceable=True)
4781-
def aten_log(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
4774+
def aten_log(self: TFloat) -> TFloat:
47824775
"""log(Tensor self) -> Tensor"""
47834776

47844777
return op.Log(self)
47854778

47864779

47874780
@torch_op("aten::log10", traceable=True)
4788-
def aten_log10(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
4781+
def aten_log10(self: TFloat) -> TFloat:
47894782
"""log10(Tensor self) -> Tensor"""
47904783

47914784
return op.Div(op.Log(self), op.CastLike(op.Log(10.0), self))
47924785

47934786

47944787
@torch_op("aten::log1p")
4795-
def aten_log1p(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
4788+
def aten_log1p(self: TFloat) -> TFloat:
47964789
"""log1p(Tensor self) -> Tensor"""
47974790

47984791
return op.Log(op.Add(self, 1.0))
47994792

48004793

48014794
@torch_op("aten::log2", traceable=True)
4802-
def aten_log2(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
4795+
def aten_log2(self: TFloat) -> TFloat:
48034796
"""log2(Tensor self) -> Tensor"""
48044797

48054798
return op.Div(op.Log(self), op.CastLike(op.Log(2.0), self))
48064799

48074800

48084801
@torch_op("aten::logaddexp", traceable=True)
4809-
def aten_logaddexp(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
4802+
def aten_logaddexp(self: TFloat, other: TFloat) -> TFloat:
48104803
"""logaddexp(Tensor self, Tensor other) -> Tensor"""
48114804

48124805
return op.Log(op.Add(op.Exp(self), op.Exp(other)))
48134806

48144807

48154808
@torch_op("aten::logaddexp2", traceable=True)
4816-
def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
4809+
def aten_logaddexp2(self: TFloat, other: TFloat) -> TFloat:
48174810
"""logaddexp2(Tensor self, Tensor other) -> Tensor"""
48184811
two = op.CastLike(2.0, self)
48194812
summation = op.Add(op.Pow(two, self), op.Pow(two, other))
@@ -4822,7 +4815,7 @@ def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOr
48224815

48234816

48244817
@torch_op("aten::logcumsumexp", traceable=True)
4825-
def aten_logcumsumexp(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16:
4818+
def aten_logcumsumexp(self: TFloat, dim: int) -> TFloat:
48264819
"""logcumsumexp(Tensor self, int dim) -> Tensor"""
48274820

48284821
if IsScalar(self):
@@ -4908,12 +4901,12 @@ def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL:
49084901

49094902

49104903
@torch_op("aten::logit", private=True)
4911-
def _aten_logit_onnx(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
4904+
def _aten_logit_onnx(self: TFloat) -> TFloat:
49124905
return op.Log(op.Div(self, op.Sub(1.0, self)))
49134906

49144907

49154908
@torch_op("aten::logit", private=True)
4916-
def _aten_logit_clamp_onnx(self: TFloatOrBFloat16, eps: float) -> TFloatOrBFloat16:
4909+
def _aten_logit_clamp_onnx(self: TFloat, eps: float) -> TFloat:
49174910
eps = op.CastLike(eps, self)
49184911
one = op.CastLike(1.0, self)
49194912
temporary_self = op.Where(self <= one - eps, self, one - eps)
@@ -4923,7 +4916,7 @@ def _aten_logit_clamp_onnx(self: TFloatOrBFloat16, eps: float) -> TFloatOrBFloat
49234916

49244917

49254918
@torch_op("aten::logit", trace_only=True)
4926-
def aten_logit(self: TFloatOrBFloat16, eps: Optional[float] = None) -> TFloatOrBFloat16:
4919+
def aten_logit(self: TFloat, eps: Optional[float] = None) -> TFloat:
49274920
"""logit(Tensor self, float? eps=None) -> Tensor"""
49284921
if eps is None:
49294922
return _aten_logit_onnx(self)
@@ -6041,9 +6034,7 @@ def aten_native_channel_shuffle(self: TensorType, groups: int) -> TensorType:
60416034

60426035

60436036
@torch_op("aten::native_dropout", trace_only=True)
6044-
def aten_native_dropout(
6045-
input: TFloatOrBFloat16, p: float, train: bool = True
6046-
) -> Tuple[TFloatOrBFloat16, BOOL]:
6037+
def aten_native_dropout(input: TFloat, p: float, train: bool = True) -> Tuple[TFloat, BOOL]:
60476038
"""native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)"""
60486039

60496040
result, mask = op.Dropout(input, p, train)
@@ -7055,7 +7046,7 @@ def aten_real(self: TensorType) -> TensorType:
70557046

70567047

70577048
@torch_op("aten::reciprocal")
7058-
def aten_reciprocal(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
7049+
def aten_reciprocal(self: TFloat) -> TFloat:
70597050
"""reciprocal(Tensor self) -> Tensor"""
70607051

70617052
return op.Reciprocal(self)
@@ -7074,7 +7065,7 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType:
70747065

70757066

70767067
@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"))
7077-
def aten_remainder(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
7068+
def aten_remainder(self: TFloat, other: TFloat) -> TFloat:
70787069
"""remainder.Tensor(Tensor self, Tensor other) -> Tensor"""
70797070

70807071
# TODO(justinchuby): Improve fp16 precision by following the logic in
@@ -7355,7 +7346,7 @@ def aten_rrelu(
73557346

73567347

73577348
@torch_op("aten::rsqrt", traceable=True)
7358-
def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
7349+
def aten_rsqrt(self: TFloat) -> TFloat:
73597350
"""rsqrt(Tensor self) -> Tensor"""
73607351

73617352
return op.Reciprocal(op.Sqrt(self))
@@ -7562,7 +7553,7 @@ def aten_sgn(self: TensorType) -> TensorType:
75627553

75637554

75647555
@torch_op("aten::sigmoid", traceable=True)
7565-
def aten_sigmoid(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
7556+
def aten_sigmoid(self: TFloat) -> TFloat:
75667557
"""sigmoid(Tensor self) -> Tensor"""
75677558

75687559
return op.Sigmoid(self)
@@ -7724,7 +7715,7 @@ def aten_smm(self: TensorType, mat2: TensorType) -> TensorType:
77247715

77257716

77267717
@torch_op(("aten::softmax.int", "aten::special_softmax"), trace_only=True)
7727-
def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrBFloat16:
7718+
def aten_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat:
77287719
"""softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"""
77297720

77307721
self_is_scalar = IsScalar(self)
@@ -7741,7 +7732,7 @@ def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrB
77417732

77427733

77437734
@torch_op(("aten::softmax.int", "aten::special_softmax"), traceable=True)
7744-
def aten_softmax_no_dtype(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16:
7735+
def aten_softmax_no_dtype(self: TFloat, dim: int) -> TFloat:
77457736
"""softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"""
77467737

77477738
self_is_scalar = IsScalar(self)
@@ -7812,7 +7803,7 @@ def aten_split_with_sizes_copy(
78127803

78137804

78147805
@torch_op("aten::sqrt", traceable=True)
7815-
def aten_sqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
7806+
def aten_sqrt(self: TFloat) -> TFloat:
78167807
"""sqrt(Tensor self) -> Tensor"""
78177808

78187809
return op.Sqrt(self)
@@ -8402,7 +8393,7 @@ def aten_triu_indices(row: int, col: int, offset: int = 0) -> TensorType:
84028393

84038394

84048395
@torch_op("aten::trunc")
8405-
def aten_trunc(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
8396+
def aten_trunc(self: TFloat) -> TFloat:
84068397
"""trunc(Tensor self) -> Tensor"""
84078398

84088399
# Reference https://github.com/onnx/onnx/issues/4588#issuecomment-1463970126

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from onnxscript.function_libs.torch_lib.tensor_typing import (
2626
IntType,
2727
TFloat,
28-
TFloatOrBFloat16,
2928
TFloatOrUInt8,
3029
TInt,
3130
TReal,
@@ -364,13 +363,13 @@ def aten_conv_depthwise3d(
364363

365364
@torch_op("aten::cross_entropy_loss", traceable=True)
366365
def aten_cross_entropy_loss(
367-
self: TFloatOrBFloat16,
366+
self: TFloat,
368367
target: IntType,
369-
weight: Optional[TFloatOrBFloat16] = None,
368+
weight: Optional[TFloat] = None,
370369
reduction: int = 1, # default is 'mean'
371370
ignore_index: int = -100,
372371
label_smoothing: float = 0.0, # this was ignored due to ONNX not support
373-
) -> TFloatOrBFloat16:
372+
) -> TFloat:
374373
"""cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor"""
375374

376375
if reduction == 0: # "none"
@@ -812,7 +811,7 @@ def aten_l1_loss(self: TensorType, target: TensorType, reduction: int = 1) -> Te
812811

813812

814813
@torch_op("aten::leaky_relu")
815-
def aten_leaky_relu(self: TFloatOrBFloat16, negative_slope: float = 0.01) -> TFloatOrBFloat16:
814+
def aten_leaky_relu(self: TFloat, negative_slope: float = 0.01) -> TFloat:
816815
"""leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor"""
817816

818817
return op.LeakyRelu(self, alpha=negative_slope)
@@ -850,7 +849,7 @@ def aten_linear_bias(input: TFloat, weight: TFloat, bias: TFloat) -> TFloat:
850849

851850

852851
@torch_op("aten::log_sigmoid")
853-
def aten_log_sigmoid(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
852+
def aten_log_sigmoid(self: TFloat) -> TFloat:
854853
"""log_sigmoid(Tensor self) -> Tensor"""
855854

856855
return op.Log(op.Sigmoid(self))

0 commit comments

Comments
 (0)