Skip to content

Commit 6ecebdc

Browse files
authored
Merge branch 'main' into justinchu/fix-slice-rewrite
2 parents 62de1f0 + 7227655 commit 6ecebdc

File tree

7 files changed

+126
-81
lines changed

7 files changed

+126
-81
lines changed

docs/api/optimizer.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,4 @@
1515
optimizer.inline
1616
optimizer.basic_constant_propagation
1717
optimizer.fold_constants
18-
optimizer.remove_unused_nodes
1918
```

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 63 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,15 @@ def aten_acosh(self: TFloat) -> TFloat:
162162

163163

164164
@torch_op(("aten::add.Tensor", "aten::add.Scalar", "_operator::add"), trace_only=True)
165-
def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
165+
def aten_add(self: TTensor, other: TTensor, alpha: float = 1.0) -> TTensor:
166166
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
167-
# TODO(microsoft/onnxruntime#15977): Improve fp16 precision
167+
168+
if self.dtype == ir.DataType.BOOL:
169+
# alpha can also be bool
170+
if alpha == 0:
171+
return op.Identity(self)
172+
return op.Or(self, other)
173+
168174
if alpha != 1.0:
169175
alpha = op.CastLike(alpha, other)
170176
other = op.Mul(other, alpha)
@@ -1237,11 +1243,16 @@ def aten_binomial(
12371243
),
12381244
trace_only=True,
12391245
)
1240-
def aten_bitwise_and(self: TInt, other: TInt) -> TInt:
1246+
def aten_bitwise_and(self: TTensor, other: TTensor) -> TTensor:
12411247
"""bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor"""
1242-
# logical_and implements the BOOL variant
12431248

1244-
return op.BitwiseAnd(self, other)
1249+
assert self.dtype == other.dtype
1250+
1251+
if self.dtype.is_integer():
1252+
return op.BitwiseAnd(self, other)
1253+
if self.dtype == ir.DataType.BOOL:
1254+
return op.And(self, other)
1255+
raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}")
12451256

12461257

12471258
@torch_op(
@@ -1329,11 +1340,14 @@ def aten_bitwise_left_shift_int8(self: INT8, other: INT8) -> INT8:
13291340

13301341

13311342
@torch_op("aten::bitwise_not", trace_only=True)
1332-
def aten_bitwise_not(self: TInt) -> TInt:
1343+
def aten_bitwise_not(self: TTensor) -> TTensor:
13331344
"""bitwise_not(Tensor self) -> Tensor"""
1334-
# logical_not implements the BOOL variant
13351345

1336-
return op.BitwiseNot(self)
1346+
if self.dtype == ir.DataType.BOOL:
1347+
return op.Not(self)
1348+
if self.dtype.is_integer():
1349+
return op.BitwiseNot(self)
1350+
raise NotImplementedError(f"Not implemented for type {self.dtype}")
13371351

13381352

13391353
@torch_op(
@@ -1345,11 +1359,16 @@ def aten_bitwise_not(self: TInt) -> TInt:
13451359
),
13461360
trace_only=True,
13471361
)
1348-
def aten_bitwise_or(self: TInt, other: TInt) -> TInt:
1362+
def aten_bitwise_or(self: TTensor, other: TTensor) -> TTensor:
13491363
"""bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor"""
1350-
# logical_or implements the BOOL variant
13511364

1352-
return op.BitwiseOr(self, other)
1365+
assert self.dtype == other.dtype
1366+
1367+
if self.dtype.is_integer():
1368+
return op.BitwiseOr(self, other)
1369+
if self.dtype == ir.DataType.BOOL:
1370+
return op.Or(self, other)
1371+
raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}")
13531372

13541373

13551374
@torch_op(
@@ -1487,11 +1506,15 @@ def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8:
14871506
),
14881507
trace_only=True,
14891508
)
1490-
def aten_bitwise_xor(self: TInt, other: TInt) -> TInt:
1509+
def aten_bitwise_xor(self: TTensor, other: TTensor) -> TTensor:
14911510
"""bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor"""
1492-
# logical_xor implements the BOOL variant
1511+
assert self.dtype == other.dtype
14931512

1494-
return op.BitwiseXor(self, other)
1513+
if self.dtype.is_integer():
1514+
return op.BitwiseXor(self, other)
1515+
if self.dtype == ir.DataType.BOOL:
1516+
return op.Xor(self, other)
1517+
raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}")
14951518

14961519

14971520
@torch_op("aten::blackman_window", trace_only=True)
@@ -5010,58 +5033,46 @@ def aten_logdet(self: TFloat) -> TFloat:
50105033
return op.Log(op.Det(self))
50115034

50125035

5013-
@torch_op(
5014-
(
5015-
"aten::logical_and",
5016-
"aten::bitwise_and.Tensor",
5017-
"aten::bitwise_and.Scalar",
5018-
"aten::bitwise_and.Scalar_Tensor",
5019-
),
5020-
trace_only=True,
5021-
)
5022-
def aten_logical_and(self: BOOL, other: BOOL) -> BOOL:
5036+
@torch_op("aten::logical_and", trace_only=True)
5037+
def aten_logical_and(self: TTensor, other: TTensor) -> BOOL:
50235038
"""logical_and(Tensor self, Tensor other) -> Tensor"""
50245039

5025-
return op.And(self, other)
5040+
assert self.dtype == other.dtype
5041+
5042+
if self.dtype == ir.DataType.BOOL:
5043+
return op.And(self, other)
5044+
return op.And(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype))
50265045

50275046

5028-
@torch_op(("aten::logical_not", "aten::bitwise_not"), trace_only=True)
5029-
def aten_logical_not(self: BOOL) -> BOOL:
5047+
@torch_op("aten::logical_not", trace_only=True)
5048+
def aten_logical_not(self: TTensor) -> BOOL:
50305049
"""logical_not(Tensor self) -> Tensor"""
50315050

5032-
return op.Not(self)
5051+
if self.dtype == ir.DataType.BOOL:
5052+
return op.Not(self)
5053+
return op.Not(op.Cast(self, to=BOOL.dtype))
50335054

50345055

5035-
@torch_op(
5036-
(
5037-
"aten::logical_or",
5038-
"aten::bitwise_or.Tensor",
5039-
"aten::bitwise_or.Scalar",
5040-
"aten::bitwise_or.Scalar_Tensor",
5041-
"aten::add.Tensor",
5042-
"aten::add.Scalar",
5043-
),
5044-
trace_only=True,
5045-
)
5046-
def aten_logical_or(self: BOOL, other: BOOL) -> BOOL:
5056+
@torch_op("aten::logical_or", trace_only=True)
5057+
def aten_logical_or(self: TTensor, other: TTensor) -> BOOL:
50475058
"""logical_or(Tensor self, Tensor other) -> Tensor"""
50485059

5049-
return op.Or(self, other)
5060+
assert self.dtype == other.dtype
50505061

5062+
if self.dtype == ir.DataType.BOOL:
5063+
return op.Or(self, other)
5064+
return op.Or(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype))
50515065

5052-
@torch_op(
5053-
(
5054-
"aten::logical_xor",
5055-
"aten::bitwise_xor.Tensor",
5056-
"aten::bitwise_xor.Scalar",
5057-
"aten::bitwise_xor.Scalar_Tensor",
5058-
),
5059-
trace_only=True,
5060-
)
5061-
def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL:
5066+
5067+
@torch_op("aten::logical_xor", trace_only=True)
5068+
def aten_logical_xor(self: TTensor, other: TTensor) -> BOOL:
50625069
"""logical_xor(Tensor self, Tensor other) -> Tensor"""
50635070

5064-
return op.Xor(self, other)
5071+
assert self.dtype == other.dtype
5072+
5073+
if self.dtype == ir.DataType.BOOL:
5074+
return op.Xor(self, other)
5075+
return op.Xor(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype))
50655076

50665077

50675078
@torch_op("aten::logit", private=True)

onnxscript/optimizer/__init__.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,8 @@
1919

2020
import onnxscript.optimizer._constant_folding as constant_folding
2121
from onnxscript import ir
22-
from onnxscript.optimizer._constant_folding import (
23-
basic_constant_propagation,
24-
)
25-
from onnxscript.optimizer._constant_folding import (
26-
fold_constants as fold_constants_ir,
27-
)
22+
from onnxscript.optimizer._constant_folding import basic_constant_propagation
23+
from onnxscript.optimizer._constant_folding import fold_constants as fold_constants_ir
2824
from onnxscript.optimizer._optimizer import optimize_ir
2925

3026
_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model)

onnxscript/optimizer/_constant_folding.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@
55

66
from __future__ import annotations
77

8+
__all__ = [
9+
"basic_constant_propagation",
10+
"fold_constants",
11+
"FoldConstantsPass",
12+
"FOLDED_FROM_KEY",
13+
]
14+
815
import dataclasses
916
import logging
1017
import math
@@ -23,6 +30,9 @@
2330

2431
DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = 512 * 512
2532

33+
# Key used to store the metadata
34+
FOLDED_FROM_KEY = "pkg.onnxscript.optimizer.folded_from"
35+
2636

2737
_NON_DETERMINISTIC_OPS = frozenset(
2838
{
@@ -491,9 +501,7 @@ def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
491501
# should handle this. Only the optimization to eliminate redundant Cast ops
492502
# should be needed here.
493503

494-
input_shape = input.shape
495-
if input_shape is not None:
496-
output.shape = input_shape.copy()
504+
output.shape = _merge_shapes(output.shape, input.shape)
497505

498506
input_dtype = _get_input_element_type(node, 0)
499507
output_dtype = _get_int_attribute(node, "to", None)
@@ -600,6 +608,9 @@ def identity(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
600608
input = node.inputs[0]
601609
output = node.outputs[0]
602610
if input is not None and output is not None:
611+
input.shape = _merge_shapes(input.shape, output.shape)
612+
if input.type is None:
613+
input.type = output.type
603614
state.set_sym_value(output, input)
604615
return None
605616

@@ -914,6 +925,24 @@ def merge_dims(dim1, dim2):
914925
return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)])
915926

916927

928+
def _record_contributing_values(original_node: ir.Node, replacement: Replacement) -> None:
929+
"""Record the set of original input values that contributed to the constant-folded outputs."""
930+
folded_from: set[str] = set()
931+
for input in original_node.inputs:
932+
if input is None:
933+
continue
934+
folded_from.update(input.meta.get(FOLDED_FROM_KEY, set()))
935+
assert input.name is not None
936+
folded_from.add(input.name)
937+
938+
for new_output in replacement.new_outputs:
939+
if new_output is None:
940+
continue
941+
new_output.meta[FOLDED_FROM_KEY] = folded_from
942+
# Store the string representation of the set to metadata_props to persist it across serialization
943+
new_output.metadata_props[FOLDED_FROM_KEY] = repr(sorted(folded_from))
944+
945+
917946
class FoldConstantsPass(ir.passes.InPlacePass):
918947
"""A pass that folds constant expressions in the model.
919948
@@ -1203,9 +1232,14 @@ def convert(av):
12031232
)
12041233
return None
12051234

1206-
def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function) -> None:
1235+
def replace_node(
1236+
self, node: ir.Node, replacement: Replacement, root: ir.Graph | ir.Function
1237+
) -> None:
12071238
logger.debug("Replacing node: %s::%s %s", node.domain, node.op_type, node.name)
12081239

1240+
# Record the names of the values that has contributed to the replacement
1241+
_record_contributing_values(node, replacement)
1242+
12091243
ir.convenience.replace_nodes_and_values(
12101244
root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs
12111245
)

onnxscript/rewriter/_ir_utils.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,23 +78,34 @@ def get_numpy_value(val: ir.Value | None) -> np.ndarray | None:
7878
return None
7979

8080

81-
def get_singleton_value(val: ir.Value | None, rank: int | None = None):
81+
def get_singleton_value(val: ir.Value | None, rank: int | Sequence[int] | None = None):
8282
"""Returns element of a single element tensor constant value, and None otherwise.
8383
84-
If rank is specified, it checks that the value has the given rank.
84+
If an int rank is specified, it checks that the value has the given rank.
85+
If the rank is a sequence of ints, it checks that the value has one of the given ranks.
86+
87+
Thus, `rank=0` checks for a scalar, `rank=1` checks for a 1D tensor, and
88+
`rank=(0,1)` checks for either a scalar or a 1D tensor.
8589
"""
8690
np_val = get_numpy_value(val)
8791
if np_val is not None and np_val.size == 1:
88-
if rank is None or (np_val.ndim == rank):
89-
return np_val.item()
92+
value = np_val.item()
93+
if (rank is None) or (isinstance(rank, int) and (np_val.ndim == rank)):
94+
return value
95+
if isinstance(rank, Sequence) and (np_val.ndim in rank):
96+
return value
9097
return None
9198

9299

93100
def is_singleton_value(
94-
val: ir.Value | None, expected: float | int | Callable, *, rtol: float | None = None
101+
val: ir.Value | None,
102+
expected: float | int | Callable,
103+
*,
104+
rtol: float | None = None,
105+
rank: int | Sequence[int] | None = None,
95106
) -> bool:
96107
"""Returns True if the value is a single element tensor with given value, and False otherwise."""
97-
scalar = get_singleton_value(val)
108+
scalar = get_singleton_value(val, rank=rank)
98109
if scalar is None:
99110
return False
100111
if callable(expected):

onnxscript/rewriter/rules/fusion/_rotary_embedding.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,9 @@ def pattern(self, op, x, freqs, start1, end1, start2, end2, one1, one2):
4343
def check(self, op, x, start1, end1, start2, end2, one1, one2, **_) -> pattern.MatchResult: # type: ignore[name-defined]
4444
check_result = pattern.MatchResult()
4545

46-
def is_one(val):
47-
"""Check if val is a 0/1 dimensional tensor with a single element equal to 1."""
48-
np_val = _ir_utils.get_numpy_value(val)
49-
return (
50-
np_val is not None
51-
and np_val.size == 1
52-
and np_val.ndim <= 1
53-
and np_val.item() == 1
54-
)
55-
56-
if not is_one(one1):
46+
if not _ir_utils.is_singleton_value(one1, 1):
5747
return check_result.fail("Unsqueeze axes is not [1]", one1)
58-
if not is_one(one2):
48+
if not _ir_utils.is_singleton_value(one2, 1):
5949
return check_result.fail("Unsqueeze axes is not [1]", one2)
6050

6151
# x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads)

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1631,6 +1631,10 @@ def _where_input_wrangler(
16311631
dtypes=(torch.float32 if sys.platform != "linux" else torch.complex64,),
16321632
reason="fixme: test is unstable on macosx, windows",
16331633
),
1634+
TorchLibOpInfo("logical_and", core_ops.aten_logical_and),
1635+
TorchLibOpInfo("logical_not", core_ops.aten_logical_not),
1636+
TorchLibOpInfo("logical_or", core_ops.aten_logical_or),
1637+
TorchLibOpInfo("logical_xor", core_ops.aten_logical_xor),
16341638
TorchLibOpInfo("logit", core_ops.aten_logit, tolerance={torch.float16: (1e-1, 7e-4)}),
16351639
TorchLibOpInfo("max_dim", core_ops.aten_max_dim)
16361640
.xfail(

0 commit comments

Comments
 (0)