diff --git a/noxfile.py b/noxfile.py index 23c2963998..60c2bb901b 100644 --- a/noxfile.py +++ b/noxfile.py @@ -29,9 +29,9 @@ "ml-dtypes", ) ONNX = "onnx==1.17" -ONNX_RUNTIME = "onnxruntime==1.20.1" -PYTORCH = "torch==2.5.1" -TORCHVISON = "torchvision==0.20.1" +ONNX_RUNTIME = "onnxruntime==1.23.0" +PYTORCH = "torch==2.7.1" +TORCHVISON = "torchvision==0.22.1" TRANSFORMERS = "transformers==4.37.2" ONNX_RUNTIME_NIGHTLY_DEPENDENCIES = ( "flatbuffers", diff --git a/onnxscript/backend/onnx_export_test.py b/onnxscript/backend/onnx_export_test.py index 49eb398750..1f913ed897 100644 --- a/onnxscript/backend/onnx_export_test.py +++ b/onnxscript/backend/onnx_export_test.py @@ -84,6 +84,7 @@ def skip(pattern: str | Pattern, reason: str, *, condition: bool = True): ), skip(r"^test_ai_onnx_ml_label_encoder", "ONNX Runtime does not support Opset 21 at 1.17"), skip(r"^test_ai_onnx_ml_tree_ensemble", "Opset 23 is not supported"), + skip(r"^test_attention", "ONNX Runtime 1.23 fails on these tests"), ) if sys.platform == "win32": diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 9e6aa69edc..e837bfadae 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -18,21 +18,16 @@ import torch from onnxscript import ( - BFLOAT16, BOOL, COMPLEX64, COMPLEX128, DOUBLE, FLOAT, - FLOAT16, INT8, INT16, INT32, INT64, UINT8, - UINT16, - UINT32, - UINT64, graph, ir, ) @@ -77,13 +72,11 @@ def aten__local_scalar_dense(self: TensorType) -> TensorType: @torch_op("aten::_log_softmax", trace_only=True) -def aten__log_softmax_half( - self: Union[FLOAT16, BFLOAT16], dim: int, half_to_float: bool -) -> FLOAT: +def aten__log_softmax(self: TFloat, dim: int, half_to_float: bool) -> TFloatHighPrecision: """_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" self_is_scalar = len(self.shape) == 0 - if half_to_float: + if half_to_float and self.dtype in {ir.DataType.FLOAT16, ir.DataType.BFLOAT16}: self = op.Cast(self, to=FLOAT.dtype) if self_is_scalar: self = op.Unsqueeze(self, op.Constant(value_ints=[0])) @@ -93,44 +86,23 @@ def aten__log_softmax_half( return result -@torch_op("aten::_log_softmax", trace_only=True) -def aten__log_softmax( - self: TFloatHighPrecision, - dim: int, - half_to_float: bool, -) -> TFloatHighPrecision: - """_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" +@torch_op("aten::_softmax", trace_only=True) +def aten__softmax(self: TFloat, dim: int, half_to_float: bool) -> TFloatHighPrecision: + """_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" self_is_scalar = len(self.shape) == 0 + + if half_to_float and self.dtype in {ir.DataType.FLOAT16, ir.DataType.BFLOAT16}: + self = op.Cast(self, to=FLOAT.dtype) + if self_is_scalar: self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - result = op.LogSoftmax(self, axis=dim) + result = op.Softmax(self, axis=dim) if self_is_scalar: + # Convert to scalar when input is scalar result = op.Squeeze(result) - return result - -@torch_op("aten::_softmax", trace_only=True) -def aten__softmax_half(self: Union[FLOAT16, BFLOAT16], dim: int, half_to_float: bool) -> FLOAT: - """_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" - - # trace_only because we need to cast conditionally based on half_to_float - if half_to_float: - self = op.Cast(self, to=FLOAT.dtype) - - return aten_softmax_no_dtype(self, dim) - - -@torch_op("aten::_softmax", trace_only=True) -def aten__softmax( - self: TFloatHighPrecision, dim: int, half_to_float: bool -) -> TFloatHighPrecision: - """_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" - - # trace_only to reuse aten_softmax_no_dtype - - del half_to_float # Unused - return aten_softmax_no_dtype(self, dim) + return result @torch_op(("aten::abs", "_operator::abs"), trace_only=True) @@ -380,7 +352,6 @@ def aten_all_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) return self -@torch_op("aten::all.dims", trace_only=True) def _aten_all_dims_no_dim(self: TTensor, keepdims: bool) -> BOOL: """all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor""" @@ -499,7 +470,6 @@ def aten_any_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) return self -@torch_op("aten::any.dims", trace_only=True) def _aten_any_dims_no_dim(self: TTensor, keepdims: bool) -> BOOL: if len(self.shape) == 0: result = op.Cast(self, to=BOOL.dtype) @@ -739,7 +709,6 @@ def aten_argmax( return result -@torch_op("aten::argmax", private=True, trace_only=True) def _aten_argmax(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: """argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" @@ -752,7 +721,6 @@ def _aten_argmax(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: return result -@torch_op("aten::argmax", private=True, trace_only=True) def _aten_argmax_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64: """argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" @@ -780,7 +748,6 @@ def aten_argmin( return result -@torch_op("aten::argmin", private=True, trace_only=True) def _aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: """argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" @@ -793,7 +760,6 @@ def _aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: return result -@torch_op("aten::argmin", private=True, trace_only=True) def _aten_argmin_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64: """argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" @@ -1282,78 +1248,30 @@ def aten_bitwise_and(self: TTensor, other: TTensor) -> TTensor: ), trace_only=True, ) -def aten_bitwise_left_shift_int16(self: INT16, other: INT16) -> INT16: - """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - # assert other >= 0 - self = op.Cast(self, to=UINT16.dtype) - other = op.Cast(other, to=UINT16.dtype) - - result = op.BitShift(self, other, direction="LEFT") - - return op.Cast(result, to=INT16.dtype) - - -@torch_op( - ( - "aten::bitwise_left_shift.Tensor", - "aten::bitwise_left_shift.Tensor_Scalar", - "aten::bitwise_left_shift.Scalar_Tensor", - "_operator::__lshift__", - "aten::__lshift__.Scalar", - ), - trace_only=True, -) -def aten_bitwise_left_shift_int32(self: INT32, other: INT32) -> INT32: +def aten_bitwise_left_shift(self: TInt, other: TInt) -> TInt: """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" # assert other >= 0 - self = op.Cast(self, to=UINT32.dtype) - other = op.Cast(other, to=UINT32.dtype) - - result = op.BitShift(self, other, direction="LEFT") - - return op.Cast(result, to=INT32.dtype) - - -@torch_op( - ( - "aten::bitwise_left_shift.Tensor", - "aten::bitwise_left_shift.Tensor_Scalar", - "aten::bitwise_left_shift.Scalar_Tensor", - "_operator::__lshift__", - "aten::__lshift__.Scalar", - ), - trace_only=True, -) -def aten_bitwise_left_shift_int64(self: INT64, other: INT64) -> INT64: - """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - # assert other >= 0 - self = op.Cast(self, to=UINT64.dtype) - other = op.Cast(other, to=UINT64.dtype) - - result = op.BitShift(self, other, direction="LEFT") - - return op.Cast(result, to=INT64.dtype) - + if self.dtype.bitwidth == 8: + unsigned_dtype = ir.DataType.UINT8 + signed_dtype = ir.DataType.INT8 + elif self.dtype.bitwidth == 16: + unsigned_dtype = ir.DataType.UINT16 + signed_dtype = ir.DataType.INT16 + elif self.dtype.bitwidth == 32: + unsigned_dtype = ir.DataType.UINT32 + signed_dtype = ir.DataType.INT32 + elif self.dtype.bitwidth == 64: + unsigned_dtype = ir.DataType.UINT64 + signed_dtype = ir.DataType.INT64 + else: + raise NotImplementedError(f"Not implemented for type {self.dtype}") -@torch_op( - ( - "aten::bitwise_left_shift.Tensor", - "aten::bitwise_left_shift.Tensor_Scalar", - "aten::bitwise_left_shift.Scalar_Tensor", - "_operator::__lshift__", - "aten::__lshift__.Scalar", - ), - trace_only=True, -) -def aten_bitwise_left_shift_int8(self: INT8, other: INT8) -> INT8: - """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - # assert other >= 0 - self = op.Cast(self, to=UINT8.dtype) - other = op.Cast(other, to=UINT8.dtype) + self = op.Cast(self, to=unsigned_dtype) + other = op.Cast(other, to=unsigned_dtype) result = op.BitShift(self, other, direction="LEFT") - return op.Cast(result, to=INT8.dtype) + return op.Cast(result, to=signed_dtype) @torch_op("aten::bitwise_not", trace_only=True) @@ -1395,115 +1313,37 @@ def aten_bitwise_or(self: TTensor, other: TTensor) -> TTensor: "aten::bitwise_right_shift.Scalar_Tensor", "_operator::__rshift__", "aten::__rshift__.Scalar", - ) -) -def aten_bitwise_right_shift_int16(self: INT16, other: INT16) -> INT16: - """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - negative = op.Less(self, 0) - self = op.Cast(self, to=UINT16.dtype) - other = op.Cast(other, to=UINT16.dtype) - - # Simulate arithmetic shift using logical shift - # Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting - mask = op.BitShift( - op.Cast(op.Constant(value_int=0xFFFF), to=UINT16.dtype), other, direction="RIGHT" - ) - mask = op.BitwiseNot(mask) - # Do logical shift - shifted = op.BitShift(self, other, direction="RIGHT") - # Compute the arithmetic shifted value assuming the sign bit was set - negative_shifted = op.BitwiseOr(shifted, mask) - # Choose the shifted value based on the sign bit - return op.Where( - negative, op.Cast(negative_shifted, to=INT16.dtype), op.Cast(shifted, to=INT16.dtype) - ) - - -@torch_op( - ( - "aten::bitwise_right_shift.Tensor", - "aten::bitwise_right_shift.Tensor_Scalar", - "aten::bitwise_right_shift.Scalar_Tensor", - "_operator::__rshift__", - "aten::__rshift__.Scalar", - ) -) -def aten_bitwise_right_shift_int32(self: INT32, other: INT32) -> INT32: - """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - negative = op.Less(self, 0) - self = op.Cast(self, to=UINT32.dtype) - other = op.Cast(other, to=UINT32.dtype) - - # Simulate arithmetic shift using logical shift - # Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting - mask = op.BitShift( - op.Cast(op.Constant(value_int=0xFFFFFFFF), to=UINT32.dtype), other, direction="RIGHT" - ) - mask = op.BitwiseNot(mask) - # Do logical shift - shifted = op.BitShift(self, other, direction="RIGHT") - # Compute the arithmetic shifted value assuming the sign bit was set - negative_shifted = op.BitwiseOr(shifted, mask) - # Choose the shifted value based on the sign bit - return op.Where( - negative, op.Cast(negative_shifted, to=INT32.dtype), op.Cast(shifted, to=INT32.dtype) - ) - - -@torch_op( - ( - "aten::bitwise_right_shift.Tensor", - "aten::bitwise_right_shift.Tensor_Scalar", - "aten::bitwise_right_shift.Scalar_Tensor", - "_operator::__rshift__", - "aten::__rshift__.Scalar", - ) + ), + trace_only=True, ) -def aten_bitwise_right_shift_int64(self: INT64, other: INT64) -> INT64: +def aten_bitwise_right_shift(self: TInt, other: TInt) -> TInt: """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - negative = op.Less(self, 0) - self = op.Cast(self, to=UINT64.dtype) - other = op.Cast(other, to=UINT64.dtype) - - # Simulate arithmetic shift using logical shift - # Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting - mask = op.BitShift( - # 0xFFFFFFFFFFFFFFFF - op.Cast(op.Constant(value_int=-1), to=UINT64.dtype), - other, - direction="RIGHT", - ) - mask = op.BitwiseNot(mask) - # Do logical shift - shifted = op.BitShift(self, other, direction="RIGHT") - # Compute the arithmetic shifted value assuming the sign bit was set - negative_shifted = op.BitwiseOr(shifted, mask) - # Choose the shifted value based on the sign bit - return op.Where( - negative, op.Cast(negative_shifted, to=INT64.dtype), op.Cast(shifted, to=INT64.dtype) - ) - + if self.dtype.bitwidth == 8: + unsigned_dtype = ir.DataType.UINT8 + signed_dtype = ir.DataType.INT8 + mask = ir.tensor(0xFF, dtype=unsigned_dtype) + elif self.dtype.bitwidth == 16: + unsigned_dtype = ir.DataType.UINT16 + signed_dtype = ir.DataType.INT16 + mask = ir.tensor(0xFFFF, dtype=unsigned_dtype) + elif self.dtype.bitwidth == 32: + unsigned_dtype = ir.DataType.UINT32 + signed_dtype = ir.DataType.INT32 + mask = ir.tensor(0xFFFFFFFF, dtype=unsigned_dtype) + elif self.dtype.bitwidth == 64: + unsigned_dtype = ir.DataType.UINT64 + signed_dtype = ir.DataType.INT64 + mask = ir.tensor(0xFFFFFFFFFFFFFFFF, dtype=unsigned_dtype) # 0xFFFFFFFFFFFFFFFF + else: + raise NotImplementedError(f"Not implemented for type {self.dtype}") -@torch_op( - ( - "aten::bitwise_right_shift.Tensor", - "aten::bitwise_right_shift.Tensor_Scalar", - "aten::bitwise_right_shift.Scalar_Tensor", - "_operator::__rshift__", - "aten::__rshift__.Scalar", - ) -) -def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8: - """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" negative = op.Less(self, 0) - self = op.Cast(self, to=UINT8.dtype) - other = op.Cast(other, to=UINT8.dtype) + self = op.Cast(self, to=unsigned_dtype) + other = op.Cast(other, to=unsigned_dtype) # Simulate arithmetic shift using logical shift # Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting - mask = op.BitShift( - op.Cast(op.Constant(value_int=0xFF), to=UINT8.dtype), other, direction="RIGHT" - ) + mask = op.BitShift(mask, other, direction="RIGHT") mask = op.BitwiseNot(mask) # Do logical shift shifted = op.BitShift(self, other, direction="RIGHT") @@ -1511,7 +1351,7 @@ def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8: negative_shifted = op.BitwiseOr(shifted, mask) # Choose the shifted value based on the sign bit return op.Where( - negative, op.Cast(negative_shifted, to=INT8.dtype), op.Cast(shifted, to=INT8.dtype) + negative, op.Cast(negative_shifted, to=signed_dtype), op.Cast(shifted, to=signed_dtype) ) @@ -2173,7 +2013,6 @@ def aten_convolution( return result -@torch_op("aten::convolution", private=True, trace_only=True) def _aten_convolution_onnx( input: TFloat, weight: TFloat, @@ -2645,80 +2484,10 @@ def aten_diagflat(self: TensorType, offset: int = 0) -> TensorType: @torch_op(("aten::diagonal", "aten::diagonal_copy"), trace_only=True) -def aten_diagonal(self: TReal, offset: int = 0, dim1: int = 0, dim2: int = 1) -> TReal: +def aten_diagonal(self: TTensor, offset: int = 0, dim1: int = 0, dim2: int = 1) -> TTensor: """diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)""" - # perm is used to transpose the tensor to make dim1 and dim2 as the last 2 dims - # [0,1,2] -> [2,0,1] when dim1=0 and dim2=1 - # [0,1,2] -> [1,0,2] when dim1=0 and dim2=2 - # [0,1,2] -> [0,1,2] when dim1=1 and dim2=2 - if dim1 < 0: - dim1 = dim1 + len(self.shape) - if dim2 < 0: - dim2 = dim2 + len(self.shape) - - self_rank = len(self.shape) - perm = list(range(self_rank)) - perm.remove(dim1) - perm.remove(dim2) - perm.append(dim1) - perm.append(dim2) - - # If rank=2, then axes=[0]; if rank=3, then axes=[1] - # This is because computing diagonal sum is on dim2 after transpose by perm - axes = [self_rank - 2] - - neg_1 = op.Constant(value_ints=[-1]) - dim1_size = op.Reshape(op.Gather(op.Shape(self), dim1), neg_1) # row - dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col - mask_shape = op.Concat(dim1_size, dim2_size, axis=0) - mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset) - mask = op.CastLike(mask, self) - self_t = op.Transpose(self, perm=perm) - result = op.Mul(self_t, mask) - result = op.ReduceSum(result, keepdims=False, axes=axes) - # min(row, col) - min_dim_size = op.Min(dim1_size, dim2_size) - # take 2 tensors as example: - # one is 3x5 in size, min_dim_size = 3, dim1_size = 3 - # the other is 5x3 in size, min_dim_size = 3, dim1_size = 5 - # 3 rows x 5 cols 5 rows x 3 cols - # offset diagonal offset diagonal - # ---------------- ---------------- - # -4 0 -6 0 - # -3 0 -5 0 - # -2 1 -4 1 - # -1 2 -3 2 - # 0 3 -2 3 - # 1 3 -1 3 - # 2 3 0 3 - # 3 2 1 2 - # 4 1 2 1 - # 5 0 3 0 - # 6 0 4 0 - - # From above table, we can get the logic below - offset_val = op.Constant(value_ints=[offset]) - if offset < 0: - # row + offset - length = op.Add(dim1_size, offset_val) - start = op.Constant(value_ints=[0]) - else: # offset >= 0 - # col - offset - length = op.Sub(dim2_size, offset_val) - start = offset_val - - # max(min(length, min(row, col)), 0) - length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0])) - end = op.Add(start, length) - result = op.Slice(result, start, end, axes=axes) - - return result - - -@torch_op("aten::diagonal", trace_only=True) -def aten_diagonal_bool(self: BOOL, offset: int = 0, dim1: int = 0, dim2: int = 1) -> BOOL: - """diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)""" + is_bool = self.dtype == BOOL.dtype # perm is used to transpose the tensor to make dim1 and dim2 as the last 2 dims # [0,1,2] -> [2,0,1] when dim1=0 and dim2=1 @@ -2745,10 +2514,16 @@ def aten_diagonal_bool(self: BOOL, offset: int = 0, dim1: int = 0, dim2: int = 1 dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col mask_shape = op.Concat(dim1_size, dim2_size, axis=0) mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset) - self_int = op.Cast(self, to=INT64.dtype) - mask_int = op.Cast(mask, to=INT64.dtype) - self_int_t = op.Transpose(self_int, perm=perm) - result = op.Mul(self_int_t, mask_int) + + if is_bool: + self_int = op.Cast(self, to=INT64.dtype) + mask_int = op.Cast(mask, to=INT64.dtype) + self_int_t = op.Transpose(self_int, perm=perm) + result = op.Mul(self_int_t, mask_int) + else: + mask = op.CastLike(mask, self) + self_t = op.Transpose(self, perm=perm) + result = op.Mul(self_t, mask) result = op.ReduceSum(result, keepdims=False, axes=axes) # min(row, col) min_dim_size = op.Min(dim1_size, dim2_size) @@ -2785,7 +2560,9 @@ def aten_diagonal_bool(self: BOOL, offset: int = 0, dim1: int = 0, dim2: int = 1 length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0])) end = op.Add(start, length) result = op.Slice(result, start, end, axes=axes) - result = op.Cast(result, to=BOOL.dtype) + + if is_bool: + result = op.Cast(result, to=BOOL.dtype) return result @@ -2896,45 +2673,37 @@ def aten_div_complex(self: TFloat, other: TFloat) -> TFloat: @torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True) -def aten_div_mode(self: TFloat, other: TFloat, rounding_mode: Optional[str] = None) -> TFloat: +def aten_div_mode(self: TReal, other: TReal, rounding_mode: Optional[str] = None) -> TReal: """div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor""" assert rounding_mode in {"trunc", "floor", None} - if rounding_mode == "trunc": - # Rounds the results of the division towards zero. - # Equivalent to C-style integer division - return aten_trunc(op.Div(self, other)) - if rounding_mode == "floor": - return op.Floor(op.Div(self, other)) - - return op.Div(self, other) - + if self.dtype.is_integer(): + quotient = op.Div(op.Cast(self, to=FLOAT.dtype), op.Cast(other, to=FLOAT.dtype)) -@torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True) -def aten_div_mode_int( - self: TInt, other: TInt, rounding_mode: Optional[str] = None -) -> TensorType: - """div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor + if rounding_mode == "trunc": + # Rounds the results of the division towards zero. + # Equivalent to C-style integer division + result = aten_trunc(quotient) + return op.CastLike(result, self) + if rounding_mode == "floor": + result = op.Floor(quotient) + return op.CastLike(result, self) - Variant for integer inputs. - """ - assert rounding_mode in {"trunc", "floor", None} + assert rounding_mode is None + # When rounding_mode is None, the return type is float32 + return quotient - quotient = op.Div(op.Cast(self, to=FLOAT.dtype), op.Cast(other, to=FLOAT.dtype)) + # Float inputs if rounding_mode == "trunc": # Rounds the results of the division towards zero. # Equivalent to C-style integer division - result = aten_trunc(quotient) - return op.CastLike(result, self) + return aten_trunc(op.Div(self, other)) if rounding_mode == "floor": - result = op.Floor(quotient) - return op.CastLike(result, self) + return op.Floor(op.Div(self, other)) - assert rounding_mode is None - # When rounding_mode is None, the return type is float32 - return quotient + return op.Div(self, other) @torch_op("aten::dot", trace_only=True) @@ -3888,26 +3657,18 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType: ("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge"), trace_only=True, ) -def aten_ge(self: TReal, other: TReal) -> BOOL: - """ge.Tensor(Tensor self, Tensor other) -> Tensor""" - - return op.GreaterOrEqual(self, other) - - -@torch_op( - ("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge"), - trace_only=True, -) -def aten_ge_bool(self: BOOL, other: BOOL) -> BOOL: +def aten_ge(self: TTensor, other: TTensor) -> BOOL: """ge.Tensor(Tensor self, Tensor other) -> Tensor""" - # self, other, self >= other - # F, F, T - # F, T, F - # T, F, T - # T, T, T + if self.dtype == ir.DataType.BOOL: + # self, other, self >= other + # F, F, T + # F, T, F + # T, F, T + # T, T, T + return op.Or(self, op.Not(other)) - return op.Or(self, op.Not(other)) + return op.GreaterOrEqual(self, other) def aten_geqrf(self: TensorType) -> tuple[TensorType, TensorType]: @@ -4036,25 +3797,19 @@ def aten_gru_cell( ("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"), trace_only=True, ) -def aten_gt(self: TReal, other: TReal) -> BOOL: +def aten_gt(self: TTensor, other: TTensor) -> BOOL: """gt.Tensor(Tensor self, Tensor other) -> Tensor""" - return op.Greater(self, other) - + if self.dtype == ir.DataType.BOOL: + # self, other, self > other + # F, F, F + # F, T, F + # T, F, T + # T, T, F -@torch_op( - ("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"), - trace_only=True, -) -def aten_gt_bool(self: BOOL, other: BOOL) -> BOOL: - """gt.Tensor(Tensor self, Tensor other) -> Tensor""" - # self, other, self > other - # F, F, F - # F, T, F - # T, F, T - # T, T, F + return op.And(self, op.Not(other)) - return op.And(self, op.Not(other)) + return op.Greater(self, other) @torch_op("aten::hamming_window", trace_only=True) @@ -4875,26 +4630,19 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType: ("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"), trace_only=True, ) -def aten_le(self: TReal, other: TReal) -> BOOL: +def aten_le(self: TTensor, other: TTensor) -> BOOL: """le.Tensor(Tensor self, Tensor other) -> Tensor""" - return op.LessOrEqual(self, other) - - -@torch_op( - ("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"), - trace_only=True, -) -def aten_le_bool(self: BOOL, other: BOOL) -> BOOL: - """le.Tensor(Tensor self, Tensor other) -> Tensor""" + if self.dtype == ir.DataType.BOOL: + # self, other, self <= other + # F, F, T + # F, T, T + # T, F, F + # T, T, T - # self, other, self <= other - # F, F, T - # F, T, T - # T, F, F - # T, T, T + return op.Or(other, op.Not(self)) - return op.Or(other, op.Not(self)) + return op.LessOrEqual(self, other) @torch_op(("aten::lerp.Tensor", "aten::lerp.Scalar")) @@ -5096,29 +4844,23 @@ def aten_logical_xor(self: TTensor, other: TTensor) -> BOOL: return op.Xor(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) -@torch_op("aten::logit", private=True) -def _aten_logit_onnx(self: TFloat) -> TFloat: - return op.Log(op.Div(self, op.Sub(1.0, self))) +@torch_op("aten::logit", trace_only=True) +def aten_logit(self: TFloat, eps: Optional[float] = None) -> TFloat: + """logit(Tensor self, float? eps=None) -> Tensor""" + one = ir.tensor(1, dtype=self.dtype) + + if eps is None: + return op.Log(op.Div(self, op.Sub(one, self))) + one_minus_eps = ir.tensor(1 - eps, dtype=self.dtype) + eps = ir.tensor(eps, dtype=self.dtype) -@torch_op("aten::logit", private=True) -def _aten_logit_clamp_onnx(self: TFloat, eps: float) -> TFloat: - eps = op.CastLike(eps, self) - one = op.CastLike(1.0, self) - temporary_self = op.Where(self <= one - eps, self, one - eps) + temporary_self = op.Where(self <= one_minus_eps, self, one_minus_eps) z = op.Where(temporary_self < eps, eps, temporary_self) return op.Log(op.Div(z, op.Sub(one, z))) -@torch_op("aten::logit", trace_only=True) -def aten_logit(self: TFloat, eps: Optional[float] = None) -> TFloat: - """logit(Tensor self, float? eps=None) -> Tensor""" - if eps is None: - return _aten_logit_onnx(self) - return _aten_logit_clamp_onnx(self, eps) - - def aten_logspace(start: float, end: float, steps: int, base: float = 10.0) -> TensorType: """logspace(Scalar start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" @@ -5175,26 +4917,18 @@ def aten_lstm_mps_backward( ("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"), trace_only=True, ) -def aten_lt(self: TReal, other: TReal) -> BOOL: - """lt.Tensor(Tensor self, Tensor other) -> Tensor""" - - return op.Less(self, other) - - -@torch_op( - ("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"), - trace_only=True, -) -def aten_lt_bool(self: BOOL, other: BOOL) -> BOOL: +def aten_lt(self: TTensor, other: TTensor) -> BOOL: """lt.Tensor(Tensor self, Tensor other) -> Tensor""" - # self, other, self < other - # F, F, F - # F, T, T - # T, F, F - # T, T, F + if self.dtype == ir.DataType.BOOL: + # self, other, self < other + # F, F, F + # F, T, T + # T, F, F + # T, T, F + return op.And(other, op.Not(self)) - return op.And(other, op.Not(self)) + return op.Less(self, other) def aten_lu_solve(self: TensorType, LU_data: TensorType, LU_pivots: TensorType) -> TensorType: @@ -5368,18 +5102,14 @@ def aten_max_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, I return result, indices -@torch_op("aten::maximum") -def aten_maximum(self: TReal, other: TReal) -> TReal: +@torch_op("aten::maximum", trace_only=True) +def aten_maximum(self: TTensor, other: TTensor) -> TTensor: """maximum(Tensor self, Tensor other) -> Tensor""" - return op.Max(self, other) - - -@torch_op("aten::maximum") -def aten_maximum_bool(self: BOOL, other: BOOL) -> BOOL: - """maximum(Tensor self, Tensor other) -> Tensor""" + if self.dtype == ir.DataType.BOOL: + return op.Or(self, other) - return op.Or(self, other) + return op.Max(self, other) @torch_op("aten::mean") @@ -5414,7 +5144,7 @@ def aten_meshgrid(tensors: Sequence[TensorType]) -> TensorType: raise NotImplementedError() -@torch_op("aten::min") +@torch_op("aten::min", trace_only=True) def aten_min(self: TReal) -> TReal: """min(Tensor self) -> Tensor""" @@ -5435,18 +5165,14 @@ def aten_min_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, T return result, indices -@torch_op("aten::minimum") -def aten_minimum(self: TReal, other: TReal) -> TReal: +@torch_op("aten::minimum", trace_only=True) +def aten_minimum(self: TTensor, other: TTensor) -> TTensor: """minimum(Tensor self, Tensor other) -> Tensor""" - return op.Min(self, other) - - -@torch_op("aten::minimum") -def aten_minimum_bool(self: BOOL, other: BOOL) -> BOOL: - """minimum(Tensor self, Tensor other) -> Tensor""" + if self.dtype == ir.DataType.BOOL: + return op.And(self, other) - return op.And(self, other) + return op.Min(self, other) def aten_miopen_batch_norm( @@ -5789,23 +5515,13 @@ def aten_msort(self: TensorType) -> TensorType: ("aten::mul", "aten::mul.Tensor", "_operator::mul", "aten::multiply.Tensor"), trace_only=True, ) -def aten_mul(self: TReal, other: TReal) -> TReal: +def aten_mul(self: TTensor, other: TTensor) -> TTensor: """mul.Tensor(Tensor self, Tensor other) -> Tensor""" - return op.Mul(self, other) - - -@torch_op( - ("aten::mul", "aten::mul.Tensor", "aten::multiply.Tensor"), - trace_only=True, -) -def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: - """ONNX Mul doesn't support Boolean, so use And as an equivalent operator.""" - - # TODO(justinchuby): Handle cases where type reconcilation is not enough, - # since different ONNX operators are used based on different data types. + if self.dtype == ir.DataType.BOOL: + return op.And(self, other) - return op.And(self, other) + return op.Mul(self, other) @torch_op( @@ -6047,7 +5763,6 @@ def aten_native_batch_norm( return norm, input_mean, input_rstd -@torch_op("aten::native_batch_norm", private=True) def _aten_native_batch_norm_training_onnx( input: TFloat, weight: TFloat, @@ -6099,7 +5814,6 @@ def _aten_native_batch_norm_training_onnx( return norm, mean, rstd, running_mean, new_running_var -@torch_op("aten::native_batch_norm", private=True) def _aten_native_batch_norm_inference_onnx( input: TFloat, weight: TFloat, @@ -6269,22 +5983,10 @@ def aten_native_group_norm( if bias is None: # Set to 0.0 as default, the shape is Channel size bias = op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2)) - # Accoding to Torch, return rstd instead of var - norm, mean, rstd = _aten_native_group_norm_onnx(input, weight, bias, group, eps) - return norm, mean, rstd - - -@torch_op("aten::native_group_norm", private=True) -def _aten_native_group_norm_onnx( - input: TFloat, - weight: TFloat, - bias: TFloat, - group: INT64, - eps: float, -) -> Tuple[TFloat, TFloat, TFloat]: # Because onnx.GroupNorm() need size=group for weight and bias # But the torch's aten function's input need size=channel, the size mismatched # So we have to use onnx.InstanceNorm() to simulate + # This implementation should be simplified after opset 21 neg_1 = op.Constant(value_ints=[-1]) # Create weight_instance_norm and bias_instance_norm, copied from Torch ONNX converter group_tensor = op.Reshape(group, neg_1) @@ -6321,7 +6023,9 @@ def _aten_native_group_norm_onnx( sqr_input_sub_mean = op.Mul(input_sub_mean, input_sub_mean) # In Pytorch, vstd = 1/(sqrt(var + eps)) var = op.ReduceMean(sqr_input_sub_mean, axes, keepdims=False) - rstd = op.Div(1.0, op.Sqrt(var + eps)) + eps = op.Constant(value=ir.tensor(eps, dtype=input.dtype)) + one = op.Constant(value=ir.tensor(1.0, dtype=input.dtype)) + rstd = op.Div(one, op.Sqrt(op.Add(var, eps))) # Get the correct shape [N, group] for mean again mean = op.ReduceMean(input_N_group_neg1, axes, keepdims=False) return norm_result, mean, rstd @@ -6533,16 +6237,7 @@ def aten_norm_except_dim(v: TensorType, pow: int = 2, dim: int = 0) -> TensorTyp raise NotImplementedError() -@torch_op( - ( - "aten::normal.Tensor_float", - "aten::normal.Tensor_Tensor", - "aten::normal.float_Tensor", - "aten::normal.float_float", - "aten::normal_functional", - ), - trace_only=True, -) +@torch_op("aten::normal_functional", trace_only=True) def aten_normal( self: TTensor, mean: float = 0.0, @@ -6571,7 +6266,7 @@ def aten_normal_float_float( return op.Cast(result, to=dtype) -@torch_op("aten::normal.float_Tensor") +@torch_op("aten::normal.float_Tensor", trace_only=True) def aten_normal_float_tensor(mean: FLOAT, std: TFloat) -> TFloat: """normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor""" @@ -6581,7 +6276,7 @@ def aten_normal_float_tensor(mean: FLOAT, std: TFloat) -> TFloat: return op.Add(op.Mul(std, sampled), mean_casted) -@torch_op("aten::normal.Tensor_float") +@torch_op("aten::normal.Tensor_float", trace_only=True) def aten_normal_tensor_float(mean: TFloat, std: FLOAT) -> TFloat: """normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor""" @@ -6590,7 +6285,7 @@ def aten_normal_tensor_float(mean: TFloat, std: FLOAT) -> TFloat: return op.Add(op.Mul(op.CastLike(std, sampled), sampled), mean) -@torch_op("aten::normal.Tensor_Tensor") +@torch_op("aten::normal.Tensor_Tensor", trace_only=True) def aten_normal_tensor_tensor(mean: TFloat, std: TFloat) -> TFloat: """normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor""" @@ -7298,10 +6993,15 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType: raise NotImplementedError() -@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"), trace_only=True) -def aten_remainder(self: TFloat, other: TFloat) -> TFloat: +@torch_op( + ("aten::remainder.Tensor", "aten::remainder.Scalar", "_operator::mod"), trace_only=True +) +def aten_remainder(self: TTensor, other: TTensor) -> TTensor: """remainder.Tensor(Tensor self, Tensor other) -> Tensor""" + if self.dtype.is_integer(): + return op.Mod(self, other) + # TODO(justinchuby): Improve fp16 precision by following the logic in # https://github.com/pytorch/pytorch/blob/3a823e46170778cc32783f27596c77d0103084a9/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L264-L277 @@ -7311,15 +7011,6 @@ def aten_remainder(self: TFloat, other: TFloat) -> TFloat: return op.Sub(self, op.Mul(rounded_quotient, other)) -@torch_op( - ("aten::remainder.Tensor", "aten::remainder.Scalar", "_operator::mod"), trace_only=True -) -def aten_remainder_int(self: TInt, other: TInt) -> TInt: - """remainder.Tensor(Tensor self, Tensor other) -> Tensor""" - - return op.Mod(self, other) - - def aten_rename(self: TensorType, names: Optional[str]) -> TensorType: """rename(Tensor(a) self, Dimname[]? names) -> Tensor(a)""" @@ -7538,23 +7229,29 @@ def aten_rnn_tanh_cell( def aten_roll(self: TTensor, shifts: Sequence[int], dims: Sequence[int] = ()) -> TTensor: """roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor""" + if isinstance(shifts, int): + shifts = [shifts] + + if isinstance(dims, int): + dims = [dims] + self_rank = len(self.shape) if self_rank == 0: return op.Identity(self) elif self.shape[0] == 0: # empty tensor return op.Identity(self) + + # NOTE: In pytorch, default value of dims is an empty list. + if len(dims) == 0: # Empty sequence + assert len(shifts) == 1, "shifts should be a single integer if dims is empty" + return _aten_roll_shift_no_dim_onnx(self, shifts[0]) else: - # NOTE: In pytorch, default value of dims is an empty list. - if len(dims) == 0: # Empty sequence - # assert isinstance(shifts, int) - return _aten_roll_shift_no_dim_onnx(self, shifts) - else: - # assert len(shifts) == len(dims), but shifts is a tensor, dims is a list - result = self - for i, shift in enumerate(shifts): - dim = dims[i] - result = _aten_roll_shift_and_dim_onnx(result, shift, dim) - return result + assert len(shifts) == len(dims) + result = self + for i, shift in enumerate(shifts): + dim = dims[i] + result = _aten_roll_shift_and_dim_onnx(result, shift, dim) + return result @torch_op("aten::roll", trace_only=True, complex=True) @@ -7563,6 +7260,12 @@ def aten_roll_complex( ) -> TTensor: """roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor""" + if isinstance(shifts, int): + shifts = [shifts] + + if isinstance(dims, int): + dims = [dims] + self_rank = len(self.shape) if self_rank == 1: return op.Identity(self) @@ -7573,37 +7276,34 @@ def aten_roll_complex( self_real = op.Slice(self, [0], [1], axes=[-1]) self_imag = op.Slice(self, [1], [2], axes=[-1]) if not dims: - # assert isinstance(shifts, int) - shift_real = _aten_roll_shift_no_dim_onnx(self_real, shifts) - shift_imag = _aten_roll_shift_no_dim_onnx(self_imag, shifts) + assert len(shifts) == 1, "shifts should be a single integer if dims is empty" + shift_real = _aten_roll_shift_no_dim_onnx(self_real, shifts[0]) + shift_imag = _aten_roll_shift_no_dim_onnx(self_imag, shifts[0]) result = op.Concat(shift_real, shift_imag, axis=-1) else: - # assert len(shifts) == len(dims), but shifts is a tensor, dims is a list + assert len(shifts) == len(dims) for i, dim in enumerate(dims): - shift = op.Gather(shifts, i, axis=0) - self_real = _aten_roll_shift_and_dim_onnx(self_real, shift, dim) - self_imag = _aten_roll_shift_and_dim_onnx(self_imag, shift, dim) + self_real = _aten_roll_shift_and_dim_onnx(self_real, shifts[i], dim) + self_imag = _aten_roll_shift_and_dim_onnx(self_imag, shifts[i], dim) result = op.Concat(self_real, self_imag, axis=-1) return result -@torch_op("aten::roll", private=True) -def _aten_roll_shift_no_dim_onnx(self: TTensor, shift: INT64) -> TTensor: +def _aten_roll_shift_no_dim_onnx(self: TTensor, shift: int) -> TTensor: neg_1 = op.Constant(value_ints=[-1]) # flatten the self tensor: from [[A,B],[C,D]] to [A,B,C,D] self_flatten = op.Reshape(self, neg_1) # Compute slice length - shift_tensor = op.Reshape(shift, neg_1) - if shift_tensor < 0: + if shift < 0: # For [A,B,C,D], if shift is -1, slice_length = -(-1) = 1, means move [A] to the end - slice_length = -shift_tensor + slice_length = op.Constant(value_ints=[-shift]) else: # For [A,B,C,D], if shift is 1, slice_length = 4 - 1 = 3, means move [A,B,C] to the end # The effect equals to move [D] to the beginning - slice_length = op.Size(self_flatten) - shift_tensor + slice_length = op.Size(self_flatten) - op.Constant(value_ints=[shift]) # Get second part of the tensor, e.g. [A,B,C] suffix = op.Slice(self_flatten, op.Constant(value_ints=[0]), slice_length) # Get first part of the tensor, e.g. [D] @@ -7613,15 +7313,13 @@ def _aten_roll_shift_no_dim_onnx(self: TTensor, shift: INT64) -> TTensor: return op.Reshape(result, op.Shape(self)) -@torch_op("aten::roll", private=True) -def _aten_roll_shift_and_dim_onnx(self: TTensor, shift: INT64, dim: int) -> TTensor: +def _aten_roll_shift_and_dim_onnx(self: TTensor, shift: int, dim: int) -> TTensor: neg_1 = op.Constant(value_ints=[-1]) - dim_tensor = op.Reshape(op.Constant(value_int=dim), neg_1) - shift_tensor = op.Reshape(shift, neg_1) - if shift_tensor < 0: - slice_length = -shift_tensor + dim_tensor = op.Constant(value_ints=[dim]) + if shift < 0: + slice_length = op.Constant(value_ints=[-shift]) else: - slice_length = op.Gather(op.Shape(self), dim_tensor, axis=0) - shift_tensor + slice_length = op.Shape(self, start=dim, end=dim + 1) - op.Constant(value_ints=[shift]) # from [A,B,C,D] -> [D,A,B,C], [D] is prefix, [A,B,C] is suffix suffix = op.Slice(self, op.Constant(value_ints=[0]), slice_length, axes=dim_tensor) prefix = op.Slice(self, slice_length, op.Reshape(op.Size(self), neg_1), axes=dim_tensor) @@ -7700,7 +7398,7 @@ def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: @torch_op("aten::scalar_tensor", trace_only=True) def aten_scalar_tensor( - s: float, + s: TensorType, dtype: int = FLOAT.dtype, layout: str = "", device: str = "", @@ -7709,8 +7407,7 @@ def aten_scalar_tensor( """scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" if dtype == -1: dtype = FLOAT.dtype - # Set trace_only=True because different if branches return different dtypes - # which is not supported in an ONNX function + return common_ops.cast_to(s, dtype=dtype) @@ -7739,20 +7436,6 @@ def aten_scalar_tensor_complex( return result -@torch_op("aten::scalar_tensor", trace_only=True) -def aten_scalar_tensor_sym_number( - s: TensorType, - dtype: int = FLOAT.dtype, - layout: str = "", - device: str = "", - pin_memory: bool = False, -) -> RealType: - """scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - if dtype == -1: - dtype = FLOAT.dtype - return common_ops.cast_to(s, dtype=dtype) - - @torch_op("aten::scatter.src", trace_only=True) def aten_scatter_src( self: TTensor, @@ -8140,7 +7823,7 @@ def aten_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat: if self_is_scalar: self = op.Unsqueeze(self, op.Constant(value_ints=[0])) result = op.Softmax(self, axis=dim) - if dtype != -1: + if dtype != -1 and dtype is not None: result = op.Cast(result, to=dtype) if self_is_scalar: # Convert to scalar when input is scalar @@ -8149,21 +7832,6 @@ def aten_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat: return result -@torch_op(("aten::softmax.int", "aten::special_softmax"), trace_only=True) -def aten_softmax_no_dtype(self: TFloat, dim: int) -> TFloat: - """softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor""" - - self_is_scalar = len(self.shape) == 0 - if self_is_scalar: - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - result = op.Softmax(self, axis=dim) - if self_is_scalar: - # Convert to scalar when input is scalar - result = op.Squeeze(result) - - return result - - @torch_op("aten::sort", trace_only=True) def aten_sort( self: TReal, dim: int = -1, descending: bool = False, stable: bool = False diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 1a31c9eac8..2a7a46ec28 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -294,20 +294,16 @@ def aten_binary_cross_entropy_backward( @torch_op("aten::celu", trace_only=True) -def aten_celu(self: FLOAT, alpha: float = 1.0) -> FLOAT: +def aten_celu(self: TFloat, alpha: float = 1.0) -> TFloat: """celu(Tensor self, Scalar alpha=1.0) -> Tensor""" - return op.Celu(self, alpha=alpha) # op.Celu only support float32 + if self.dtype != FLOAT.dtype: + self_upcasted = op.Cast(self, to=FLOAT.dtype) + # op.Celu only support float32 + return op.Cast(op.Celu(self_upcasted, alpha=alpha), to=self.dtype) -@torch_op("aten::celu", trace_only=True) -def aten_celu_type_promoted( - self: TFloatUnlessFloat32, alpha: float = 1.0 -) -> TFloatUnlessFloat32: - """celu(Tensor self, Scalar alpha=1.0) -> Tensor""" - - self_upcasted = op.Cast(self, to=FLOAT.dtype) - return op.CastLike(op.Celu(self_upcasted, alpha=alpha), self) + return op.Celu(self, alpha=alpha) @torch_op("aten::col2im", trace_only=True) @@ -1804,7 +1800,7 @@ def aten_scaled_dot_product_attention( query: TFloat, key: TFloat, value: TFloat, - attn_mask: Optional[TFloat] = None, + attn_mask: Optional[TensorType] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, @@ -1854,6 +1850,11 @@ def aten_scaled_dot_product_attention( query, key, value, scale, dropout_p ) + if attn_mask.dtype == ir.DataType.BOOL: + return _aten_scaled_dot_product_attention_bool_mask_onnx( + query, key, value, attn_mask, scale, dropout_p + ) + return _aten_scaled_dot_product_attention_float_mask_onnx( query, key, value, attn_mask, scale, dropout_p ) @@ -1921,7 +1922,6 @@ def aten__scaled_dot_product_flash_attention( ) -@torch_op("aten::_scaled_dot_product_efficient_attention", private=True) def _aten_scaled_dot_product_efficient_attention_fillin_empty_outputs( query: TFloat, compute_log_sumexp: bool, @@ -2016,64 +2016,6 @@ def aten__scaled_dot_product_efficient_attention( ) -@torch_op("aten::scaled_dot_product_attention", trace_only=True) -def aten_scaled_dot_product_attention_bool_mask( - query: TFloat, - key: TFloat, - value: TFloat, - attn_mask: Optional[BOOL] = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: Optional[float] = None, - enable_gqa: bool = False, -) -> TFloat: - """scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> Tensor - - Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html - - Equivalent to the PyTorch code:: - scale_factor = 1 / math.sqrt(Q.size(-1)) if scale is None else scale - attn_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) if is_causal else attn_mask - attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask - attn_weight = torch.softmax((Q @ K.transpose(-2, -1) * scale_factor) + attn_mask, dim=-1) - attn_weight = torch.dropout(attn_weight, dropout_p) - return attn_weight @ V - - where Q, K, V are the query, key, and value tensors, respectively. - L is the target sequence length, S is the source sequence length, and E is the embedding size. - """ - # Use trace_only to handle optional inputs - assert (not is_causal) or (is_causal and attn_mask is None), ( - "is_causal and attn_mask cannot be set at the same time" - ) - assert len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4, ( - "only 4D query, key, and value are supported" - ) - - if scale is None: - scale = _attention_scale(query) - scale = op.CastLike(scale, query) - - if is_causal: - attn_mask = _causal_attention_mask(query, key) - # The causal mask is always float - return _aten_scaled_dot_product_attention_float_mask_onnx( - query, key, value, attn_mask, scale, dropout_p - ) - - if enable_gqa: - key, value = _attention_repeat_kv_for_group_query(query, key, value) - - if attn_mask is None: - return _aten_scaled_dot_product_attention_no_mask_onnx( - query, key, value, scale, dropout_p - ) - - return _aten_scaled_dot_product_attention_bool_mask_onnx( - query, key, value, attn_mask, scale, dropout_p - ) - - def _aten_scaled_dot_product_attention_no_mask_onnx( query: TFloat, key: TFloat, diff --git a/requirements/ci/requirements-ort-nightly.txt b/requirements/ci/requirements-ort-nightly.txt index 4ed908b4e2..b54550738b 100644 --- a/requirements/ci/requirements-ort-nightly.txt +++ b/requirements/ci/requirements-ort-nightly.txt @@ -1,3 +1,3 @@ # https://aiinfra.visualstudio.com/PublicPackages/_artifacts/feed/ORT-Nightly/PyPI/onnxruntime/overview --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ -onnxruntime==1.23.0.dev20250517001 +onnxruntime==1.23.0.dev20251001001 diff --git a/tests/function_libs/torch_lib/ops_test.py b/tests/function_libs/torch_lib/ops_test.py index 45875043ea..a45050fb22 100644 --- a/tests/function_libs/torch_lib/ops_test.py +++ b/tests/function_libs/torch_lib/ops_test.py @@ -99,7 +99,7 @@ def _should_skip_xfail_test_sample( class TestFunctionValidity(unittest.TestCase): @parameterized.parameterized.expand( - [(info.op.name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS] + [(info.op_info_name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS] ) def test_script_function_passes_checker( self, _, torchlib_op_info: ops_test_data.TorchLibOpInfo @@ -110,10 +110,12 @@ def test_script_function_passes_checker( onnx.checker.check_function(function_proto) # type: ignore[attr-defined] @parameterized.parameterized.expand( - [(info.op.name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS] + [(info.op_info_name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS] ) def test_function_has_op_schema(self, _, torchlib_op_info: ops_test_data.TorchLibOpInfo): func = torchlib_op_info.op + if not hasattr(func, "op_schema"): + raise AssertionError(f"Function {func.__name__} does not have op_schema attribute") schema = func.op_schema self.assertIsNotNone(schema) self.assertEqual(schema.name, func.name) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index c8d0bf5786..b60fd8cf31 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -48,7 +48,6 @@ from torch.testing._internal.opinfo import definitions as opinfo_definitions from typing_extensions import Self -from onnxscript._internal import version_utils from onnxscript.function_libs.torch_lib import _flags from onnxscript.function_libs.torch_lib.ops import core as core_ops from onnxscript.function_libs.torch_lib.ops import fft as fft_ops @@ -459,40 +458,13 @@ def _where_input_wrangler( fft_ops.aten__fft_r2c, tolerance={torch.float64: (2e-6, 2e-6), torch.float32: (3e-2, 3e-4)}, ), + TorchLibOpInfo("ops.aten._local_scalar_dense", core_ops.aten__local_scalar_dense), TorchLibOpInfo( - "ops.aten._local_scalar_dense", - core_ops.aten__local_scalar_dense, - ), - TorchLibOpInfo("ops.aten._log_softmax", core_ops.aten__log_softmax), - TorchLibOpInfo( - "ops.aten._log_softmax_half", - core_ops.aten__log_softmax_half, + "ops.aten._log_softmax", + core_ops.aten__log_softmax, tolerance={torch.float16: (1e-3, 1e-3)}, - ) - .xfail( - reason="PyTorch does not implement _log_softmax for float16 on CPU", - dtypes=(torch.float16,), - enabled_if=version_utils.torch_older_than("2.2"), - ) - .xfail( - enabled_if=version_utils.onnxruntime_older_than("1.17"), - dtypes=(torch.float16,), - reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", - test_class_name="TestOutputConsistencyFullGraph", ), TorchLibOpInfo("ops.aten._softmax", core_ops.aten__softmax), - TorchLibOpInfo("ops.aten._softmax_half", core_ops.aten__softmax_half) - .xfail( - reason="PyTorch does not implement _softmax for float16 on CPU", - dtypes=(torch.float16,), - enabled_if=version_utils.torch_older_than("2.2"), - ) - .xfail( - enabled_if=version_utils.onnxruntime_older_than("1.17"), - dtypes=(torch.float16,), - reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", - test_class_name="TestOutputConsistencyFullGraph", - ), TorchLibOpInfo("all_dim", core_ops.aten_all_dim).skip( matcher=lambda sample: not (len(sample.kwargs) > 0) or isinstance(sample.kwargs.get("dim"), tuple), @@ -503,10 +475,7 @@ def _where_input_wrangler( reason="this overload requires dim to be a tuple", ), TorchLibOpInfo("allclose", core_ops.aten_allclose), - TorchLibOpInfo( - "all", - core_ops.aten_all, - ).skip( + TorchLibOpInfo("all", core_ops.aten_all).skip( matcher=lambda sample: len(sample.kwargs) != 0, reason="this Aten overload only support one tensor as input by design", ), @@ -541,32 +510,14 @@ def _where_input_wrangler( reason="zero sized inputs cannot be compared", ), TorchLibOpInfo("addmv", core_ops.aten_addmv, tolerance={torch.float16: (2e-3, 2e-2)}), - TorchLibOpInfo( - "addr", - core_ops.aten_addr, - tolerance={torch.float16: (3e-3, 4e-3)}, - ), - TorchLibOpInfo( - "amax", - core_ops.aten_amax, - input_wrangler=_amin_amax_input_wrangler, - ), - TorchLibOpInfo( - "amin", - core_ops.aten_amin, - input_wrangler=_amin_amax_input_wrangler, - ), - TorchLibOpInfo( - "any", - core_ops.aten_any, - ).skip( + TorchLibOpInfo("addr", core_ops.aten_addr, tolerance={torch.float16: (3e-3, 4e-3)}), + TorchLibOpInfo("amax", core_ops.aten_amax, input_wrangler=_amin_amax_input_wrangler), + TorchLibOpInfo("amin", core_ops.aten_amin, input_wrangler=_amin_amax_input_wrangler), + TorchLibOpInfo("any", core_ops.aten_any).skip( matcher=lambda sample: len(sample.kwargs) != 0, reason="this Aten overload only support one tensor as input by design", ), - TorchLibOpInfo( - "any_dim", - core_ops.aten_any_dim, - ).skip( + TorchLibOpInfo("any_dim", core_ops.aten_any_dim).skip( matcher=lambda sample: not (len(sample.kwargs) > 0) or isinstance(sample.kwargs.get("dim"), tuple), reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer", @@ -584,76 +535,46 @@ def _where_input_wrangler( matcher=lambda sample: isinstance(sample.input, (list, tuple)), reason="takes single tensor as input", ), - TorchLibOpInfo( - "atleast_1d_Sequence", - core_ops.aten_atleast_1d_sequence, - ) + TorchLibOpInfo("atleast_1d_Sequence", core_ops.aten_atleast_1d_sequence) .skip( matcher=lambda sample: not isinstance(sample.input, (list, tuple)), reason="takes tensor sequences only", ) - .xfail( - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason=( - "fixme: [ONNXRuntimeError] : 1 : FAIL : This is an invalid model. Error: Duplicate definition of name (_0x9370ed0_rank)." - "https://github.com/microsoft/onnxscript/issues/960" - ), - ) .xfail( reason=( "fixme: ORT shape inference failed." "https://github.com/microsoft/onnxscript/issues/1007" - ), + ) ), TorchLibOpInfo("atleast_2d", core_ops.aten_atleast_2d).skip( matcher=lambda sample: isinstance(sample.input, (list, tuple)), reason="takes single tensor as input", ), - TorchLibOpInfo( - "atleast_2d_Sequence", - core_ops.aten_atleast_2d_sequence, - ) + TorchLibOpInfo("atleast_2d_Sequence", core_ops.aten_atleast_2d_sequence) .skip( matcher=lambda sample: not isinstance(sample.input, (list, tuple)), reason="takes tensor sequences only", ) - .xfail( - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason=( - "fixme: [ONNXRuntimeError] : 1 : FAIL : This is an invalid model. Error: Duplicate definition of name (_0x9370ed0_rank)." - "https://github.com/microsoft/onnxscript/issues/960" - ), - ) .xfail( reason=( "fixme: ORT shape inference failed." "https://github.com/microsoft/onnxscript/issues/1007" - ), + ) ), TorchLibOpInfo("atleast_3d", core_ops.aten_atleast_3d).skip( matcher=lambda sample: isinstance(sample.input, (list, tuple)), reason="takes single tensor as input", ), - TorchLibOpInfo( - "atleast_3d_Sequence", - core_ops.aten_atleast_3d_sequence, - ) + TorchLibOpInfo("atleast_3d_Sequence", core_ops.aten_atleast_3d_sequence) .skip( matcher=lambda sample: not isinstance(sample.input, (list, tuple)), reason="takes tensor sequences only", ) - .xfail( - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason=( - "fixme: [ONNXRuntimeError] : 1 : FAIL : This is an invalid model. Error: Duplicate definition of name (_0x9370ed0_rank)." - "https://github.com/microsoft/onnxscript/issues/960" - ), - ) .xfail( reason=( "fixme: ORT shape inference failed." "https://github.com/microsoft/onnxscript/issues/1007" - ), + ) ), TorchLibOpInfo("baddbmm", core_ops.aten_baddbmm, tolerance={torch.float16: (1e-3, 1e-2)}), TorchLibOpInfo("bernoulli", core_ops.aten_bernoulli, nondeterministic=True), @@ -671,16 +592,10 @@ def _where_input_wrangler( ), TorchLibOpInfo("ops.aten.bernoulli.p_deterministic", core_ops.aten_bernoulli_p), TorchLibOpInfo("bitwise_and", core_ops.aten_bitwise_and), - TorchLibOpInfo("bitwise_left_shift_int16", core_ops.aten_bitwise_left_shift_int16), - TorchLibOpInfo("bitwise_left_shift_int32", core_ops.aten_bitwise_left_shift_int32), - TorchLibOpInfo("bitwise_left_shift_int64", core_ops.aten_bitwise_left_shift_int64), - TorchLibOpInfo("bitwise_left_shift_int8", core_ops.aten_bitwise_left_shift_int8), + TorchLibOpInfo("bitwise_left_shift", core_ops.aten_bitwise_left_shift), TorchLibOpInfo("bitwise_not", core_ops.aten_bitwise_not), TorchLibOpInfo("bitwise_or", core_ops.aten_bitwise_or), - TorchLibOpInfo("bitwise_right_shift_int16", core_ops.aten_bitwise_right_shift_int16), - TorchLibOpInfo("bitwise_right_shift_int32", core_ops.aten_bitwise_right_shift_int32), - TorchLibOpInfo("bitwise_right_shift_int64", core_ops.aten_bitwise_right_shift_int64), - TorchLibOpInfo("bitwise_right_shift_int8", core_ops.aten_bitwise_right_shift_int8), + TorchLibOpInfo("bitwise_right_shift", core_ops.aten_bitwise_right_shift), TorchLibOpInfo("bitwise_xor", core_ops.aten_bitwise_xor), TorchLibOpInfo("ops.aten.blackman_window", core_ops.aten_blackman_window), TorchLibOpInfo("bmm", core_ops.aten_bmm), @@ -698,10 +613,7 @@ def _where_input_wrangler( reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619", ), TorchLibOpInfo("ceil", core_ops.aten_ceil), - TorchLibOpInfo("chunk", core_ops.aten_chunk).skip( - enabled_if=version_utils.torch_older_than("2.7"), - reason="Test for chunk is not configured for torch<2.7", - ), + TorchLibOpInfo("chunk", core_ops.aten_chunk), TorchLibOpInfo("clamp_max", core_ops.aten_clamp_max_tensor).skip( reason="Size 0 inputs are not handled by design", matcher=lambda sample: sample.input.numel() == 0, @@ -737,7 +649,6 @@ def _where_input_wrangler( TorchLibOpInfo("deg2rad", core_ops.aten_deg2rad), # TorchLibOpInfo("detach", core_ops.aten_detach), # detach is not in OP-TEST-DB TorchLibOpInfo("diagonal", core_ops.aten_diagonal), - TorchLibOpInfo("diagonal_bool", core_ops.aten_diagonal_bool), TorchLibOpInfo("div", core_ops.aten_div).skip( matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None, reason="this variation does not take the rounding_mode argument", @@ -755,7 +666,6 @@ def _where_input_wrangler( # Numbers match sometimes but not other times reason="fixme: off-by-one. https://github.com/microsoft/onnxscript/issues/990", ), - TorchLibOpInfo("div_mode_int", core_ops.aten_div_mode_int), TorchLibOpInfo("dot", core_ops.aten_dot), TorchLibOpInfo( "empty", @@ -765,8 +675,7 @@ def _where_input_wrangler( ), TorchLibOpInfo("einsum", core_ops.aten_einsum, input_wrangler=_einsum_input_wrangler) .xfail( - reason="fixme: PyTorch produces int64 output with int32 input", - dtypes=(torch.int32,), + reason="fixme: PyTorch produces int64 output with int32 input", dtypes=(torch.int32,) ) .xfail( reason="fixme: ONNX shape inference fails: https://github.com/onnx/onnx/issues/5739", @@ -800,21 +709,15 @@ def _where_input_wrangler( TorchLibOpInfo("fmod", core_ops.aten_fmod), TorchLibOpInfo("frac", core_ops.aten_frac), TorchLibOpInfo("full", core_ops.aten_full), - TorchLibOpInfo( - "full_like", - core_ops.aten_full_like, - ).skip( - enabled_if=ops_test_common.IS_MACOS, - reason="fixme: memory allocation issue on CI", + TorchLibOpInfo("full_like", core_ops.aten_full_like).skip( + enabled_if=ops_test_common.IS_MACOS, reason="fixme: memory allocation issue on CI" ), TorchLibOpInfo("gather", core_ops.aten_gather).skip( matcher=lambda sample: sample.input.numel() == 0 or sample.args[1].numel() == 0, reason="fixme: ORT does not support empty tensors as input", ), TorchLibOpInfo("ge", core_ops.aten_ge), - TorchLibOpInfo("ge_bool", core_ops.aten_ge_bool), TorchLibOpInfo("gt", core_ops.aten_gt), - TorchLibOpInfo("gt_bool", core_ops.aten_gt_bool), # TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB # TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index), @@ -828,9 +731,7 @@ def _where_input_wrangler( reason="this Aten overload only supports tensor(bool) as indices", ), TorchLibOpInfo( - "index_put", - core_ops.aten_index_put, - input_wrangler=_index_put_input_wrangler, + "index_put", core_ops.aten_index_put, input_wrangler=_index_put_input_wrangler ) .skip( matcher=lambda sample: sample.args[0][0].dtype != torch.int64, @@ -870,20 +771,13 @@ def _where_input_wrangler( dtypes=(torch.int64, torch.int32), reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", ) - .xfail( - variant_name="tensor_overload", - dtypes=(torch.int64, torch.int32), + .skip( + matcher=lambda sample: sample.kwargs.get("dtype") in (torch.int64, torch.int32), reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", - enabled_if=not version_utils.torch_older_than("2.2"), ), TorchLibOpInfo("log", core_ops.aten_log), TorchLibOpInfo("le", core_ops.aten_le), - TorchLibOpInfo("le_bool", core_ops.aten_le_bool), - TorchLibOpInfo( - "lerp", - core_ops.aten_lerp, - tolerance={torch.float16: (2e-3, 2e-1)}, - ), + TorchLibOpInfo("lerp", core_ops.aten_lerp, tolerance={torch.float16: (2e-3, 2e-1)}), TorchLibOpInfo("log10", core_ops.aten_log10), TorchLibOpInfo("log1p", core_ops.aten_log1p), TorchLibOpInfo( @@ -922,7 +816,6 @@ def _where_input_wrangler( TorchLibOpInfo("logdet", core_ops.aten_logdet), TorchLibOpInfo("logsumexp", core_ops.aten_logsumexp), TorchLibOpInfo("lt", core_ops.aten_lt), - TorchLibOpInfo("lt_bool", core_ops.aten_lt_bool), TorchLibOpInfo("masked_fill", core_ops.aten_masked_fill).xfail( dtypes=(torch.bool,), reason="fixme: ORT does not have an implementation for Where with bool inputs.", @@ -938,19 +831,12 @@ def _where_input_wrangler( reason="values of matmul of [m, 0] and [0, n] matrices are undefined", ), TorchLibOpInfo("maximum", core_ops.aten_maximum), - TorchLibOpInfo("maximum_bool", core_ops.aten_maximum_bool), - TorchLibOpInfo( - "mean", - core_ops.aten_mean, - input_wrangler=_mean_input_wrangler, - ).skip( + TorchLibOpInfo("mean", core_ops.aten_mean, input_wrangler=_mean_input_wrangler).skip( matcher=lambda sample: sample.kwargs.get("dim") is not None, reason="this Aten overload only accept 1 inputs: self", ), TorchLibOpInfo( - "mean_dim", - core_ops.aten_mean_dim, - input_wrangler=_mean_input_wrangler, + "mean_dim", core_ops.aten_mean_dim, input_wrangler=_mean_input_wrangler ).skip( matcher=lambda sample: sample.kwargs.get("dim") is None, reason="this Aten overload can accept 2 inputs:(self, dim)", @@ -962,15 +848,11 @@ def _where_input_wrangler( or (len(sample.args) > 0 and not isinstance(sample.args[0], int)), reason="this ATen overload only support one tensor as input and another int as args", ), - TorchLibOpInfo( - "min", - core_ops.aten_min, - ).skip( + TorchLibOpInfo("min", core_ops.aten_min).skip( matcher=lambda sample: len(sample.args) > 0, reason="this ATen overload only supports one tensor as input by design", ), TorchLibOpInfo("minimum", core_ops.aten_minimum), - TorchLibOpInfo("minimum_bool", core_ops.aten_minimum_bool), TorchLibOpInfo("mm", core_ops.aten_mm).skip( matcher=lambda sample: torch.numel(sample.input) == 0, reason="values of matmul of [m, 0] and [0, n] matrices are undefined", @@ -979,39 +861,19 @@ def _where_input_wrangler( TorchLibOpInfo("mT", core_ops.aten_mT_complex, complex=True), TorchLibOpInfo("mul", core_ops.aten_mul), TorchLibOpInfo("mul", core_ops.aten_mul_complex, complex=True), - TorchLibOpInfo( - "mv", - core_ops.aten_mv, - tolerance={torch.float16: (3e-2, 1e-2)}, - ), + TorchLibOpInfo("mv", core_ops.aten_mv, tolerance={torch.float16: (3e-2, 1e-2)}), TorchLibOpInfo("narrow", core_ops.aten_narrow), TorchLibOpInfo("ops.aten.native_dropout", core_ops.aten_native_dropout), TorchLibOpInfo("ne", core_ops.aten_ne), TorchLibOpInfo("neg", core_ops.aten_neg), + TorchLibOpInfo("new_empty", core_ops.aten_new_empty, nondeterministic=True), TorchLibOpInfo( - "new_empty", - core_ops.aten_new_empty, - nondeterministic=True, - ), - TorchLibOpInfo( - "new_empty_strided", - core_ops.aten_new_empty_strided, - nondeterministic=True, - ), - TorchLibOpInfo( - "new_full", - core_ops.aten_new_full, - ), - TorchLibOpInfo( - "new_ones", - core_ops.aten_new_ones, - ), - TorchLibOpInfo( - "new_zeros", - core_ops.aten_new_zeros, + "new_empty_strided", core_ops.aten_new_empty_strided, nondeterministic=True ), + TorchLibOpInfo("new_full", core_ops.aten_new_full), + TorchLibOpInfo("new_ones", core_ops.aten_new_ones), + TorchLibOpInfo("new_zeros", core_ops.aten_new_zeros), TorchLibOpInfo("nn.functional.celu", nn_ops.aten_celu), - TorchLibOpInfo("nn.functional.celu_type_promoted", nn_ops.aten_celu_type_promoted), TorchLibOpInfo( "nn.functional.cross_entropy", # use cross_entropy as test case instead of cross_entropy_loss (not in OPS_DB) @@ -1024,9 +886,7 @@ def _where_input_wrangler( reason="ONNX SoftmaxCrossEntropyLoss op only accept argument[target] as int type", ), TorchLibOpInfo( - "nn.functional.dropout", - core_ops.aten_dropout, - input_wrangler=_dropout_input_wrangler, + "nn.functional.dropout", core_ops.aten_dropout, input_wrangler=_dropout_input_wrangler ).skip( matcher=lambda sample: len(sample.kwargs) == 0 or sample.kwargs.get("p", 0.0) > 0.0, reason="dropout is random so the result not match", @@ -1037,10 +897,7 @@ def _where_input_wrangler( core_ops.aten_embedding_bag, tolerance={torch.float32: (1e-4, 5e-4)}, compare_shape_only_for_output=(1, 2, 3), - ).skip( - dtypes=(torch.float16,), - reason="fixme: results mismatch in torch nightly.", - ), + ).skip(dtypes=(torch.float16,), reason="fixme: results mismatch in torch nightly."), TorchLibOpInfo( "ops.aten.embedding_bag.padding_idx", core_ops.aten_embedding_bag_padding_idx, @@ -1075,10 +932,7 @@ def _where_input_wrangler( tolerance={torch.float16: (5e-2, 1e-2)}, ), TorchLibOpInfo("nn.functional.pad", nn_ops.aten_pad) - .skip( - variant_name="circular", - reason="fixme: ORT does not support the circular mode", - ) + .skip(variant_name="circular", reason="fixme: ORT does not support the circular mode") .skip( variant_name="replicate_negative", reason="fixme: The implementation for negative paddings is not correct", @@ -1100,10 +954,7 @@ def _where_input_wrangler( TorchLibOpInfo( "ops.aten.reflection_pad1d", nn_ops.aten_reflection_pad1d, - ).xfail( - dtypes=(torch.int64,), - reason="Torch not implement reflection_pad1d for int64.", - ), + ).xfail(dtypes=(torch.int64,), reason="Torch not implement reflection_pad1d for int64."), TorchLibOpInfo( "nn.functional.reflection_pad2d", nn_ops.aten_reflection_pad2d, @@ -1112,26 +963,9 @@ def _where_input_wrangler( matcher=lambda sample: not (len(sample.args) > 1 and sample.args[1] == "reflect"), reason="this Aten overload need args[1] == 'reflect' for pad mode", ), - TorchLibOpInfo( - "nn.functional.relu", - nn_ops.aten_relu, - ).xfail( - dtypes=(torch.int64,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: ORT did not implement Relu for int64. https://github.com/microsoft/onnxruntime/issues/16654", - ), - TorchLibOpInfo( - "nn.functional.relu6", - nn_ops.aten_relu6, - ).xfail( - dtypes=(torch.int64,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: ORT did not implement Relu for int64. https://github.com/microsoft/onnxruntime/issues/16654", - ), - TorchLibOpInfo( - "ops.aten.replication_pad1d", - nn_ops.aten_replication_pad1d, - ), + TorchLibOpInfo("nn.functional.relu", nn_ops.aten_relu), + TorchLibOpInfo("nn.functional.relu6", nn_ops.aten_relu6), + TorchLibOpInfo("ops.aten.replication_pad1d", nn_ops.aten_replication_pad1d), TorchLibOpInfo( "nn.functional.replication_pad2d", nn_ops.aten_replication_pad2d, @@ -1141,10 +975,9 @@ def _where_input_wrangler( matcher=lambda sample: not (len(sample.args) > 1 and sample.args[1] == "replicate"), reason="this Aten overload need args[1] == 'replicate' for pad mode", ) - .xfail( + .skip( variant_name="replicate_negative", - enabled_if=not version_utils.torch_older_than("2.2"), - reason="fixme: negative padding is not implemented yet", + reason="fixme: The implementation for negative paddings is not correct. Potentially an ORT issue", ), TorchLibOpInfo( "nn.functional.replication_pad3d", @@ -1160,15 +993,9 @@ def _where_input_wrangler( ), TorchLibOpInfo("nn.functional.selu", core_ops.aten_selu), TorchLibOpInfo( - "nn.functional.mse_loss", - nn_ops.aten_mse_loss, - input_wrangler=_mse_loss_input_wrangler, + "nn.functional.mse_loss", nn_ops.aten_mse_loss, input_wrangler=_mse_loss_input_wrangler ), - TorchLibOpInfo( - "nonzero", - core_ops.aten_nonzero, - input_wrangler=_nonzero_input_wrangler, - ) + TorchLibOpInfo("nonzero", core_ops.aten_nonzero, input_wrangler=_nonzero_input_wrangler) .xfail( matcher=lambda sample: sample.kwargs.get("as_tuple"), reason="as_tuple=True is not supported", @@ -1231,26 +1058,19 @@ def _where_input_wrangler( nondeterministic=True, ), TorchLibOpInfo("ops.aten.randn", core_ops.aten_randn, nondeterministic=True).xfail( - dtypes=(torch.float16,), - reason="fixme: Shape inference error", + dtypes=(torch.float16,), reason="fixme: Shape inference error" ), TorchLibOpInfo("ops.aten.randn_like", core_ops.aten_randn_like, nondeterministic=True), TorchLibOpInfo("rad2deg", core_ops.aten_rad2deg), TorchLibOpInfo("reciprocal", core_ops.aten_reciprocal), - TorchLibOpInfo( - "remainder", - core_ops.aten_remainder, - ), + TorchLibOpInfo("remainder", core_ops.aten_remainder), TorchLibOpInfo("repeat", core_ops.aten_repeat), TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_self_int) .skip( matcher=lambda sample: not isinstance(sample.kwargs.get("repeats", None), int), reason=("ignore cases when repeasts is a Tensor"), ) - .skip( - dtypes=(torch.bool,), - reason="bool not supported", - ) + .skip(dtypes=(torch.bool,), reason="bool not supported") .skip( matcher=lambda sample: sample.kwargs.get("dim") is None, reason="fixme: conversion not implemented if dim is None", @@ -1264,10 +1084,7 @@ def _where_input_wrangler( matcher=lambda sample: isinstance(sample.kwargs.get("repeats", None), int), reason=("ignore cases when repeasts is an int"), ) - .skip( - dtypes=(torch.bool,), - reason="bool not supported", - ) + .skip(dtypes=(torch.bool,), reason="bool not supported") .skip( matcher=lambda sample: sample.kwargs.get("dim") is None, reason="fixme: conversion not implemented if dim is None", @@ -1297,14 +1114,9 @@ def _where_input_wrangler( complex=True, ), TorchLibOpInfo( - "ops.aten.scalar_tensor", - core_ops.aten_scalar_tensor_complex, - complex=True, + "ops.aten.scalar_tensor", core_ops.aten_scalar_tensor_complex, complex=True ), - TorchLibOpInfo( - "scatter_add", - core_ops.aten_scatter_add, - ) + TorchLibOpInfo("scatter_add", core_ops.aten_scatter_add) .xfail( matcher=lambda sample: len(sample.input.shape) == 0, reason="fixme: Rank(0) input will lead ORT failed due to different rank(result) in if-else branch. https://github.com/onnx/onnx/issues/4986", @@ -1353,48 +1165,10 @@ def _where_input_wrangler( dtypes=(torch.float16,), reason="fixme: Tensor-likes are not close. Tests pass for float32.", ), - TorchLibOpInfo( - "split_with_sizes", - core_ops.aten_split_with_sizes, - ) - .xfail( - dtypes=(torch.float16,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: ORT failed to produce the correct argument type: https://github.com/microsoft/onnxruntime/issues/16006", - ) - .xfail( - dtypes=(torch.bool,), - reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", - ), - TorchLibOpInfo( - "split", - core_ops.aten_split, - ) - .xfail( - dtypes=(torch.float16,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: ORT failed to produce the correct argument type: https://github.com/microsoft/onnxruntime/issues/16006", - ) - .xfail( - variant_name="list_args", - dtypes=(torch.float16,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: ORT failed to produce the correct argument type: https://github.com/microsoft/onnxruntime/issues/16006", - ) - .xfail( - dtypes=(torch.bool,), - reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", - ) - .xfail( - variant_name="list_args", - dtypes=(torch.bool,), - reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", - ), + TorchLibOpInfo("split_with_sizes", core_ops.aten_split_with_sizes), + TorchLibOpInfo("split", core_ops.aten_split), TorchLibOpInfo("sqrt", core_ops.aten_sqrt), - TorchLibOpInfo( - "squeeze_dim", - core_ops.aten_squeeze_dim, - ) + TorchLibOpInfo("squeeze_dim", core_ops.aten_squeeze_dim) .skip( matcher=lambda sample: not (len(sample.args) > 0 and isinstance(sample.args[0], int)), reason="this Aten overload only support one tensor as input and one int as args by design", @@ -1404,11 +1178,7 @@ def _where_input_wrangler( and sample.input.shape[sample.args[0]] != 1, reason="this Aten overload only support squeeze dim with size 1", ), - TorchLibOpInfo( - "squeeze_dim", - core_ops.aten_squeeze_dim_complex, - complex=True, - ) + TorchLibOpInfo("squeeze_dim", core_ops.aten_squeeze_dim_complex, complex=True) .skip( matcher=lambda sample: not (len(sample.args) > 0 and isinstance(sample.args[0], int)), reason="this Aten overload only support one tensor as input and one int as args by design", @@ -1418,10 +1188,7 @@ def _where_input_wrangler( and sample.input.shape[sample.args[0]] != 1, reason="this Aten overload only support squeeze dim with size 1", ), - TorchLibOpInfo( - "squeeze", - core_ops.aten_squeeze, - ).skip( + TorchLibOpInfo("squeeze", core_ops.aten_squeeze).skip( matcher=lambda sample: len(sample.args) != 0, reason="this Aten overload only support one tensor as input by design", ), @@ -1430,20 +1197,14 @@ def _where_input_wrangler( TorchLibOpInfo("sub", core_ops.aten_sub, tolerance={torch.float16: (2e-3, 1e-3)}), TorchLibOpInfo("sub", core_ops.aten_sub_complex, complex=True), # TorchLibOpInfo("sym_size", core_ops.aten_sym_size), # no test case in OPS_DB - TorchLibOpInfo( - "t", - core_ops.aten_t, - ).xfail( + TorchLibOpInfo("t", core_ops.aten_t).xfail( enabled_if=not _flags.EXPERIMENTAL_PREFER_TRACING, reason="fixme: ORT Graph attribute inferencing failed on rank-1 input. https://github.com/onnx/onnx/issues/4986", test_class_name="TestOutputConsistencyFullGraph", ), TorchLibOpInfo("tan", core_ops.aten_tan), TorchLibOpInfo("tanh", core_ops.aten_tanh), - TorchLibOpInfo( - "tile", - core_ops.aten_tile, - ).skip( + TorchLibOpInfo("tile", core_ops.aten_tile).skip( matcher=lambda sample: any(dim == 0 for dim in sample.input.shape) or not sample.input.shape, reason="fixme: Logic not implemented for size 0 inputs in op.Reshape", @@ -1471,20 +1232,7 @@ def _where_input_wrangler( reason="fixme: ORT does not have an implementation of Trilu for int32.", ), TorchLibOpInfo("trunc", core_ops.aten_trunc), - TorchLibOpInfo( - "unbind", - core_ops.aten_unbind, - ) - .xfail( - dtypes=(torch.float16,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: SplitToSequence op inference failed. https://github.com/microsoft/onnxruntime/issues/16006", - ) - .xfail( - enabled_if=version_utils.torch_older_than("2.7"), - dtypes=(torch.bool,), - reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", - ), + TorchLibOpInfo("unbind", core_ops.aten_unbind), TorchLibOpInfo("unflatten", core_ops.aten_unflatten), TorchLibOpInfo("unfold", core_ops.aten_unfold), TorchLibOpInfo("ops.aten.unfold", core_ops.aten_unfold), @@ -1503,10 +1251,7 @@ def _where_input_wrangler( ), TorchLibOpInfo("xlogy", special_ops.aten_special_xlogy), TorchLibOpInfo("zeros", core_ops.aten_zeros), - TorchLibOpInfo( - "arange_start_step", - core_ops.aten_arange_start_step, - ) + TorchLibOpInfo("arange_start_step", core_ops.aten_arange_start_step) .skip( matcher=lambda sample: len(sample.args) != 2, reason="arange_start_step overload takes three arguments (input, start, step)", @@ -1516,10 +1261,7 @@ def _where_input_wrangler( reason="dtype needs to be specified for non-float tensors", dtypes=(torch.float16, torch.int64, torch.int32), ), - TorchLibOpInfo( - "arange_start", - core_ops.aten_arange_start, - ) + TorchLibOpInfo("arange_start", core_ops.aten_arange_start) .skip( matcher=lambda sample: len(sample.args) != 1, reason="arange_start overload takes two arguments (input, start)", @@ -1529,10 +1271,7 @@ def _where_input_wrangler( reason="dtype needs to be specified for non-float tensors", dtypes=(torch.float16, torch.int64, torch.int32), ), - TorchLibOpInfo( - "arange", - core_ops.aten_arange, - ) + TorchLibOpInfo("arange", core_ops.aten_arange) .xfail( dtypes=(torch.int32,), reason="fixme: output shape mismatch in edge cases. https://github.com/microsoft/onnxscript/issues/974", @@ -1555,10 +1294,7 @@ def _where_input_wrangler( TorchLibOpInfo( "as_strided", core_ops.aten_as_strided, - ).xfail( - variant_name="partial_views", - reason="ONNX doesn't have partial view for tensor", - ), + ).xfail(variant_name="partial_views", reason="ONNX doesn't have partial view for tensor"), TorchLibOpInfo("clamp", core_ops.aten_clamp_tensor), TorchLibOpInfo( "ops.aten.col2im", @@ -1578,19 +1314,13 @@ def _where_input_wrangler( tolerance={torch.float32: (2e-4, 9e-4)}, ), TorchLibOpInfo("empty_like", core_ops.aten_empty_like, nondeterministic=True), - TorchLibOpInfo( - "grid_sampler_2d", - core_ops.aten_grid_sampler_2d, - ) + TorchLibOpInfo("grid_sampler_2d", core_ops.aten_grid_sampler_2d) .skip( # Torch implemented this using the cubic convolution algorithm with alhpa=-0.75, might be different than ORT matcher=lambda sample: sample.args[1] == 2, reason="fixme: 'bicubic' mode in ORT implemented differently with Torch", ) - .skip( - dtypes=(torch.float16,), - reason="fixme: Accuracy is not high enough", - ), + .skip(dtypes=(torch.float16,), reason="fixme: Accuracy is not high enough"), TorchLibOpInfo( "nn.functional.group_norm", nn_ops.aten_group_norm, @@ -1651,10 +1381,7 @@ def _where_input_wrangler( or (len(sample.args) > 0 and not isinstance(sample.args[0], int)), reason="this ATen overload only support one tensor as input and another int as args", ), - TorchLibOpInfo( - "max", - core_ops.aten_max, - ).skip( + TorchLibOpInfo("max", core_ops.aten_max).skip( matcher=lambda sample: len(sample.args) > 0, reason="this ATen overload only supports one tensor as input by design", ), @@ -1712,8 +1439,7 @@ def _where_input_wrangler( reason="fixme: ORT only supports BatchNorm less than opset14", ), TorchLibOpInfo( - "ops.aten._native_batch_norm_legit.no_stats", - core_ops.aten__native_batch_norm_no_stats, + "ops.aten._native_batch_norm_legit.no_stats", core_ops.aten__native_batch_norm_no_stats ), TorchLibOpInfo( "ops.aten._native_batch_norm_legit_functional", @@ -1734,10 +1460,6 @@ def _where_input_wrangler( "ops.aten.native_group_norm", core_ops.aten_native_group_norm, tolerance={torch.float16: (1e-2, 7e-3)}, - ).xfail( - dtypes=(torch.float16,), - reason="fixme: 'GroupNormKernelImpl' not implemented for 'Half' in nightly and weekly", - enabled_if=version_utils.torch_older_than("2.2"), ), TorchLibOpInfo( "native_layer_norm", @@ -1819,9 +1541,7 @@ def _where_input_wrangler( tolerance={torch.float16: (1e-2, 1e-3)}, ), TorchLibOpInfo( - "ops.aten.conv3d", - core_ops.aten_conv3d, - tolerance={torch.float32: (3.7e-5, 1.8e-4)}, + "ops.aten.conv3d", core_ops.aten_conv3d, tolerance={torch.float32: (3.7e-5, 1.8e-4)} ), TorchLibOpInfo("nn.functional.gelu", nn_ops.aten_gelu), TorchLibOpInfo("nn.functional.glu", nn_ops.aten_glu), @@ -1902,11 +1622,6 @@ def _where_input_wrangler( nn_ops.aten_scaled_dot_product_attention, tolerance={torch.float32: (3e-4, 1.5e-5)}, ) - .skip( - matcher=lambda sample: (attn_mask := sample.kwargs.get("attn_mask")) is not None - and attn_mask.dtype == torch.bool, - reason="this overload takes a non-boolean mask", - ) .skip( matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0, reason="dropout is random so the results do not match", @@ -1929,15 +1644,7 @@ def _where_input_wrangler( # Output[0] is OK, but other outputs just have the same shape with zero values nondeterministic=True, compare_shape_only_for_output=(1, 2, 3, 4, 5, 6, 7, 8), - ) - .skip( - enabled_if=version_utils.torch_older_than("2.1"), - reason="The operator is not supported in older version.", - ) - .skip( - device_type="cpu", - reason="_scaled_dot_product_flash_attention only supports CUDA", - ), + ).skip(device_type="cpu", reason="_scaled_dot_product_flash_attention only supports CUDA"), TorchLibOpInfo( "ops.aten._scaled_dot_product_efficient_attention", nn_ops.aten__scaled_dot_product_efficient_attention, @@ -1945,40 +1652,10 @@ def _where_input_wrangler( # Output[0] is OK, but other outputs just have the same shape with zero values nondeterministic=True, compare_shape_only_for_output=(1, 2, 3), - ) - .skip( - enabled_if=version_utils.torch_older_than("2.1"), - reason="The operator is not supported in older version.", - ) - .skip( + ).skip( enabled_if=not torch.cuda.is_available(), reason="_scaled_dot_product_efficient_attention only supports CUDA", ), - TorchLibOpInfo( - "nn.functional.scaled_dot_product_attention_bool_mask", - nn_ops.aten_scaled_dot_product_attention_bool_mask, - tolerance={torch.float32: (3e-4, 1.5e-5)}, - ) - .skip( - matcher=lambda sample: (attn_mask := sample.kwargs.get("attn_mask")) is not None - and attn_mask.dtype != torch.bool, - reason="this overload takes a boolean mask", - ) - .skip( - matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0, - reason="dropout is random so the results do not match", - ) - .xfail( - dtypes=(torch.float16,), - reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", - test_class_name="TestOutputConsistencyFullGraph", - ) - .xfail( - matcher=lambda sample: len(sample.input.shape) != 4 - or len(sample.args[0].shape) != 4 - or len(sample.args[1].shape) != 4, - reason="torch sdpa is expected to pass in 4d q, k, and v.", - ), TorchLibOpInfo( "ops.aten.upsample_bilinear2d.default", nn_ops.aten_upsample_bilinear2d, @@ -1998,10 +1675,7 @@ def _where_input_wrangler( # Shape-only comparison is the appropriate testing approach for this case. compare_shape_only_for_output=(0,), ), - TorchLibOpInfo( - "ops.aten.upsample_bilinear2d.vec", - nn_ops.aten_upsample_bilinear2d_vec, - ), + TorchLibOpInfo("ops.aten.upsample_bilinear2d.vec", nn_ops.aten_upsample_bilinear2d_vec), TorchLibOpInfo( "ops.aten.upsample_bicubic2d.default", nn_ops.aten_upsample_bicubic2d, @@ -2021,10 +1695,7 @@ def _where_input_wrangler( # Shape-only comparison is the appropriate testing approach for this case. compare_shape_only_for_output=(0,), ), - TorchLibOpInfo( - "ops.aten.upsample_bicubic2d.vec", - nn_ops.aten_upsample_bicubic2d_vec, - ), + TorchLibOpInfo("ops.aten.upsample_bicubic2d.vec", nn_ops.aten_upsample_bicubic2d_vec), TorchLibOpInfo( "ops.aten.upsample_linear1d", nn_ops.aten_upsample_linear1d, @@ -2033,38 +1704,14 @@ def _where_input_wrangler( and sample.kwargs.get("scales") is not None, reason="fixme: align_corners=False output mismatch when scales are provided", ), - TorchLibOpInfo( - "ops.aten.upsample_nearest1d", - nn_ops.aten_upsample_nearest1d, - ), - TorchLibOpInfo( - "ops.aten.upsample_nearest1d.vec", - nn_ops.aten_upsample_nearestnd_vec, - ), - TorchLibOpInfo( - "ops.aten.upsample_nearest2d", - nn_ops.aten_upsample_nearest2d, - ), - TorchLibOpInfo( - "ops.aten.upsample_nearest2d.vec", - nn_ops.aten_upsample_nearestnd_vec, - ), - TorchLibOpInfo( - "ops.aten.upsample_nearest3d", - nn_ops.aten_upsample_nearest3d, - ), - TorchLibOpInfo( - "ops.aten.upsample_nearest3d.vec", - nn_ops.aten_upsample_nearestnd_vec, - ), - TorchLibOpInfo( - "ops.aten.upsample_trilinear3d.default", - nn_ops.aten_upsample_trilinear3d, - ), - TorchLibOpInfo( - "ops.aten.upsample_trilinear3d.vec", - nn_ops.aten_upsample_trilinear3d_vec, - ), + TorchLibOpInfo("ops.aten.upsample_nearest1d", nn_ops.aten_upsample_nearest1d), + TorchLibOpInfo("ops.aten.upsample_nearest1d.vec", nn_ops.aten_upsample_nearestnd_vec), + TorchLibOpInfo("ops.aten.upsample_nearest2d", nn_ops.aten_upsample_nearest2d), + TorchLibOpInfo("ops.aten.upsample_nearest2d.vec", nn_ops.aten_upsample_nearestnd_vec), + TorchLibOpInfo("ops.aten.upsample_nearest3d", nn_ops.aten_upsample_nearest3d), + TorchLibOpInfo("ops.aten.upsample_nearest3d.vec", nn_ops.aten_upsample_nearestnd_vec), + TorchLibOpInfo("ops.aten.upsample_trilinear3d.default", nn_ops.aten_upsample_trilinear3d), + TorchLibOpInfo("ops.aten.upsample_trilinear3d.vec", nn_ops.aten_upsample_trilinear3d_vec), TorchLibOpInfo("ones_like", core_ops.aten_ones_like), TorchLibOpInfo( "roll", @@ -2082,10 +1729,7 @@ def _where_input_wrangler( core_ops.aten_scatter_reduce, input_wrangler=_scatter_reduce_input_wrangler, ) - .xfail( - variant_name="mean", - reason="ONNX doesn't support reduce='mean' option", - ) + .xfail(variant_name="mean", reason="ONNX doesn't support reduce='mean' option") .xfail( variant_name="prod", dtypes=(torch.float16, torch.float64), @@ -2159,40 +1803,13 @@ def _where_input_wrangler( ops_test_common.duplicate_opinfo(OPS_DB, "atleast_1d", ("atleast_1d_Sequence",)) ops_test_common.duplicate_opinfo(OPS_DB, "atleast_2d", ("atleast_2d_Sequence",)) ops_test_common.duplicate_opinfo(OPS_DB, "atleast_3d", ("atleast_3d_Sequence",)) -ops_test_common.duplicate_opinfo( - OPS_DB, - "bitwise_left_shift", - ( - "bitwise_left_shift_int8", - "bitwise_left_shift_int16", - "bitwise_left_shift_int32", - "bitwise_left_shift_int64", - ), -) -ops_test_common.duplicate_opinfo( - OPS_DB, - "bitwise_right_shift", - ( - "bitwise_right_shift_int8", - "bitwise_right_shift_int16", - "bitwise_right_shift_int32", - "bitwise_right_shift_int64", - ), -) ops_test_common.duplicate_opinfo(OPS_DB, "cat", ("concat", "concatenate")) ops_test_common.duplicate_opinfo(OPS_DB, "clone", ("lift_fresh_copy",)) -ops_test_common.duplicate_opinfo(OPS_DB, "diagonal", ("diagonal_bool",)) -ops_test_common.duplicate_opinfo(OPS_DB, "div", ("div_mode", "div_mode_int")) -ops_test_common.duplicate_opinfo(OPS_DB, "ge", ("ge_bool",)) -ops_test_common.duplicate_opinfo(OPS_DB, "gt", ("gt_bool",)) +ops_test_common.duplicate_opinfo(OPS_DB, "div", ("div_mode",)) ops_test_common.duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",)) -ops_test_common.duplicate_opinfo(OPS_DB, "le", ("le_bool",)) -ops_test_common.duplicate_opinfo(OPS_DB, "lt", ("lt_bool",)) ops_test_common.duplicate_opinfo(OPS_DB, "max", ("max_dim",)) -ops_test_common.duplicate_opinfo(OPS_DB, "maximum", ("maximum_bool",)) ops_test_common.duplicate_opinfo(OPS_DB, "mean", ("mean_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "min", ("min_dim",)) -ops_test_common.duplicate_opinfo(OPS_DB, "minimum", ("minimum_bool",)) ops_test_common.duplicate_opinfo( OPS_DB, "nn.functional.pad", @@ -2202,20 +1819,6 @@ def _where_input_wrangler( "nn.functional.replication_pad3d", ), ) -ops_test_common.duplicate_opinfo( - OPS_DB, - "nn.functional.scaled_dot_product_attention", - ("nn.functional.scaled_dot_product_attention_bool_mask",), -) -ops_test_common.duplicate_opinfo( - OPS_DB, - "nn.functional.celu", - ("nn.functional.celu_type_promoted",), -) -ops_test_common.duplicate_opinfo( - OPS_DB, "ops.aten._log_softmax", ("ops.aten._log_softmax_half",) -) -ops_test_common.duplicate_opinfo(OPS_DB, "ops.aten._softmax", ("ops.aten._softmax_half",)) ops_test_common.duplicate_opinfo(OPS_DB, "prod", ("prod_dim_int",)) ops_test_common.duplicate_opinfo(OPS_DB, "round", ("round_decimals",)) ops_test_common.duplicate_opinfo(OPS_DB, "squeeze", ("squeeze_dim",))